Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 157 additions & 81 deletions internal/integration/cursor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,116 +349,192 @@ func parseMaxAwaitTime(mt *mtest.T, evt *event.CommandStartedEvent) int64 {
return got
}

func TestCursor_tailableAwaitData(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().CreateClient(false))
func tadcFindFactory(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor {
mt.Helper()

cappedOpts := options.CreateCollection().SetCapped(true).
SetSizeInBytes(1024 * 64)
initCollection(mt, &coll)
cur, err := coll.Find(ctx, bson.D{{"__nomatch", 1}},
options.Find().SetBatchSize(1).SetCursorType(options.TailableAwait))
require.NoError(mt, err, "Find error: %v", err)

// TODO(SERVER-96344): mongos doesn't honor a failpoint's full blockTimeMS.
mtOpts := mtest.NewOptions().MinServerVersion("4.4").
Topologies(mtest.ReplicaSet, mtest.LoadBalanced, mtest.Single).
CollectionCreateOptions(cappedOpts)
return cur
}

mt.RunOpts("apply remaining timeoutMS if less than maxAwaitTimeMS", mtOpts, func(mt *mtest.T) {
initCollection(mt, mt.Coll)
func tadcAggregateFactory(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor {
mt.Helper()

// Create a 30ms failpoint for getMore.
mt.SetFailPoint(failpoint.FailPoint{
ConfigureFailPoint: "failCommand",
Mode: failpoint.Mode{
Times: 1,
},
Data: failpoint.Data{
FailCommands: []string{"getMore"},
BlockConnection: true,
BlockTimeMS: 30,
},
})
initCollection(mt, &coll)
opts := options.Aggregate()
pipeline := mongo.Pipeline{{{"$changeStream", bson.D{{"fullDocument", "default"}}}},
{{"$match", bson.D{
{"operationType", "insert"},
{"fullDocment.__nomatch", 1},
}}},
}

// Create a find cursor with a 100ms maxAwaitTimeMS and a tailable awaitData
// cursor type.
opts := options.Find().
SetBatchSize(1).
SetMaxAwaitTime(100 * time.Millisecond).
SetCursorType(options.TailableAwait)
cursor, err := coll.Aggregate(ctx, pipeline, opts)
require.NoError(mt, err, "Aggregate error: %v", err)

cursor, err := mt.Coll.Find(context.Background(), bson.D{{"x", 2}}, opts)
require.NoError(mt, err)
return cursor
}

defer cursor.Close(context.Background())
func tadcRunCommandCursorFactory(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor {
mt.Helper()

// Use a 200ms timeout that caps the lifetime of cursor.Next. The underlying
// getMore loop should run at least two times: the first getMore will block
// for 30ms on the getMore and then an additional 100ms for the
// maxAwaitTimeMS. The second getMore will then use the remaining ~70ms
// left on the timeout.
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
defer cancel()
initCollection(mt, &coll)

// Iterate twice to force a getMore
cursor.Next(ctx)
cur, err := coll.Database().RunCommandCursor(ctx, bson.D{
{"find", coll.Name()},
{"filter", bson.D{{"__nomatch", 1}}},
{"tailable", true},
{"awaitData", true},
{"batchSize", int32(1)},
})
require.NoError(mt, err, "RunCommandCursor error: %v", err)

mt.ClearEvents()
cursor.Next(ctx)
return cur
}

require.Error(mt, cursor.Err(), "expected error from cursor.Next")
assert.ErrorIs(mt, cursor.Err(), context.DeadlineExceeded, "expected context deadline exceeded error")
// For tailable awaitData cursors, the maxTimeMS for a getMore should be
// min(maxAwaitTimeMS, remaining timeoutMS - minRoundTripTime) to allow the
// server more opportunities to respond with an empty batch before a
// client-side timeout.
func TestCursor_tailableAwaitData_applyRemainingTimeout(t *testing.T) {
// These values reflect what is used in the unified spec tests, see
// DRIVERS-2868.
const timeoutMS = 200
const maxAwaitTimeMS = 100
const blockTimeMS = 30
const getMoreBound = 71

// TODO(GODRIVER-3328): mongos doesn't honor a failpoint's full blockTimeMS.
baseTopologies := []mtest.TopologyKind{mtest.Single, mtest.LoadBalanced, mtest.ReplicaSet}

type testCase struct {
name string
factory func(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor
opTimeout bool
topologies []mtest.TopologyKind
}

// Collect all started events to find the getMore commands.
startedEvents := mt.GetAllStartedEvents()
cases := []testCase{
// TODO(GODRIVER-2944): "find" cursors are tested in the CSOT unified spec
// tests for tailable/awaitData cursors and so these tests can be removed
// once the driver supports timeoutMode.
{
name: "find client-level timeout",
factory: tadcFindFactory,
topologies: baseTopologies,
opTimeout: false,
},
{
name: "find operation-level timeout",
factory: tadcFindFactory,
topologies: baseTopologies,
opTimeout: true,
},

var getMoreStartedEvents []*event.CommandStartedEvent
for _, evt := range startedEvents {
if evt.CommandName == "getMore" {
getMoreStartedEvents = append(getMoreStartedEvents, evt)
}
}
// There is no analogue to tailable/awaiData cursor unified spec tests for
// aggregate and runnCommand.
{
name: "aggregate with changeStream client-level timeout",
factory: tadcAggregateFactory,
topologies: []mtest.TopologyKind{mtest.ReplicaSet, mtest.LoadBalanced},
opTimeout: false,
},
{
name: "aggregate with changeStream operation-level timeout",
factory: tadcAggregateFactory,
topologies: []mtest.TopologyKind{mtest.ReplicaSet, mtest.LoadBalanced},
opTimeout: true,
},
{
name: "runCommandCursor client-level timeout",
factory: tadcRunCommandCursorFactory,
topologies: baseTopologies,
opTimeout: false,
},
{
name: "runCommandCursor operation-level timeout",
factory: tadcRunCommandCursorFactory,
topologies: baseTopologies,
opTimeout: true,
},
}

// The first getMore should have a maxTimeMS of <= 100ms.
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreStartedEvents[0]), int64(100))
mt := mtest.New(t, mtest.NewOptions().CreateClient(false).MinServerVersion("4.2"))

// The second getMore should have a maxTimeMS of <=71, indicating that we
// are using the time remaining in the context rather than the
// maxAwaitTimeMS.
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreStartedEvents[1]), int64(71))
})
for _, tc := range cases {
// Reset the collection between test cases to avoid leaking timeouts
// between tests.
cappedOpts := options.CreateCollection().SetCapped(true).SetSizeInBytes(1024 * 64)
caseOpts := mtest.NewOptions().
CollectionCreateOptions(cappedOpts).
Topologies(tc.topologies...).
CreateClient(true)

mtOpts.Topologies(mtest.ReplicaSet, mtest.Sharded, mtest.LoadBalanced, mtest.Single)
if !tc.opTimeout {
caseOpts = caseOpts.ClientOptions(options.Client().SetTimeout(timeoutMS * time.Millisecond))
}

mt.RunOpts("apply maxAwaitTimeMS if less than remaining timeout", mtOpts, func(mt *mtest.T) {
initCollection(mt, mt.Coll)
mt.ClearEvents()
mt.RunOpts(tc.name, caseOpts, func(mt *mtest.T) {
mt.SetFailPoint(failpoint.FailPoint{
ConfigureFailPoint: "failCommand",
Mode: failpoint.Mode{Times: 1},
Data: failpoint.Data{
FailCommands: []string{"getMore"},
BlockConnection: true,
BlockTimeMS: int32(blockTimeMS),
},
})

ctx := context.Background()

// Create a find cursor
opts := options.Find().SetBatchSize(1).SetMaxAwaitTime(50 * time.Millisecond)
var cancel context.CancelFunc
if tc.opTimeout {
ctx, cancel = context.WithTimeout(ctx, timeoutMS*time.Millisecond)
defer cancel()
}

cursor, err := mt.Coll.Find(context.Background(), bson.D{}, opts)
require.NoError(mt, err)
cur := tc.factory(ctx, mt, *mt.Coll)
defer func() { assert.NoError(mt, cur.Close(context.Background())) }()

_ = mt.GetStartedEvent() // Empty find from started list.
require.NoError(mt, cur.Err())

defer cursor.Close(context.Background())
cur.SetMaxAwaitTime(maxAwaitTimeMS * time.Millisecond)

ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
mt.ClearEvents()

// Iterate twice to force a getMore
cursor.Next(ctx)
cursor.Next(ctx)
assert.False(mt, cur.Next(ctx))

cmd := mt.GetStartedEvent().Command
require.Error(mt, cur.Err(), "expected error from cursor.Next")
assert.ErrorIs(mt, cur.Err(), context.DeadlineExceeded, "expected context deadline exceeded error")

maxTimeMSRaw, err := cmd.LookupErr("maxTimeMS")
require.NoError(mt, err)
getMoreEvts := []*event.CommandStartedEvent{}
for _, evt := range mt.GetAllStartedEvents() {
if evt.CommandName == "getMore" {
getMoreEvts = append(getMoreEvts, evt)
}
}

got, ok := maxTimeMSRaw.AsInt64OK()
require.True(mt, ok)
// It's possible that three getMore events are called: 100ms, 70ms, and
// then some small leftover of remaining time (e.g. 20µs).
require.GreaterOrEqual(mt, len(getMoreEvts), 2)

assert.LessOrEqual(mt, got, int64(50))
})
// The first getMore should have a maxTimeMS of <= 100ms but greater
// than 71ms, indicating that the maxAwaitTimeMS was used.
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreEvts[0]), int64(maxAwaitTimeMS))
assert.Greater(mt, parseMaxAwaitTime(mt, getMoreEvts[0]), int64(getMoreBound))

// The second getMore should have a maxTimeMS of <=71, indicating that we
// are using the time remaining in the context rather than the
// maxAwaitTimeMS.
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreEvts[1]), int64(getMoreBound))
})
}
}

