Skip to content

Commit be25c6d

Browse files
committed
Merge branch 'master' into ci/godriver-3659-await-min-pool-size-in-ust
2 parents c2611f2 + d5ac9a4 commit be25c6d

File tree

12 files changed

+441
-116
lines changed

12 files changed

+441
-116
lines changed

.evergreen/config.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ functions:
198198
params:
199199
binary: bash
200200
env:
201-
GO_BUILD_TAGS: cse
201+
GO_BUILD_TAGS: "cse,mongointernal"
202202
include_expansions_in_env: ["TOPOLOGY", "AUTH", "SSL", "SKIP_CSOT_TESTS", "MONGODB_URI", "CRYPT_SHARED_LIB_PATH", "SKIP_CRYPT_SHARED_LIB", "RACE", "MONGO_GO_DRIVER_COMPRESSOR", "REQUIRE_API_VERSION", "LOAD_BALANCER"]
203203
args: [*task-runner, setup-test]
204204
- command: subprocess.exec

internal/integration/cursor_test.go

Lines changed: 157 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -349,116 +349,192 @@ func parseMaxAwaitTime(mt *mtest.T, evt *event.CommandStartedEvent) int64 {
349349
return got
350350
}
351351

352-
func TestCursor_tailableAwaitData(t *testing.T) {
353-
mt := mtest.New(t, mtest.NewOptions().CreateClient(false))
352+
func tadcFindFactory(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor {
353+
mt.Helper()
354354

355-
cappedOpts := options.CreateCollection().SetCapped(true).
356-
SetSizeInBytes(1024 * 64)
355+
initCollection(mt, &coll)
356+
cur, err := coll.Find(ctx, bson.D{{"__nomatch", 1}},
357+
options.Find().SetBatchSize(1).SetCursorType(options.TailableAwait))
358+
require.NoError(mt, err, "Find error: %v", err)
357359

358-
// TODO(SERVER-96344): mongos doesn't honor a failpoint's full blockTimeMS.
359-
mtOpts := mtest.NewOptions().MinServerVersion("4.4").
360-
Topologies(mtest.ReplicaSet, mtest.LoadBalanced, mtest.Single).
361-
CollectionCreateOptions(cappedOpts)
360+
return cur
361+
}
362362

363-
mt.RunOpts("apply remaining timeoutMS if less than maxAwaitTimeMS", mtOpts, func(mt *mtest.T) {
364-
initCollection(mt, mt.Coll)
363+
func tadcAggregateFactory(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor {
364+
mt.Helper()
365365

366-
// Create a 30ms failpoint for getMore.
367-
mt.SetFailPoint(failpoint.FailPoint{
368-
ConfigureFailPoint: "failCommand",
369-
Mode: failpoint.Mode{
370-
Times: 1,
371-
},
372-
Data: failpoint.Data{
373-
FailCommands: []string{"getMore"},
374-
BlockConnection: true,
375-
BlockTimeMS: 30,
376-
},
377-
})
366+
initCollection(mt, &coll)
367+
opts := options.Aggregate()
368+
pipeline := mongo.Pipeline{{{"$changeStream", bson.D{{"fullDocument", "default"}}}},
369+
{{"$match", bson.D{
370+
{"operationType", "insert"},
371+
{"fullDocment.__nomatch", 1},
372+
}}},
373+
}
378374

379-
// Create a find cursor with a 100ms maxAwaitTimeMS and a tailable awaitData
380-
// cursor type.
381-
opts := options.Find().
382-
SetBatchSize(1).
383-
SetMaxAwaitTime(100 * time.Millisecond).
384-
SetCursorType(options.TailableAwait)
375+
cursor, err := coll.Aggregate(ctx, pipeline, opts)
376+
require.NoError(mt, err, "Aggregate error: %v", err)
385377

386-
cursor, err := mt.Coll.Find(context.Background(), bson.D{{"x", 2}}, opts)
387-
require.NoError(mt, err)
378+
return cursor
379+
}
388380

389-
defer cursor.Close(context.Background())
381+
func tadcRunCommandCursorFactory(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor {
382+
mt.Helper()
390383

391-
// Use a 200ms timeout that caps the lifetime of cursor.Next. The underlying
392-
// getMore loop should run at least two times: the first getMore will block
393-
// for 30ms on the getMore and then an additional 100ms for the
394-
// maxAwaitTimeMS. The second getMore will then use the remaining ~70ms
395-
// left on the timeout.
396-
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
397-
defer cancel()
384+
initCollection(mt, &coll)
398385

399-
// Iterate twice to force a getMore
400-
cursor.Next(ctx)
386+
cur, err := coll.Database().RunCommandCursor(ctx, bson.D{
387+
{"find", coll.Name()},
388+
{"filter", bson.D{{"__nomatch", 1}}},
389+
{"tailable", true},
390+
{"awaitData", true},
391+
{"batchSize", int32(1)},
392+
})
393+
require.NoError(mt, err, "RunCommandCursor error: %v", err)
401394

402-
mt.ClearEvents()
403-
cursor.Next(ctx)
395+
return cur
396+
}
404397

405-
require.Error(mt, cursor.Err(), "expected error from cursor.Next")
406-
assert.ErrorIs(mt, cursor.Err(), context.DeadlineExceeded, "expected context deadline exceeded error")
398+
// For tailable awaitData cursors, the maxTimeMS for a getMore should be
399+
// min(maxAwaitTimeMS, remaining timeoutMS - minRoundTripTime) to allow the
400+
// server more opportunities to respond with an empty batch before a
401+
// client-side timeout.
402+
func TestCursor_tailableAwaitData_applyRemainingTimeout(t *testing.T) {
403+
// These values reflect what is used in the unified spec tests, see
404+
// DRIVERS-2868.
405+
const timeoutMS = 200
406+
const maxAwaitTimeMS = 100
407+
const blockTimeMS = 30
408+
const getMoreBound = 71
409+
410+
// TODO(GODRIVER-3328): mongos doesn't honor a failpoint's full blockTimeMS.
411+
baseTopologies := []mtest.TopologyKind{mtest.Single, mtest.LoadBalanced, mtest.ReplicaSet}
412+
413+
type testCase struct {
414+
name string
415+
factory func(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor
416+
opTimeout bool
417+
topologies []mtest.TopologyKind
418+
}
407419

408-
// Collect all started events to find the getMore commands.
409-
startedEvents := mt.GetAllStartedEvents()
420+
cases := []testCase{
421+
// TODO(GODRIVER-2944): "find" cursors are tested in the CSOT unified spec
422+
// tests for tailable/awaitData cursors and so these tests can be removed
423+
// once the driver supports timeoutMode.
424+
{
425+
name: "find client-level timeout",
426+
factory: tadcFindFactory,
427+
topologies: baseTopologies,
428+
opTimeout: false,
429+
},
430+
{
431+
name: "find operation-level timeout",
432+
factory: tadcFindFactory,
433+
topologies: baseTopologies,
434+
opTimeout: true,
435+
},
410436

411-
var getMoreStartedEvents []*event.CommandStartedEvent
412-
for _, evt := range startedEvents {
413-
if evt.CommandName == "getMore" {
414-
getMoreStartedEvents = append(getMoreStartedEvents, evt)
415-
}
416-
}
437+
// There is no analogue to tailable/awaiData cursor unified spec tests for
438+
// aggregate and runnCommand.
439+
{
440+
name: "aggregate with changeStream client-level timeout",
441+
factory: tadcAggregateFactory,
442+
topologies: []mtest.TopologyKind{mtest.ReplicaSet, mtest.LoadBalanced},
443+
opTimeout: false,
444+
},
445+
{
446+
name: "aggregate with changeStream operation-level timeout",
447+
factory: tadcAggregateFactory,
448+
topologies: []mtest.TopologyKind{mtest.ReplicaSet, mtest.LoadBalanced},
449+
opTimeout: true,
450+
},
451+
{
452+
name: "runCommandCursor client-level timeout",
453+
factory: tadcRunCommandCursorFactory,
454+
topologies: baseTopologies,
455+
opTimeout: false,
456+
},
457+
{
458+
name: "runCommandCursor operation-level timeout",
459+
factory: tadcRunCommandCursorFactory,
460+
topologies: baseTopologies,
461+
opTimeout: true,
462+
},
463+
}
417464

418-
// The first getMore should have a maxTimeMS of <= 100ms.
419-
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreStartedEvents[0]), int64(100))
465+
mt := mtest.New(t, mtest.NewOptions().CreateClient(false).MinServerVersion("4.2"))
420466

421-
// The second getMore should have a maxTimeMS of <=71, indicating that we
422-
// are using the time remaining in the context rather than the
423-
// maxAwaitTimeMS.
424-
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreStartedEvents[1]), int64(71))
425-
})
467+
for _, tc := range cases {
468+
// Reset the collection between test cases to avoid leaking timeouts
469+
// between tests.
470+
cappedOpts := options.CreateCollection().SetCapped(true).SetSizeInBytes(1024 * 64)
471+
caseOpts := mtest.NewOptions().
472+
CollectionCreateOptions(cappedOpts).
473+
Topologies(tc.topologies...).
474+
CreateClient(true)
426475

427-
mtOpts.Topologies(mtest.ReplicaSet, mtest.Sharded, mtest.LoadBalanced, mtest.Single)
476+
if !tc.opTimeout {
477+
caseOpts = caseOpts.ClientOptions(options.Client().SetTimeout(timeoutMS * time.Millisecond))
478+
}
428479

429-
mt.RunOpts("apply maxAwaitTimeMS if less than remaining timeout", mtOpts, func(mt *mtest.T) {
430-
initCollection(mt, mt.Coll)
431-
mt.ClearEvents()
480+
mt.RunOpts(tc.name, caseOpts, func(mt *mtest.T) {
481+
mt.SetFailPoint(failpoint.FailPoint{
482+
ConfigureFailPoint: "failCommand",
483+
Mode: failpoint.Mode{Times: 1},
484+
Data: failpoint.Data{
485+
FailCommands: []string{"getMore"},
486+
BlockConnection: true,
487+
BlockTimeMS: int32(blockTimeMS),
488+
},
489+
})
490+
491+
ctx := context.Background()
432492

433-
// Create a find cursor
434-
opts := options.Find().SetBatchSize(1).SetMaxAwaitTime(50 * time.Millisecond)
493+
var cancel context.CancelFunc
494+
if tc.opTimeout {
495+
ctx, cancel = context.WithTimeout(ctx, timeoutMS*time.Millisecond)
496+
defer cancel()
497+
}
435498

436-
cursor, err := mt.Coll.Find(context.Background(), bson.D{}, opts)
437-
require.NoError(mt, err)
499+
cur := tc.factory(ctx, mt, *mt.Coll)
500+
defer func() { assert.NoError(mt, cur.Close(context.Background())) }()
438501

439-
_ = mt.GetStartedEvent() // Empty find from started list.
502+
require.NoError(mt, cur.Err())
440503

441-
defer cursor.Close(context.Background())
504+
cur.SetMaxAwaitTime(maxAwaitTimeMS * time.Millisecond)
442505

443-
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
444-
defer cancel()
506+
mt.ClearEvents()
445507

446-
// Iterate twice to force a getMore
447-
cursor.Next(ctx)
448-
cursor.Next(ctx)
508+
assert.False(mt, cur.Next(ctx))
449509

450-
cmd := mt.GetStartedEvent().Command
510+
require.Error(mt, cur.Err(), "expected error from cursor.Next")
511+
assert.ErrorIs(mt, cur.Err(), context.DeadlineExceeded, "expected context deadline exceeded error")
451512

452-
maxTimeMSRaw, err := cmd.LookupErr("maxTimeMS")
453-
require.NoError(mt, err)
513+
getMoreEvts := []*event.CommandStartedEvent{}
514+
for _, evt := range mt.GetAllStartedEvents() {
515+
if evt.CommandName == "getMore" {
516+
getMoreEvts = append(getMoreEvts, evt)
517+
}
518+
}
454519

455-
got, ok := maxTimeMSRaw.AsInt64OK()
456-
require.True(mt, ok)
520+
// It's possible that three getMore events are called: 100ms, 70ms, and
521+
// then some small leftover of remaining time (e.g. 20µs).
522+
require.GreaterOrEqual(mt, len(getMoreEvts), 2)
457523

458-
assert.LessOrEqual(mt, got, int64(50))
459-
})
524+
// The first getMore should have a maxTimeMS of <= 100ms but greater
525+
// than 71ms, indicating that the maxAwaitTimeMS was used.
526+
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreEvts[0]), int64(maxAwaitTimeMS))
527+
assert.Greater(mt, parseMaxAwaitTime(mt, getMoreEvts[0]), int64(getMoreBound))
528+
529+
// The second getMore should have a maxTimeMS of <=71, indicating that we
530+
// are using the time remaining in the context rather than the
531+
// maxAwaitTimeMS.
532+
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreEvts[1]), int64(getMoreBound))
533+
})
534+
}
460535
}
461536

