From b9da7c2cad427833c1db9f7cb53ebb553283bf9c Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Fri, 22 Aug 2025 16:13:19 -0600 Subject: [PATCH 1/3] Add cursorOptions to constructors to pipe client timeout --- mongo/client_bulk_write.go | 6 +++- mongo/collection.go | 12 +++++-- mongo/cursor.go | 73 ++++++++++++++++++++++++++++++-------- mongo/database.go | 6 ++-- mongo/index_view.go | 7 +++- 5 files changed, 83 insertions(+), 21 deletions(-) diff --git a/mongo/client_bulk_write.go b/mongo/client_bulk_write.go index 310fdbc301..27c3ad3ce4 100644 --- a/mongo/client_bulk_write.go +++ b/mongo/client_bulk_write.go @@ -476,7 +476,11 @@ func (mb *modelBatches) processResponse(ctx context.Context, resp bsoncore.Docum return err } var cursor *Cursor - cursor, err = newCursor(bCursor, mb.client.bsonOpts, mb.client.registry) + cursor, err = newCursor(bCursor, mb.client.bsonOpts, mb.client.registry, + + // This op doesn't return a cursor to the user, so setting the client + // timeout should be a no-op. + withCursorOptionClientTimeout(mb.client.timeout)) if err != nil { return err } diff --git a/mongo/collection.go b/mongo/collection.go index 9da9240c5e..ef4188d67b 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -1092,7 +1092,13 @@ func aggregate(a aggregateParams, opts ...options.Lister[options.AggregateOption if err != nil { return nil, wrapErrors(err) } - cursor, err := newCursorWithSession(bc, a.client.bsonOpts, a.registry, sess) + cursor, err := newCursorWithSession(bc, a.client.bsonOpts, a.registry, sess, + + // The only way the server will return a tailable/awaitData cursor for an + // aggregate operation is for the first stage in the pipeline to + // be $changeStream, this is the only time maxAwaitTimeMS should be applied. + // For this reason, we pass the client timeout to the cursor. + withCursorOptionClientTimeout(a.client.timeout)) return cursor, wrapErrors(err) } @@ -1567,7 +1573,9 @@ func (coll *Collection) find( if err != nil { return nil, wrapErrors(err) } - return newCursorWithSession(bc, coll.bsonOpts, coll.registry, sess) + + return newCursorWithSession(bc, coll.bsonOpts, coll.registry, sess, + withCursorOptionClientTimeout(coll.client.timeout)) } func newFindArgsFromFindOneArgs(args *options.FindOneOptions) *options.FindOptions { diff --git a/mongo/cursor.go b/mongo/cursor.go index 743622e88b..bb77f4a21e 100644 --- a/mongo/cursor.go +++ b/mongo/cursor.go @@ -31,22 +31,41 @@ type Cursor struct { // to Next or TryNext. If continued access is required, a copy must be made. Current bson.Raw - bc batchCursor - batch *bsoncore.Iterator - batchLength int - bsonOpts *options.BSONOptions - registry *bson.Registry - clientSession *session.Client + bc batchCursor + batch *bsoncore.Iterator + batchLength int + bsonOpts *options.BSONOptions + registry *bson.Registry + clientSession *session.Client + clientTimeout time.Duration + hasClientTimeout bool err error } +type cursorOptions struct { + clientTimeout time.Duration + hasClientTimeout bool +} + +type cursorOption func(*cursorOptions) + +func withCursorOptionClientTimeout(dur *time.Duration) cursorOption { + return func(opts *cursorOptions) { + if dur != nil && *dur > 0 { + opts.clientTimeout = *dur + opts.hasClientTimeout = true + } + } +} + func newCursor( bc batchCursor, bsonOpts *options.BSONOptions, registry *bson.Registry, + opts ...cursorOption, ) (*Cursor, error) { - return newCursorWithSession(bc, bsonOpts, registry, nil) + return newCursorWithSession(bc, bsonOpts, registry, nil, opts...) } func newCursorWithSession( @@ -54,6 +73,7 @@ func newCursorWithSession( bsonOpts *options.BSONOptions, registry *bson.Registry, clientSession *session.Client, + opts ...cursorOption, ) (*Cursor, error) { if registry == nil { registry = defaultRegistry @@ -61,11 +81,19 @@ func newCursorWithSession( if bc == nil { return nil, errors.New("batch cursor must not be nil") } + + cursorOpts := &cursorOptions{} + for _, opt := range opts { + opt(cursorOpts) + } + c := &Cursor{ - bc: bc, - bsonOpts: bsonOpts, - registry: registry, - clientSession: clientSession, + bc: bc, + bsonOpts: bsonOpts, + registry: registry, + clientSession: clientSession, + clientTimeout: cursorOpts.clientTimeout, + hasClientTimeout: cursorOpts.hasClientTimeout, } if bc.ID() == 0 { c.closeImplicitSession() @@ -140,11 +168,17 @@ func NewCursorFromDocuments(documents []any, preloadedErr error, registry *bson. // ID returns the ID of this cursor, or 0 if the cursor has been closed or exhausted. func (c *Cursor) ID() int64 { return c.bc.ID() } -// Next gets the next document for this cursor. It returns true if there were no errors and the cursor has not been -// exhausted. +// Next gets the next document for this cursor. It returns true if there were no +// errors and the cursor has not been exhausted. +// +// Next blocks until a document is available or an error occurs. If the context +// expires, the cursor's error will be set to ctx.Err(). In case of an error, +// Next will return false. // -// Next blocks until a document is available or an error occurs. If the context expires, the cursor's error will -// be set to ctx.Err(). In case of an error, Next will return false. +// If MaxAwaitTime is set, the operation will be bound by the Context's +// deadline. If the context does not have a deadline, the operation will be +// bound by the client-level timeout, if one is set. If MaxAwaitTime is greater +// than the user-provided timeout, Next will return false. // // If Next returns false, subsequent calls will also return false. func (c *Cursor) Next(ctx context.Context) bool { @@ -177,6 +211,15 @@ func (c *Cursor) next(ctx context.Context, nonBlocking bool) bool { ctx = context.Background() } + // If the context does not have a deadline we defer to a client-level timeout, + // if one is set. + if _, ok := ctx.Deadline(); !ok && c.hasClientTimeout { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, c.clientTimeout) + + defer cancel() + } + // To avoid unnecessary socket timeouts, we attempt to short-circuit tailable // awaitData "getMore" operations by ensuring that the maxAwaitTimeMS is less // than the operation timeout. diff --git a/mongo/database.go b/mongo/database.go index 0d2690c4bd..e42097bb8a 100644 --- a/mongo/database.go +++ b/mongo/database.go @@ -306,7 +306,8 @@ func (db *Database) RunCommandCursor( closeImplicitSession(sess) return nil, wrapErrors(err) } - cursor, err := newCursorWithSession(bc, db.bsonOpts, db.registry, sess) + cursor, err := newCursorWithSession(bc, db.bsonOpts, db.registry, sess, + withCursorOptionClientTimeout(db.client.timeout)) return cursor, wrapErrors(err) } @@ -511,7 +512,8 @@ func (db *Database) ListCollections( closeImplicitSession(sess) return nil, wrapErrors(err) } - cursor, err := newCursorWithSession(bc, db.bsonOpts, db.registry, sess) + cursor, err := newCursorWithSession(bc, db.bsonOpts, db.registry, sess, + withCursorOptionClientTimeout(db.client.timeout)) return cursor, wrapErrors(err) } diff --git a/mongo/index_view.go b/mongo/index_view.go index 4e692fbc5c..f8529c169e 100644 --- a/mongo/index_view.go +++ b/mongo/index_view.go @@ -131,7 +131,12 @@ func (iv IndexView) List(ctx context.Context, opts ...options.Lister[options.Lis closeImplicitSession(sess) return nil, wrapErrors(err) } - cursor, err := newCursorWithSession(bc, iv.coll.bsonOpts, iv.coll.registry, sess) + cursor, err := newCursorWithSession(bc, iv.coll.bsonOpts, iv.coll.registry, sess, + + // This value is included for completeness, but a server will never return + // a tailable awaitData cursor from a listIndexes operation. + withCursorOptionClientTimeout(iv.coll.client.timeout)) + return cursor, wrapErrors(err) } From 085d3bc681f79572ca798038b2f76bc5d1bf0275 Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Fri, 22 Aug 2025 16:13:53 -0600 Subject: [PATCH 2/3] Update tailable/awaitData cursor tests to include client-level timeout --- internal/integration/cursor_test.go | 236 ++++++++++++++++++---------- 1 file changed, 155 insertions(+), 81 deletions(-) diff --git a/internal/integration/cursor_test.go b/internal/integration/cursor_test.go index 8bba46df53..647cd5b126 100644 --- a/internal/integration/cursor_test.go +++ b/internal/integration/cursor_test.go @@ -319,116 +319,190 @@ func parseMaxAwaitTime(mt *mtest.T, evt *event.CommandStartedEvent) int64 { return got } -func TestCursor_tailableAwaitData(t *testing.T) { - mt := mtest.New(t, mtest.NewOptions().CreateClient(false)) +func tadcFindFactory(ctx context.Context, mt *mtest.T) (*mongo.Cursor, func() error) { + mt.Helper() - cappedOpts := options.CreateCollection().SetCapped(true). - SetSizeInBytes(1024 * 64) + initCollection(mt, mt.Coll) + cur, err := mt.Coll.Find(ctx, bson.D{{"x", 1}}, + options.Find().SetBatchSize(1).SetCursorType(options.TailableAwait)) + require.NoError(mt, err, "Find error: %v", err) - // TODO(SERVER-96344): mongos doesn't honor a failpoint's full blockTimeMS. - mtOpts := mtest.NewOptions().MinServerVersion("4.4"). - Topologies(mtest.ReplicaSet, mtest.LoadBalanced, mtest.Single). - CollectionCreateOptions(cappedOpts) + return cur, func() error { return cur.Close(context.Background()) } +} - mt.RunOpts("apply remaining timeoutMS if less than maxAwaitTimeMS", mtOpts, func(mt *mtest.T) { - initCollection(mt, mt.Coll) +func tadcAggregateFactory(ctx context.Context, mt *mtest.T) (*mongo.Cursor, func() error) { + mt.Helper() - // Create a 30ms failpoint for getMore. - mt.SetFailPoint(failpoint.FailPoint{ - ConfigureFailPoint: "failCommand", - Mode: failpoint.Mode{ - Times: 1, - }, - Data: failpoint.Data{ - FailCommands: []string{"getMore"}, - BlockConnection: true, - BlockTimeMS: 30, - }, - }) + initCollection(mt, mt.Coll) - // Create a find cursor with a 100ms maxAwaitTimeMS and a tailable awaitData - // cursor type. - opts := options.Find(). - SetBatchSize(1). - SetMaxAwaitTime(100 * time.Millisecond). - SetCursorType(options.TailableAwait) + opts := options.Aggregate().SetMaxAwaitTime(100 * time.Millisecond) + pipe := mongo.Pipeline{{{"$changeStream", bson.D{}}}} - cursor, err := mt.Coll.Find(context.Background(), bson.D{{"x", 2}}, opts) - require.NoError(mt, err) + cursor, err := mt.Coll.Aggregate(ctx, pipe, opts) + require.NoError(mt, err, "Aggregate error: %v", err) - defer cursor.Close(context.Background()) + return cursor, func() error { return cursor.Close(context.Background()) } +} - // Use a 200ms timeout that caps the lifetime of cursor.Next. The underlying - // getMore loop should run at least two times: the first getMore will block - // for 30ms on the getMore and then an additional 100ms for the - // maxAwaitTimeMS. The second getMore will then use the remaining ~70ms - // left on the timeout. - ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) - defer cancel() +func tadcRunCommandCursorFactory(ctx context.Context, mt *mtest.T) (*mongo.Cursor, func() error) { + mt.Helper() - // Iterate twice to force a getMore - cursor.Next(ctx) + initCollection(mt, mt.Coll) - mt.ClearEvents() - cursor.Next(ctx) + cur, err := mt.DB.RunCommandCursor(ctx, bson.D{ + {"find", mt.Coll.Name()}, + {"filter", bson.D{{"x", 1}}}, + {"tailable", true}, + {"awaitData", true}, + {"batchSize", int32(1)}, + }) + require.NoError(mt, err, "RunCommandCursor error: %v", err) - require.Error(mt, cursor.Err(), "expected error from cursor.Next") - assert.ErrorIs(mt, cursor.Err(), context.DeadlineExceeded, "expected context deadline exceeded error") + return cur, func() error { return cur.Close(context.Background()) } +} - // Collect all started events to find the getMore commands. - startedEvents := mt.GetAllStartedEvents() +// For tailable awaitData cursors, the maxTimeMS for a getMore should be +// min(maxAwaitTimeMS, remaining timeoutMS - minRoundTripTime) to allow the +// server more opportunities to respond with an empty batch before a +// client-side timeout. +func TestCursor_tailableAwaitData_applyRemainingTimeout(t *testing.T) { + const timeout = 2000 * time.Millisecond - var getMoreStartedEvents []*event.CommandStartedEvent - for _, evt := range startedEvents { - if evt.CommandName == "getMore" { - getMoreStartedEvents = append(getMoreStartedEvents, evt) - } - } + // Setup mtest instance. + mt := mtest.New(t, mtest.NewOptions().CreateClient(false)) - // The first getMore should have a maxTimeMS of <= 100ms. - assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreStartedEvents[0]), int64(100)) + cappedOpts := options.CreateCollection().SetCapped(true). + SetSizeInBytes(1024 * 64) - // The second getMore should have a maxTimeMS of <=71, indicating that we - // are using the time remaining in the context rather than the - // maxAwaitTimeMS. - assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreStartedEvents[1]), int64(71)) - }) + // TODO(SERVER-96344): mongos doesn't honor a failpoint's full blockTimeMS. + baseTopologies := []mtest.TopologyKind{mtest.Single, mtest.LoadBalanced, mtest.ReplicaSet} + + type testCase struct { + name string + factory func(ctx context.Context, mt *mtest.T) (*mongo.Cursor, func() error) + opTimeout bool + topologies []mtest.TopologyKind + + // Operations that insert a document into the collection will require that + // an initial batch be consumed to ensure that the getMore is sent in + // subsequent Next calls. + consumeFirstBatch bool + } + + cases := []testCase{ + { + name: "find client-level timeout", + factory: tadcFindFactory, + topologies: baseTopologies, + opTimeout: false, + consumeFirstBatch: true, + }, + { + name: "find operation-level timeout", + factory: tadcFindFactory, + topologies: baseTopologies, + opTimeout: true, + consumeFirstBatch: true, + }, + { + name: "aggregate with $changeStream client-level timeout", + factory: tadcAggregateFactory, + topologies: []mtest.TopologyKind{mtest.ReplicaSet, mtest.LoadBalanced}, + opTimeout: false, + consumeFirstBatch: false, + }, + { + name: "aggregate with $changeStream operation-level timeout", + factory: tadcAggregateFactory, + topologies: []mtest.TopologyKind{mtest.ReplicaSet, mtest.LoadBalanced}, + opTimeout: true, + consumeFirstBatch: false, + }, + { + name: "runCommandCursor client-level timeout", + factory: tadcRunCommandCursorFactory, + topologies: baseTopologies, + opTimeout: false, + consumeFirstBatch: true, + }, + { + name: "runCommandCursor operation-level timeout", + factory: tadcRunCommandCursorFactory, + topologies: baseTopologies, + opTimeout: true, + consumeFirstBatch: true, + }, + } - mtOpts.Topologies(mtest.ReplicaSet, mtest.Sharded, mtest.LoadBalanced, mtest.Single) + mtOpts := mtest.NewOptions().CollectionCreateOptions(cappedOpts) - mt.RunOpts("apply maxAwaitTimeMS if less than remaining timeout", mtOpts, func(mt *mtest.T) { - initCollection(mt, mt.Coll) - mt.ClearEvents() + for _, tc := range cases { + caseOpts := mtOpts + caseOpts = caseOpts.Topologies(tc.topologies...) - // Create a find cursor - opts := options.Find().SetBatchSize(1).SetMaxAwaitTime(50 * time.Millisecond) + if !tc.opTimeout { + caseOpts = mtOpts.ClientOptions(options.Client().SetTimeout(timeout)) + } - cursor, err := mt.Coll.Find(context.Background(), bson.D{}, opts) - require.NoError(mt, err) + mt.RunOpts(tc.name, caseOpts, func(mt *mtest.T) { + mt.SetFailPoint(failpoint.FailPoint{ + ConfigureFailPoint: "failCommand", + Mode: failpoint.Mode{Times: 1}, + Data: failpoint.Data{ + FailCommands: []string{"getMore"}, + BlockConnection: true, + BlockTimeMS: 300, + }, + }) - _ = mt.GetStartedEvent() // Empty find from started list. + ctx := context.Background() - defer cursor.Close(context.Background()) + var cancel context.CancelFunc + if tc.opTimeout { + ctx, cancel = context.WithTimeout(ctx, timeout) + defer cancel() + } - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() + cur, cleanup := tc.factory(ctx, mt) + defer func() { assert.NoError(mt, cleanup()) }() - // Iterate twice to force a getMore - cursor.Next(ctx) - cursor.Next(ctx) + require.NoError(mt, cur.Err()) - cmd := mt.GetStartedEvent().Command + cur.SetMaxAwaitTime(1000 * time.Millisecond) - maxTimeMSRaw, err := cmd.LookupErr("maxTimeMS") - require.NoError(mt, err) + if tc.consumeFirstBatch { + assert.True(mt, cur.Next(ctx)) // consume first batch item + } - got, ok := maxTimeMSRaw.AsInt64OK() - require.True(mt, ok) + mt.ClearEvents() + assert.False(mt, cur.Next(ctx)) - assert.LessOrEqual(mt, got, int64(50)) - }) + require.Error(mt, cur.Err(), "expected error from cursor.Next") + assert.ErrorIs(mt, cur.Err(), context.DeadlineExceeded, "expected context deadline exceeded error") + + getMoreEvts := []*event.CommandStartedEvent{} + for _, evt := range mt.GetAllStartedEvents() { + if evt.CommandName == "getMore" { + getMoreEvts = append(getMoreEvts, evt) + } + } + + require.Len(mt, getMoreEvts, 2) + + // The first getMore should have a maxTimeMS of <= 100ms but greater + // than 71ms, indicating that the maxAwaitTimeMS was used. + assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreEvts[0]), int64(1000)) + assert.Greater(mt, parseMaxAwaitTime(mt, getMoreEvts[0]), int64(710)) + + // The second getMore should have a maxTimeMS of <=71, indicating that we + // are using the time remaining in the context rather than the + // maxAwaitTimeMS. + assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreEvts[1]), int64(710)) + }) + } } +// For tailable awaitData cursors, the maxTimeMS for a getMore should be func TestCursor_tailableAwaitData_ShortCircuitingGetMore(t *testing.T) { mt := mtest.New(t, mtest.NewOptions().CreateClient(false)) From 5c404fb885293e5990723f7f69ff9be05227ceeb Mon Sep 17 00:00:00 2001 From: Preston Vasquez Date: Fri, 29 Aug 2025 13:43:47 -0600 Subject: [PATCH 3/3] Decouple teardown from tadc factories --- internal/integration/cursor_test.go | 160 ++++++++++++++-------------- 1 file changed, 81 insertions(+), 79 deletions(-) diff --git a/internal/integration/cursor_test.go b/internal/integration/cursor_test.go index 647cd5b126..215119f440 100644 --- a/internal/integration/cursor_test.go +++ b/internal/integration/cursor_test.go @@ -319,46 +319,50 @@ func parseMaxAwaitTime(mt *mtest.T, evt *event.CommandStartedEvent) int64 { return got } -func tadcFindFactory(ctx context.Context, mt *mtest.T) (*mongo.Cursor, func() error) { +func tadcFindFactory(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor { mt.Helper() - initCollection(mt, mt.Coll) - cur, err := mt.Coll.Find(ctx, bson.D{{"x", 1}}, + initCollection(mt, &coll) + cur, err := coll.Find(ctx, bson.D{{"__nomatch", 1}}, options.Find().SetBatchSize(1).SetCursorType(options.TailableAwait)) require.NoError(mt, err, "Find error: %v", err) - return cur, func() error { return cur.Close(context.Background()) } + return cur } -func tadcAggregateFactory(ctx context.Context, mt *mtest.T) (*mongo.Cursor, func() error) { +func tadcAggregateFactory(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor { mt.Helper() - initCollection(mt, mt.Coll) - - opts := options.Aggregate().SetMaxAwaitTime(100 * time.Millisecond) - pipe := mongo.Pipeline{{{"$changeStream", bson.D{}}}} + initCollection(mt, &coll) + opts := options.Aggregate() + pipeline := mongo.Pipeline{{{"$changeStream", bson.D{{"fullDocument", "default"}}}}, + {{"$match", bson.D{ + {"operationType", "insert"}, + {"fullDocment.__nomatch", 1}, + }}}, + } - cursor, err := mt.Coll.Aggregate(ctx, pipe, opts) + cursor, err := coll.Aggregate(ctx, pipeline, opts) require.NoError(mt, err, "Aggregate error: %v", err) - return cursor, func() error { return cursor.Close(context.Background()) } + return cursor } -func tadcRunCommandCursorFactory(ctx context.Context, mt *mtest.T) (*mongo.Cursor, func() error) { +func tadcRunCommandCursorFactory(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor { mt.Helper() - initCollection(mt, mt.Coll) + initCollection(mt, &coll) - cur, err := mt.DB.RunCommandCursor(ctx, bson.D{ - {"find", mt.Coll.Name()}, - {"filter", bson.D{{"x", 1}}}, + cur, err := coll.Database().RunCommandCursor(ctx, bson.D{ + {"find", coll.Name()}, + {"filter", bson.D{{"__nomatch", 1}}}, {"tailable", true}, {"awaitData", true}, {"batchSize", int32(1)}, }) require.NoError(mt, err, "RunCommandCursor error: %v", err) - return cur, func() error { return cur.Close(context.Background()) } + return cur } // For tailable awaitData cursors, the maxTimeMS for a getMore should be @@ -366,82 +370,81 @@ func tadcRunCommandCursorFactory(ctx context.Context, mt *mtest.T) (*mongo.Curso // server more opportunities to respond with an empty batch before a // client-side timeout. func TestCursor_tailableAwaitData_applyRemainingTimeout(t *testing.T) { - const timeout = 2000 * time.Millisecond - - // Setup mtest instance. - mt := mtest.New(t, mtest.NewOptions().CreateClient(false)) - - cappedOpts := options.CreateCollection().SetCapped(true). - SetSizeInBytes(1024 * 64) - - // TODO(SERVER-96344): mongos doesn't honor a failpoint's full blockTimeMS. + // These values reflect what is used in the unified spec tests, see + // DRIVERS-2868. + const timeoutMS = 200 + const maxAwaitTimeMS = 100 + const blockTimeMS = 30 + const getMoreBound = 71 + + // TODO(GODRIVER-3328): mongos doesn't honor a failpoint's full blockTimeMS. baseTopologies := []mtest.TopologyKind{mtest.Single, mtest.LoadBalanced, mtest.ReplicaSet} type testCase struct { name string - factory func(ctx context.Context, mt *mtest.T) (*mongo.Cursor, func() error) + factory func(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor opTimeout bool topologies []mtest.TopologyKind - - // Operations that insert a document into the collection will require that - // an initial batch be consumed to ensure that the getMore is sent in - // subsequent Next calls. - consumeFirstBatch bool } cases := []testCase{ + // TODO(GODRIVER-2944): "find" cursors are tested in the CSOT unified spec + // tests for tailable/awaitData cursors and so these tests can be removed + // once the driver supports timeoutMode. { - name: "find client-level timeout", - factory: tadcFindFactory, - topologies: baseTopologies, - opTimeout: false, - consumeFirstBatch: true, + name: "find client-level timeout", + factory: tadcFindFactory, + topologies: baseTopologies, + opTimeout: false, }, { - name: "find operation-level timeout", - factory: tadcFindFactory, - topologies: baseTopologies, - opTimeout: true, - consumeFirstBatch: true, + name: "find operation-level timeout", + factory: tadcFindFactory, + topologies: baseTopologies, + opTimeout: true, }, + + // There is no analogue to tailable/awaiData cursor unified spec tests for + // aggregate and runnCommand. { - name: "aggregate with $changeStream client-level timeout", - factory: tadcAggregateFactory, - topologies: []mtest.TopologyKind{mtest.ReplicaSet, mtest.LoadBalanced}, - opTimeout: false, - consumeFirstBatch: false, + name: "aggregate with changeStream client-level timeout", + factory: tadcAggregateFactory, + topologies: []mtest.TopologyKind{mtest.ReplicaSet, mtest.LoadBalanced}, + opTimeout: false, }, { - name: "aggregate with $changeStream operation-level timeout", - factory: tadcAggregateFactory, - topologies: []mtest.TopologyKind{mtest.ReplicaSet, mtest.LoadBalanced}, - opTimeout: true, - consumeFirstBatch: false, + name: "aggregate with changeStream operation-level timeout", + factory: tadcAggregateFactory, + topologies: []mtest.TopologyKind{mtest.ReplicaSet, mtest.LoadBalanced}, + opTimeout: true, }, { - name: "runCommandCursor client-level timeout", - factory: tadcRunCommandCursorFactory, - topologies: baseTopologies, - opTimeout: false, - consumeFirstBatch: true, + name: "runCommandCursor client-level timeout", + factory: tadcRunCommandCursorFactory, + topologies: baseTopologies, + opTimeout: false, }, { - name: "runCommandCursor operation-level timeout", - factory: tadcRunCommandCursorFactory, - topologies: baseTopologies, - opTimeout: true, - consumeFirstBatch: true, + name: "runCommandCursor operation-level timeout", + factory: tadcRunCommandCursorFactory, + topologies: baseTopologies, + opTimeout: true, }, } - mtOpts := mtest.NewOptions().CollectionCreateOptions(cappedOpts) + mt := mtest.New(t, mtest.NewOptions().CreateClient(false).MinServerVersion("4.2")) for _, tc := range cases { - caseOpts := mtOpts - caseOpts = caseOpts.Topologies(tc.topologies...) + // Reset the collection between test cases to avoid leaking timeouts + // between tests. + cappedOpts := options.CreateCollection().SetCapped(true).SetSizeInBytes(1024 * 64) + caseOpts := mtest.NewOptions(). + CollectionCreateOptions(cappedOpts). + Topologies(tc.topologies...). + CreateClient(true) if !tc.opTimeout { - caseOpts = mtOpts.ClientOptions(options.Client().SetTimeout(timeout)) + caseOpts = caseOpts.ClientOptions(options.Client().SetTimeout(timeoutMS * time.Millisecond)) } mt.RunOpts(tc.name, caseOpts, func(mt *mtest.T) { @@ -451,7 +454,7 @@ func TestCursor_tailableAwaitData_applyRemainingTimeout(t *testing.T) { Data: failpoint.Data{ FailCommands: []string{"getMore"}, BlockConnection: true, - BlockTimeMS: 300, + BlockTimeMS: int32(blockTimeMS), }, }) @@ -459,22 +462,19 @@ func TestCursor_tailableAwaitData_applyRemainingTimeout(t *testing.T) { var cancel context.CancelFunc if tc.opTimeout { - ctx, cancel = context.WithTimeout(ctx, timeout) + ctx, cancel = context.WithTimeout(ctx, timeoutMS*time.Millisecond) defer cancel() } - cur, cleanup := tc.factory(ctx, mt) - defer func() { assert.NoError(mt, cleanup()) }() + cur := tc.factory(ctx, mt, *mt.Coll) + defer func() { assert.NoError(mt, cur.Close(context.Background())) }() require.NoError(mt, cur.Err()) - cur.SetMaxAwaitTime(1000 * time.Millisecond) - - if tc.consumeFirstBatch { - assert.True(mt, cur.Next(ctx)) // consume first batch item - } + cur.SetMaxAwaitTime(maxAwaitTimeMS * time.Millisecond) mt.ClearEvents() + assert.False(mt, cur.Next(ctx)) require.Error(mt, cur.Err(), "expected error from cursor.Next") @@ -487,17 +487,19 @@ func TestCursor_tailableAwaitData_applyRemainingTimeout(t *testing.T) { } } - require.Len(mt, getMoreEvts, 2) + // It's possible that three getMore events are called: 100ms, 70ms, and + // then some small leftover of remaining time (e.g. 20µs). + require.GreaterOrEqual(mt, len(getMoreEvts), 2) // The first getMore should have a maxTimeMS of <= 100ms but greater // than 71ms, indicating that the maxAwaitTimeMS was used. - assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreEvts[0]), int64(1000)) - assert.Greater(mt, parseMaxAwaitTime(mt, getMoreEvts[0]), int64(710)) + assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreEvts[0]), int64(maxAwaitTimeMS)) + assert.Greater(mt, parseMaxAwaitTime(mt, getMoreEvts[0]), int64(getMoreBound)) // The second getMore should have a maxTimeMS of <=71, indicating that we // are using the time remaining in the context rather than the // maxAwaitTimeMS. - assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreEvts[1]), int64(710)) + assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreEvts[1]), int64(getMoreBound)) }) } }