diff --git a/internal/integration/cursor_test.go b/internal/integration/cursor_test.go index 092bcc48a5..94acd222c5 100644 --- a/internal/integration/cursor_test.go +++ b/internal/integration/cursor_test.go @@ -349,116 +349,192 @@ func parseMaxAwaitTime(mt *mtest.T, evt *event.CommandStartedEvent) int64 { return got } -func TestCursor_tailableAwaitData(t *testing.T) { - mt := mtest.New(t, mtest.NewOptions().CreateClient(false)) +func tadcFindFactory(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor { + mt.Helper() - cappedOpts := options.CreateCollection().SetCapped(true). - SetSizeInBytes(1024 * 64) + initCollection(mt, &coll) + cur, err := coll.Find(ctx, bson.D{{"__nomatch", 1}}, + options.Find().SetBatchSize(1).SetCursorType(options.TailableAwait)) + require.NoError(mt, err, "Find error: %v", err) - // TODO(SERVER-96344): mongos doesn't honor a failpoint's full blockTimeMS. - mtOpts := mtest.NewOptions().MinServerVersion("4.4"). - Topologies(mtest.ReplicaSet, mtest.LoadBalanced, mtest.Single). - CollectionCreateOptions(cappedOpts) + return cur +} - mt.RunOpts("apply remaining timeoutMS if less than maxAwaitTimeMS", mtOpts, func(mt *mtest.T) { - initCollection(mt, mt.Coll) +func tadcAggregateFactory(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor { + mt.Helper() - // Create a 30ms failpoint for getMore. - mt.SetFailPoint(failpoint.FailPoint{ - ConfigureFailPoint: "failCommand", - Mode: failpoint.Mode{ - Times: 1, - }, - Data: failpoint.Data{ - FailCommands: []string{"getMore"}, - BlockConnection: true, - BlockTimeMS: 30, - }, - }) + initCollection(mt, &coll) + opts := options.Aggregate() + pipeline := mongo.Pipeline{{{"$changeStream", bson.D{{"fullDocument", "default"}}}}, + {{"$match", bson.D{ + {"operationType", "insert"}, + {"fullDocment.__nomatch", 1}, + }}}, + } - // Create a find cursor with a 100ms maxAwaitTimeMS and a tailable awaitData - // cursor type. - opts := options.Find(). - SetBatchSize(1). - SetMaxAwaitTime(100 * time.Millisecond). - SetCursorType(options.TailableAwait) + cursor, err := coll.Aggregate(ctx, pipeline, opts) + require.NoError(mt, err, "Aggregate error: %v", err) - cursor, err := mt.Coll.Find(context.Background(), bson.D{{"x", 2}}, opts) - require.NoError(mt, err) + return cursor +} - defer cursor.Close(context.Background()) +func tadcRunCommandCursorFactory(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor { + mt.Helper() - // Use a 200ms timeout that caps the lifetime of cursor.Next. The underlying - // getMore loop should run at least two times: the first getMore will block - // for 30ms on the getMore and then an additional 100ms for the - // maxAwaitTimeMS. The second getMore will then use the remaining ~70ms - // left on the timeout. - ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) - defer cancel() + initCollection(mt, &coll) - // Iterate twice to force a getMore - cursor.Next(ctx) + cur, err := coll.Database().RunCommandCursor(ctx, bson.D{ + {"find", coll.Name()}, + {"filter", bson.D{{"__nomatch", 1}}}, + {"tailable", true}, + {"awaitData", true}, + {"batchSize", int32(1)}, + }) + require.NoError(mt, err, "RunCommandCursor error: %v", err) - mt.ClearEvents() - cursor.Next(ctx) + return cur +} - require.Error(mt, cursor.Err(), "expected error from cursor.Next") - assert.ErrorIs(mt, cursor.Err(), context.DeadlineExceeded, "expected context deadline exceeded error") +// For tailable awaitData cursors, the maxTimeMS for a getMore should be +// min(maxAwaitTimeMS, remaining timeoutMS - minRoundTripTime) to allow the +// server more opportunities to respond with an empty batch before a +// client-side timeout. +func TestCursor_tailableAwaitData_applyRemainingTimeout(t *testing.T) { + // These values reflect what is used in the unified spec tests, see + // DRIVERS-2868. + const timeoutMS = 200 + const maxAwaitTimeMS = 100 + const blockTimeMS = 30 + const getMoreBound = 71 + + // TODO(GODRIVER-3328): mongos doesn't honor a failpoint's full blockTimeMS. + baseTopologies := []mtest.TopologyKind{mtest.Single, mtest.LoadBalanced, mtest.ReplicaSet} + + type testCase struct { + name string + factory func(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor + opTimeout bool + topologies []mtest.TopologyKind + } - // Collect all started events to find the getMore commands. - startedEvents := mt.GetAllStartedEvents() + cases := []testCase{ + // TODO(GODRIVER-2944): "find" cursors are tested in the CSOT unified spec + // tests for tailable/awaitData cursors and so these tests can be removed + // once the driver supports timeoutMode. + { + name: "find client-level timeout", + factory: tadcFindFactory, + topologies: baseTopologies, + opTimeout: false, + }, + { + name: "find operation-level timeout", + factory: tadcFindFactory, + topologies: baseTopologies, + opTimeout: true, + }, - var getMoreStartedEvents []*event.CommandStartedEvent - for _, evt := range startedEvents { - if evt.CommandName == "getMore" { - getMoreStartedEvents = append(getMoreStartedEvents, evt) - } - } + // There is no analogue to tailable/awaiData cursor unified spec tests for + // aggregate and runnCommand. + { + name: "aggregate with changeStream client-level timeout", + factory: tadcAggregateFactory, + topologies: []mtest.TopologyKind{mtest.ReplicaSet, mtest.LoadBalanced}, + opTimeout: false, + }, + { + name: "aggregate with changeStream operation-level timeout", + factory: tadcAggregateFactory, + topologies: []mtest.TopologyKind{mtest.ReplicaSet, mtest.LoadBalanced}, + opTimeout: true, + }, + { + name: "runCommandCursor client-level timeout", + factory: tadcRunCommandCursorFactory, + topologies: baseTopologies, + opTimeout: false, + }, + { + name: "runCommandCursor operation-level timeout", + factory: tadcRunCommandCursorFactory, + topologies: baseTopologies, + opTimeout: true, + }, + } - // The first getMore should have a maxTimeMS of <= 100ms. - assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreStartedEvents[0]), int64(100)) + mt := mtest.New(t, mtest.NewOptions().CreateClient(false).MinServerVersion("4.2")) - // The second getMore should have a maxTimeMS of <=71, indicating that we - // are using the time remaining in the context rather than the - // maxAwaitTimeMS. - assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreStartedEvents[1]), int64(71)) - }) + for _, tc := range cases { + // Reset the collection between test cases to avoid leaking timeouts + // between tests. + cappedOpts := options.CreateCollection().SetCapped(true).SetSizeInBytes(1024 * 64) + caseOpts := mtest.NewOptions(). + CollectionCreateOptions(cappedOpts). + Topologies(tc.topologies...). + CreateClient(true) - mtOpts.Topologies(mtest.ReplicaSet, mtest.Sharded, mtest.LoadBalanced, mtest.Single) + if !tc.opTimeout { + caseOpts = caseOpts.ClientOptions(options.Client().SetTimeout(timeoutMS * time.Millisecond)) + } - mt.RunOpts("apply maxAwaitTimeMS if less than remaining timeout", mtOpts, func(mt *mtest.T) { - initCollection(mt, mt.Coll) - mt.ClearEvents() + mt.RunOpts(tc.name, caseOpts, func(mt *mtest.T) { + mt.SetFailPoint(failpoint.FailPoint{ + ConfigureFailPoint: "failCommand", + Mode: failpoint.Mode{Times: 1}, + Data: failpoint.Data{ + FailCommands: []string{"getMore"}, + BlockConnection: true, + BlockTimeMS: int32(blockTimeMS), + }, + }) + + ctx := context.Background() - // Create a find cursor - opts := options.Find().SetBatchSize(1).SetMaxAwaitTime(50 * time.Millisecond) + var cancel context.CancelFunc + if tc.opTimeout { + ctx, cancel = context.WithTimeout(ctx, timeoutMS*time.Millisecond) + defer cancel() + } - cursor, err := mt.Coll.Find(context.Background(), bson.D{}, opts) - require.NoError(mt, err) + cur := tc.factory(ctx, mt, *mt.Coll) + defer func() { assert.NoError(mt, cur.Close(context.Background())) }() - _ = mt.GetStartedEvent() // Empty find from started list. + require.NoError(mt, cur.Err()) - defer cursor.Close(context.Background()) + cur.SetMaxAwaitTime(maxAwaitTimeMS * time.Millisecond) - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() + mt.ClearEvents() - // Iterate twice to force a getMore - cursor.Next(ctx) - cursor.Next(ctx) + assert.False(mt, cur.Next(ctx)) - cmd := mt.GetStartedEvent().Command + require.Error(mt, cur.Err(), "expected error from cursor.Next") + assert.ErrorIs(mt, cur.Err(), context.DeadlineExceeded, "expected context deadline exceeded error") - maxTimeMSRaw, err := cmd.LookupErr("maxTimeMS") - require.NoError(mt, err) + getMoreEvts := []*event.CommandStartedEvent{} + for _, evt := range mt.GetAllStartedEvents() { + if evt.CommandName == "getMore" { + getMoreEvts = append(getMoreEvts, evt) + } + } - got, ok := maxTimeMSRaw.AsInt64OK() - require.True(mt, ok) + // It's possible that three getMore events are called: 100ms, 70ms, and + // then some small leftover of remaining time (e.g. 20µs). + require.GreaterOrEqual(mt, len(getMoreEvts), 2) - assert.LessOrEqual(mt, got, int64(50)) - }) + // The first getMore should have a maxTimeMS of <= 100ms but greater + // than 71ms, indicating that the maxAwaitTimeMS was used. + assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreEvts[0]), int64(maxAwaitTimeMS)) + assert.Greater(mt, parseMaxAwaitTime(mt, getMoreEvts[0]), int64(getMoreBound)) + + // The second getMore should have a maxTimeMS of <=71, indicating that we + // are using the time remaining in the context rather than the + // maxAwaitTimeMS. + assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreEvts[1]), int64(getMoreBound)) + }) + } } +// For tailable awaitData cursors, the maxTimeMS for a getMore should be func TestCursor_tailableAwaitData_ShortCircuitingGetMore(t *testing.T) { mt := mtest.New(t, mtest.NewOptions().CreateClient(false)) diff --git a/mongo/client_bulk_write.go b/mongo/client_bulk_write.go index 310fdbc301..27c3ad3ce4 100644 --- a/mongo/client_bulk_write.go +++ b/mongo/client_bulk_write.go @@ -476,7 +476,11 @@ func (mb *modelBatches) processResponse(ctx context.Context, resp bsoncore.Docum return err } var cursor *Cursor - cursor, err = newCursor(bCursor, mb.client.bsonOpts, mb.client.registry) + cursor, err = newCursor(bCursor, mb.client.bsonOpts, mb.client.registry, + + // This op doesn't return a cursor to the user, so setting the client + // timeout should be a no-op. + withCursorOptionClientTimeout(mb.client.timeout)) if err != nil { return err } diff --git a/mongo/collection.go b/mongo/collection.go index 9da9240c5e..ef4188d67b 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -1092,7 +1092,13 @@ func aggregate(a aggregateParams, opts ...options.Lister[options.AggregateOption if err != nil { return nil, wrapErrors(err) } - cursor, err := newCursorWithSession(bc, a.client.bsonOpts, a.registry, sess) + cursor, err := newCursorWithSession(bc, a.client.bsonOpts, a.registry, sess, + + // The only way the server will return a tailable/awaitData cursor for an + // aggregate operation is for the first stage in the pipeline to + // be $changeStream, this is the only time maxAwaitTimeMS should be applied. + // For this reason, we pass the client timeout to the cursor. + withCursorOptionClientTimeout(a.client.timeout)) return cursor, wrapErrors(err) } @@ -1567,7 +1573,9 @@ func (coll *Collection) find( if err != nil { return nil, wrapErrors(err) } - return newCursorWithSession(bc, coll.bsonOpts, coll.registry, sess) + + return newCursorWithSession(bc, coll.bsonOpts, coll.registry, sess, + withCursorOptionClientTimeout(coll.client.timeout)) } func newFindArgsFromFindOneArgs(args *options.FindOneOptions) *options.FindOptions { diff --git a/mongo/cursor.go b/mongo/cursor.go index 743622e88b..bb77f4a21e 100644 --- a/mongo/cursor.go +++ b/mongo/cursor.go @@ -31,22 +31,41 @@ type Cursor struct { // to Next or TryNext. If continued access is required, a copy must be made. Current bson.Raw - bc batchCursor - batch *bsoncore.Iterator - batchLength int - bsonOpts *options.BSONOptions - registry *bson.Registry - clientSession *session.Client + bc batchCursor + batch *bsoncore.Iterator + batchLength int + bsonOpts *options.BSONOptions + registry *bson.Registry + clientSession *session.Client + clientTimeout time.Duration + hasClientTimeout bool err error } +type cursorOptions struct { + clientTimeout time.Duration + hasClientTimeout bool +} + +type cursorOption func(*cursorOptions) + +func withCursorOptionClientTimeout(dur *time.Duration) cursorOption { + return func(opts *cursorOptions) { + if dur != nil && *dur > 0 { + opts.clientTimeout = *dur + opts.hasClientTimeout = true + } + } +} + func newCursor( bc batchCursor, bsonOpts *options.BSONOptions, registry *bson.Registry, + opts ...cursorOption, ) (*Cursor, error) { - return newCursorWithSession(bc, bsonOpts, registry, nil) + return newCursorWithSession(bc, bsonOpts, registry, nil, opts...) } func newCursorWithSession( @@ -54,6 +73,7 @@ func newCursorWithSession( bsonOpts *options.BSONOptions, registry *bson.Registry, clientSession *session.Client, + opts ...cursorOption, ) (*Cursor, error) { if registry == nil { registry = defaultRegistry @@ -61,11 +81,19 @@ func newCursorWithSession( if bc == nil { return nil, errors.New("batch cursor must not be nil") } + + cursorOpts := &cursorOptions{} + for _, opt := range opts { + opt(cursorOpts) + } + c := &Cursor{ - bc: bc, - bsonOpts: bsonOpts, - registry: registry, - clientSession: clientSession, + bc: bc, + bsonOpts: bsonOpts, + registry: registry, + clientSession: clientSession, + clientTimeout: cursorOpts.clientTimeout, + hasClientTimeout: cursorOpts.hasClientTimeout, } if bc.ID() == 0 { c.closeImplicitSession() @@ -140,11 +168,17 @@ func NewCursorFromDocuments(documents []any, preloadedErr error, registry *bson. // ID returns the ID of this cursor, or 0 if the cursor has been closed or exhausted. func (c *Cursor) ID() int64 { return c.bc.ID() } -// Next gets the next document for this cursor. It returns true if there were no errors and the cursor has not been -// exhausted. +// Next gets the next document for this cursor. It returns true if there were no +// errors and the cursor has not been exhausted. +// +// Next blocks until a document is available or an error occurs. If the context +// expires, the cursor's error will be set to ctx.Err(). In case of an error, +// Next will return false. // -// Next blocks until a document is available or an error occurs. If the context expires, the cursor's error will -// be set to ctx.Err(). In case of an error, Next will return false. +// If MaxAwaitTime is set, the operation will be bound by the Context's +// deadline. If the context does not have a deadline, the operation will be +// bound by the client-level timeout, if one is set. If MaxAwaitTime is greater +// than the user-provided timeout, Next will return false. // // If Next returns false, subsequent calls will also return false. func (c *Cursor) Next(ctx context.Context) bool { @@ -177,6 +211,15 @@ func (c *Cursor) next(ctx context.Context, nonBlocking bool) bool { ctx = context.Background() } + // If the context does not have a deadline we defer to a client-level timeout, + // if one is set. + if _, ok := ctx.Deadline(); !ok && c.hasClientTimeout { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, c.clientTimeout) + + defer cancel() + } + // To avoid unnecessary socket timeouts, we attempt to short-circuit tailable // awaitData "getMore" operations by ensuring that the maxAwaitTimeMS is less // than the operation timeout. diff --git a/mongo/database.go b/mongo/database.go index 0d2690c4bd..e42097bb8a 100644 --- a/mongo/database.go +++ b/mongo/database.go @@ -306,7 +306,8 @@ func (db *Database) RunCommandCursor( closeImplicitSession(sess) return nil, wrapErrors(err) } - cursor, err := newCursorWithSession(bc, db.bsonOpts, db.registry, sess) + cursor, err := newCursorWithSession(bc, db.bsonOpts, db.registry, sess, + withCursorOptionClientTimeout(db.client.timeout)) return cursor, wrapErrors(err) } @@ -511,7 +512,8 @@ func (db *Database) ListCollections( closeImplicitSession(sess) return nil, wrapErrors(err) } - cursor, err := newCursorWithSession(bc, db.bsonOpts, db.registry, sess) + cursor, err := newCursorWithSession(bc, db.bsonOpts, db.registry, sess, + withCursorOptionClientTimeout(db.client.timeout)) return cursor, wrapErrors(err) } diff --git a/mongo/index_view.go b/mongo/index_view.go index 4e692fbc5c..f8529c169e 100644 --- a/mongo/index_view.go +++ b/mongo/index_view.go @@ -131,7 +131,12 @@ func (iv IndexView) List(ctx context.Context, opts ...options.Lister[options.Lis closeImplicitSession(sess) return nil, wrapErrors(err) } - cursor, err := newCursorWithSession(bc, iv.coll.bsonOpts, iv.coll.registry, sess) + cursor, err := newCursorWithSession(bc, iv.coll.bsonOpts, iv.coll.registry, sess, + + // This value is included for completeness, but a server will never return + // a tailable awaitData cursor from a listIndexes operation. + withCursorOptionClientTimeout(iv.coll.client.timeout)) + return cursor, wrapErrors(err) }