537+
// For tailable awaitData cursors, the maxTimeMS for a getMore should be
462538
func TestCursor_tailableAwaitData_ShortCircuitingGetMore(t *testing.T) {
463539
mt := mtest.New(t, mtest.NewOptions().CreateClient(false))
464540

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
// Copyright (C) MongoDB, Inc. 2025-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
//go:build mongointernal
8+
9+
package integration
10+
11+
import (
12+
"context"
13+
"testing"
14+
15+
"go.mongodb.org/mongo-driver/v2/bson"
16+
"go.mongodb.org/mongo-driver/v2/internal/assert"
17+
"go.mongodb.org/mongo-driver/v2/internal/integration/mtest"
18+
"go.mongodb.org/mongo-driver/v2/internal/require"
19+
"go.mongodb.org/mongo-driver/v2/mongo"
20+
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
21+
)
22+
23+
func TestNewSessionWithLSID(t *testing.T) {
24+
mt := mtest.New(t)
25+
26+
mt.Run("can be used to pass a specific session ID to CRUD commands", func(mt *mtest.T) {
27+
mt.Parallel()
28+
29+
// Create a session ID document, which is a BSON document with field
30+
// "id" containing a 16-byte UUID (binary subtype 4).
31+
sessionID := bson.Raw(bsoncore.NewDocumentBuilder().
32+
AppendBinary("id", 4, []byte{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}).
33+
Build())
34+
35+
sess := mongo.NewSessionWithLSID(mt.Client, sessionID)
36+
37+
ctx := mongo.NewSessionContext(context.Background(), sess)
38+
_, err := mt.Coll.InsertOne(ctx, bson.D{{"foo", "bar"}})
39+
require.NoError(mt, err)
40+
41+
evt := mt.GetStartedEvent()
42+
val, err := evt.Command.LookupErr("lsid")
43+
require.NoError(mt, err, "lsid should be present in the command document")
44+
45+
doc, ok := val.DocumentOK()
46+
require.True(mt, ok, "lsid should be a document")
47+
48+
assert.Equal(mt, sessionID, doc)
49+
})
50+
51+
mt.Run("EndSession panics", func(mt *mtest.T) {
52+
mt.Parallel()
53+
54+
sessionID := bson.Raw(bsoncore.NewDocumentBuilder().
55+
AppendBinary("id", 4, []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}).
56+
Build())
57+
sess := mongo.NewSessionWithLSID(mt.Client, sessionID)
58+
59+
// Use a defer-recover block to catch the expected panic and assert that
60+
// the recovered error is not nil.
61+
defer func() {
62+
err := recover()
63+
assert.NotNil(mt, err, "expected EndSession to panic")
64+
}()
65+
66+
// Expect this call to panic.
67+
sess.EndSession(context.Background())
68+
69+
// We expect that calling EndSession on a Session returned by
70+
// NewSessionWithLSID panics. This code will only be reached if EndSession
71+
// doesn't panic.
72+
t.Errorf("expected EndSession to panic")
73+
})
74+
75+
mt.Run("ClientSession.SetServer panics", func(mt *mtest.T) {
76+
mt.Parallel()
77+
78+
sessionID := bson.Raw(bsoncore.NewDocumentBuilder().
79+
AppendBinary("id", 4, []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}).
80+
Build())
81+
sess := mongo.NewSessionWithLSID(mt.Client, sessionID)
82+
83+
// Use a defer-recover block to catch the expected panic and assert that
84+
// the recovered error is not nil.
85+
defer func() {
86+
err := recover()
87+
assert.NotNil(mt, err, "expected ClientSession.SetServer to panic")
88+
}()
89+
90+
// Expect this call to panic.
91+
sess.ClientSession().SetServer()
92+
93+
// We expect that calling ClientSession.SetServer on a Session returned
94+
// by NewSessionWithLSID panics. This code will only be reached if
95+
// ClientSession.SetServer doesn't panic.
96+
t.Errorf("expected ClientSession.SetServer to panic")
97+
})
98+
}

0 commit comments

Comments
 (0)