Skip to content

Commit d5ac9a4

Browse files
Merge release/2.3 into master (#2190)
Merge release/2.3 into master
2 parents 40b5f17 + 73f684f commit d5ac9a4

File tree

6 files changed

+240
-102
lines changed

6 files changed

+240
-102
lines changed

internal/integration/cursor_test.go

Lines changed: 157 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -349,116 +349,192 @@ func parseMaxAwaitTime(mt *mtest.T, evt *event.CommandStartedEvent) int64 {
349349
return got
350350
}
351351

352-
func TestCursor_tailableAwaitData(t *testing.T) {
353-
mt := mtest.New(t, mtest.NewOptions().CreateClient(false))
352+
func tadcFindFactory(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor {
353+
mt.Helper()
354354

355-
cappedOpts := options.CreateCollection().SetCapped(true).
356-
SetSizeInBytes(1024 * 64)
355+
initCollection(mt, &coll)
356+
cur, err := coll.Find(ctx, bson.D{{"__nomatch", 1}},
357+
options.Find().SetBatchSize(1).SetCursorType(options.TailableAwait))
358+
require.NoError(mt, err, "Find error: %v", err)
357359

358-
// TODO(SERVER-96344): mongos doesn't honor a failpoint's full blockTimeMS.
359-
mtOpts := mtest.NewOptions().MinServerVersion("4.4").
360-
Topologies(mtest.ReplicaSet, mtest.LoadBalanced, mtest.Single).
361-
CollectionCreateOptions(cappedOpts)
360+
return cur
361+
}
362362

363-
mt.RunOpts("apply remaining timeoutMS if less than maxAwaitTimeMS", mtOpts, func(mt *mtest.T) {
364-
initCollection(mt, mt.Coll)
363+
func tadcAggregateFactory(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor {
364+
mt.Helper()
365365

366-
// Create a 30ms failpoint for getMore.
367-
mt.SetFailPoint(failpoint.FailPoint{
368-
ConfigureFailPoint: "failCommand",
369-
Mode: failpoint.Mode{
370-
Times: 1,
371-
},
372-
Data: failpoint.Data{
373-
FailCommands: []string{"getMore"},
374-
BlockConnection: true,
375-
BlockTimeMS: 30,
376-
},
377-
})
366+
initCollection(mt, &coll)
367+
opts := options.Aggregate()
368+
pipeline := mongo.Pipeline{{{"$changeStream", bson.D{{"fullDocument", "default"}}}},
369+
{{"$match", bson.D{
370+
{"operationType", "insert"},
371+
{"fullDocment.__nomatch", 1},
372+
}}},
373+
}
378374

379-
// Create a find cursor with a 100ms maxAwaitTimeMS and a tailable awaitData
380-
// cursor type.
381-
opts := options.Find().
382-
SetBatchSize(1).
383-
SetMaxAwaitTime(100 * time.Millisecond).
384-
SetCursorType(options.TailableAwait)
375+
cursor, err := coll.Aggregate(ctx, pipeline, opts)
376+
require.NoError(mt, err, "Aggregate error: %v", err)
385377

386-
cursor, err := mt.Coll.Find(context.Background(), bson.D{{"x", 2}}, opts)
387-
require.NoError(mt, err)
378+
return cursor
379+
}
388380

389-
defer cursor.Close(context.Background())
381+
func tadcRunCommandCursorFactory(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor {
382+
mt.Helper()
390383

391-
// Use a 200ms timeout that caps the lifetime of cursor.Next. The underlying
392-
// getMore loop should run at least two times: the first getMore will block
393-
// for 30ms on the getMore and then an additional 100ms for the
394-
// maxAwaitTimeMS. The second getMore will then use the remaining ~70ms
395-
// left on the timeout.
396-
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
397-
defer cancel()
384+
initCollection(mt, &coll)
398385

399-
// Iterate twice to force a getMore
400-
cursor.Next(ctx)
386+
cur, err := coll.Database().RunCommandCursor(ctx, bson.D{
387+
{"find", coll.Name()},
388+
{"filter", bson.D{{"__nomatch", 1}}},
389+
{"tailable", true},
390+
{"awaitData", true},
391+
{"batchSize", int32(1)},
392+
})
393+
require.NoError(mt, err, "RunCommandCursor error: %v", err)
401394

402-
mt.ClearEvents()
403-
cursor.Next(ctx)
395+
return cur
396+
}
404397

405-
require.Error(mt, cursor.Err(), "expected error from cursor.Next")
406-
assert.ErrorIs(mt, cursor.Err(), context.DeadlineExceeded, "expected context deadline exceeded error")
398+
// For tailable awaitData cursors, the maxTimeMS for a getMore should be
399+
// min(maxAwaitTimeMS, remaining timeoutMS - minRoundTripTime) to allow the
400+
// server more opportunities to respond with an empty batch before a
401+
// client-side timeout.
402+
func TestCursor_tailableAwaitData_applyRemainingTimeout(t *testing.T) {
403+
// These values reflect what is used in the unified spec tests, see
404+
// DRIVERS-2868.
405+
const timeoutMS = 200
406+
const maxAwaitTimeMS = 100
407+
const blockTimeMS = 30
408+
const getMoreBound = 71
409+
410+
// TODO(GODRIVER-3328): mongos doesn't honor a failpoint's full blockTimeMS.
411+
baseTopologies := []mtest.TopologyKind{mtest.Single, mtest.LoadBalanced, mtest.ReplicaSet}
412+
413+
type testCase struct {
414+
name string
415+
factory func(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor
416+
opTimeout bool
417+
topologies []mtest.TopologyKind
418+
}
407419

408-
// Collect all started events to find the getMore commands.
409-
startedEvents := mt.GetAllStartedEvents()
420+
cases := []testCase{
421+
// TODO(GODRIVER-2944): "find" cursors are tested in the CSOT unified spec
422+
// tests for tailable/awaitData cursors and so these tests can be removed
423+
// once the driver supports timeoutMode.
424+
{
425+
name: "find client-level timeout",
426+
factory: tadcFindFactory,
427+
topologies: baseTopologies,
428+
opTimeout: false,
429+
},
430+
{
431+
name: "find operation-level timeout",
432+
factory: tadcFindFactory,
433+
topologies: baseTopologies,
434+
opTimeout: true,
435+
},
410436

411-
var getMoreStartedEvents []*event.CommandStartedEvent
412-
for _, evt := range startedEvents {
413-
if evt.CommandName == "getMore" {
414-
getMoreStartedEvents = append(getMoreStartedEvents, evt)
415-
}
416-
}
437+
// There is no analogue to tailable/awaiData cursor unified spec tests for
438+
// aggregate and runnCommand.
439+
{
440+
name: "aggregate with changeStream client-level timeout",
441+
factory: tadcAggregateFactory,
442+
topologies: []mtest.TopologyKind{mtest.ReplicaSet, mtest.LoadBalanced},
443+
opTimeout: false,
444+
},
445+
{
446+
name: "aggregate with changeStream operation-level timeout",
447+
factory: tadcAggregateFactory,
448+
topologies: []mtest.TopologyKind{mtest.ReplicaSet, mtest.LoadBalanced},
449+
opTimeout: true,
450+
},
451+
{
452+
name: "runCommandCursor client-level timeout",
453+
factory: tadcRunCommandCursorFactory,
454+
topologies: baseTopologies,
455+
opTimeout: false,
456+
},
457+
{
458+
name: "runCommandCursor operation-level timeout",
459+
factory: tadcRunCommandCursorFactory,
460+
topologies: baseTopologies,
461+
opTimeout: true,
462+
},
463+
}
417464

418-
// The first getMore should have a maxTimeMS of <= 100ms.
419-
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreStartedEvents[0]), int64(100))
465+
mt := mtest.New(t, mtest.NewOptions().CreateClient(false).MinServerVersion("4.2"))
420466

421-
// The second getMore should have a maxTimeMS of <=71, indicating that we
422-
// are using the time remaining in the context rather than the
423-
// maxAwaitTimeMS.
424-
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreStartedEvents[1]), int64(71))
425-
})
467+
for _, tc := range cases {
468+
// Reset the collection between test cases to avoid leaking timeouts
469+
// between tests.
470+
cappedOpts := options.CreateCollection().SetCapped(true).SetSizeInBytes(1024 * 64)
471+
caseOpts := mtest.NewOptions().
472+
CollectionCreateOptions(cappedOpts).
473+
Topologies(tc.topologies...).
474+
CreateClient(true)
426475

427-
mtOpts.Topologies(mtest.ReplicaSet, mtest.Sharded, mtest.LoadBalanced, mtest.Single)
476+
if !tc.opTimeout {
477+
caseOpts = caseOpts.ClientOptions(options.Client().SetTimeout(timeoutMS * time.Millisecond))
478+
}
428479

429-
mt.RunOpts("apply maxAwaitTimeMS if less than remaining timeout", mtOpts, func(mt *mtest.T) {
430-
initCollection(mt, mt.Coll)
431-
mt.ClearEvents()
480+
mt.RunOpts(tc.name, caseOpts, func(mt *mtest.T) {
481+
mt.SetFailPoint(failpoint.FailPoint{
482+
ConfigureFailPoint: "failCommand",
483+
Mode: failpoint.Mode{Times: 1},
484+
Data: failpoint.Data{
485+
FailCommands: []string{"getMore"},
486+
BlockConnection: true,
487+
BlockTimeMS: int32(blockTimeMS),
488+
},
489+
})
490+
491+
ctx := context.Background()
432492

433-
// Create a find cursor
434-
opts := options.Find().SetBatchSize(1).SetMaxAwaitTime(50 * time.Millisecond)
493+
var cancel context.CancelFunc
494+
if tc.opTimeout {
495+
ctx, cancel = context.WithTimeout(ctx, timeoutMS*time.Millisecond)
496+
defer cancel()
497+
}
435498

436-
cursor, err := mt.Coll.Find(context.Background(), bson.D{}, opts)
437-
require.NoError(mt, err)
499+
cur := tc.factory(ctx, mt, *mt.Coll)
500+
defer func() { assert.NoError(mt, cur.Close(context.Background())) }()
438501

439-
_ = mt.GetStartedEvent() // Empty find from started list.
502+
require.NoError(mt, cur.Err())
440503

441-
defer cursor.Close(context.Background())
504+
cur.SetMaxAwaitTime(maxAwaitTimeMS * time.Millisecond)
442505

443-
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
444-
defer cancel()
506+
mt.ClearEvents()
445507

446-
// Iterate twice to force a getMore
447-
cursor.Next(ctx)
448-
cursor.Next(ctx)
508+
assert.False(mt, cur.Next(ctx))
449509

450-
cmd := mt.GetStartedEvent().Command
510+
require.Error(mt, cur.Err(), "expected error from cursor.Next")
511+
assert.ErrorIs(mt, cur.Err(), context.DeadlineExceeded, "expected context deadline exceeded error")
451512

452-
maxTimeMSRaw, err := cmd.LookupErr("maxTimeMS")
453-
require.NoError(mt, err)
513+
getMoreEvts := []*event.CommandStartedEvent{}
514+
for _, evt := range mt.GetAllStartedEvents() {
515+
if evt.CommandName == "getMore" {
516+
getMoreEvts = append(getMoreEvts, evt)
517+
}
518+
}
454519

455-
got, ok := maxTimeMSRaw.AsInt64OK()
456-
require.True(mt, ok)
520+
// It's possible that three getMore events are called: 100ms, 70ms, and
521+
// then some small leftover of remaining time (e.g. 20µs).
522+
require.GreaterOrEqual(mt, len(getMoreEvts), 2)
457523

458-
assert.LessOrEqual(mt, got, int64(50))
459-
})
524+
// The first getMore should have a maxTimeMS of <= 100ms but greater
525+
// than 71ms, indicating that the maxAwaitTimeMS was used.
526+
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreEvts[0]), int64(maxAwaitTimeMS))
527+
assert.Greater(mt, parseMaxAwaitTime(mt, getMoreEvts[0]), int64(getMoreBound))
528+
529+
// The second getMore should have a maxTimeMS of <=71, indicating that we
530+
// are using the time remaining in the context rather than the
531+
// maxAwaitTimeMS.
532+
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreEvts[1]), int64(getMoreBound))
533+
})
534+
}
460535
}
461536