// For tailable awaitData cursors, the maxTimeMS for a getMore should be
func TestCursor_tailableAwaitData_ShortCircuitingGetMore(t *testing.T) {
mt := mtest.New(t, mtest.NewOptions().CreateClient(false))

Expand Down
6 changes: 5 additions & 1 deletion mongo/client_bulk_write.go
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,11 @@ func (mb *modelBatches) processResponse(ctx context.Context, resp bsoncore.Docum
return err
}
var cursor *Cursor
cursor, err = newCursor(bCursor, mb.client.bsonOpts, mb.client.registry)
cursor, err = newCursor(bCursor, mb.client.bsonOpts, mb.client.registry,

// This op doesn't return a cursor to the user, so setting the client
// timeout should be a no-op.
withCursorOptionClientTimeout(mb.client.timeout))
if err != nil {
return err
}
Expand Down
12 changes: 10 additions & 2 deletions mongo/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -1092,7 +1092,13 @@ func aggregate(a aggregateParams, opts ...options.Lister[options.AggregateOption
if err != nil {
return nil, wrapErrors(err)
}
cursor, err := newCursorWithSession(bc, a.client.bsonOpts, a.registry, sess)
cursor, err := newCursorWithSession(bc, a.client.bsonOpts, a.registry, sess,

// The only way the server will return a tailable/awaitData cursor for an
// aggregate operation is for the first stage in the pipeline to
// be $changeStream, this is the only time maxAwaitTimeMS should be applied.
// For this reason, we pass the client timeout to the cursor.
withCursorOptionClientTimeout(a.client.timeout))
return cursor, wrapErrors(err)
}

Expand Down Expand Up @@ -1567,7 +1573,9 @@ func (coll *Collection) find(
if err != nil {
return nil, wrapErrors(err)
}
return newCursorWithSession(bc, coll.bsonOpts, coll.registry, sess)

return newCursorWithSession(bc, coll.bsonOpts, coll.registry, sess,
withCursorOptionClientTimeout(coll.client.timeout))
}

func newFindArgsFromFindOneArgs(args *options.FindOneOptions) *options.FindOptions {
Expand Down
Loading
Loading