537+
// For tailable awaitData cursors, the maxTimeMS for a getMore should be
462538
func TestCursor_tailableAwaitData_ShortCircuitingGetMore(t *testing.T) {
463539
mt := mtest.New(t, mtest.NewOptions().CreateClient(false))
464540

mongo/client_bulk_write.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,11 @@ func (mb *modelBatches) processResponse(ctx context.Context, resp bsoncore.Docum
476476
return err
477477
}
478478
var cursor *Cursor
479-
cursor, err = newCursor(bCursor, mb.client.bsonOpts, mb.client.registry)
479+
cursor, err = newCursor(bCursor, mb.client.bsonOpts, mb.client.registry,
480+
481+
// This op doesn't return a cursor to the user, so setting the client
482+
// timeout should be a no-op.
483+
withCursorOptionClientTimeout(mb.client.timeout))
480484
if err != nil {
481485
return err
482486
}

mongo/collection.go

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,7 +1092,13 @@ func aggregate(a aggregateParams, opts ...options.Lister[options.AggregateOption
10921092
if err != nil {
10931093
return nil, wrapErrors(err)
10941094
}
1095-
cursor, err := newCursorWithSession(bc, a.client.bsonOpts, a.registry, sess)
1095+
cursor, err := newCursorWithSession(bc, a.client.bsonOpts, a.registry, sess,
1096+
1097+
// The only way the server will return a tailable/awaitData cursor for an
1098+
// aggregate operation is for the first stage in the pipeline to
1099+
// be $changeStream, this is the only time maxAwaitTimeMS should be applied.
1100+
// For this reason, we pass the client timeout to the cursor.
1101+
withCursorOptionClientTimeout(a.client.timeout))
10961102
return cursor, wrapErrors(err)
10971103
}
10981104

@@ -1567,7 +1573,9 @@ func (coll *Collection) find(
15671573
if err != nil {
15681574
return nil, wrapErrors(err)
15691575
}
1570-
return newCursorWithSession(bc, coll.bsonOpts, coll.registry, sess)
1576+
1577+
return newCursorWithSession(bc, coll.bsonOpts, coll.registry, sess,
1578+
withCursorOptionClientTimeout(coll.client.timeout))
15711579
}
15721580

15731581
func newFindArgsFromFindOneArgs(args *options.FindOneOptions) *options.FindOptions {

0 commit comments

Comments
 (0)