Skip to content

Commit 724d9e4

Browse files
GODRIVER-3473 Short-cicruit cursor.next() on invalid timeouts (#2135)
1 parent 0c85ece commit 724d9e4

File tree

11 files changed

+389
-184
lines changed

11 files changed

+389
-184
lines changed

internal/integration/collection_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2028,16 +2028,16 @@ func TestCollection(t *testing.T) {
20282028
})
20292029
}
20302030

2031-
func initCollection(mt *mtest.T, coll *mongo.Collection) {
2032-
mt.Helper()
2031+
func initCollection(tb testing.TB, coll *mongo.Collection) {
2032+
tb.Helper()
20332033

20342034
var docs []interface{}
20352035
for i := 1; i <= 5; i++ {
20362036
docs = append(docs, bson.D{{"x", int32(i)}})
20372037
}
20382038

20392039
_, err := coll.InsertMany(context.Background(), docs)
2040-
assert.Nil(mt, err, "InsertMany error for initial data: %v", err)
2040+
assert.NoError(tb, err, "InsertMany error for initial data: %v", err)
20412041
}
20422042

20432043
func testAggregateWithOptions(mt *mtest.T, createIndex bool, opts options.Lister[options.AggregateOptions]) {

internal/integration/cursor_test.go

Lines changed: 217 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"time"
1515

1616
"go.mongodb.org/mongo-driver/v2/bson"
17+
"go.mongodb.org/mongo-driver/v2/event"
1718
"go.mongodb.org/mongo-driver/v2/internal/assert"
1819
"go.mongodb.org/mongo-driver/v2/internal/failpoint"
1920
"go.mongodb.org/mongo-driver/v2/internal/integration/mtest"
@@ -304,77 +305,248 @@ func TestCursor(t *testing.T) {
304305
batchSize = sizeVal.Int32()
305306
assert.Equal(mt, int32(4), batchSize, "expected batchSize 4, got %v", batchSize)
306307
})
308+
}
307309

308-
tailableAwaitDataCursorOpts := mtest.NewOptions().MinServerVersion("4.4").
309-
Topologies(mtest.ReplicaSet, mtest.Sharded, mtest.LoadBalanced, mtest.Single)
310+
func parseMaxAwaitTime(mt *mtest.T, evt *event.CommandStartedEvent) int64 {
311+
mt.Helper()
310312

311-
mt.RunOpts("tailable awaitData cursor", tailableAwaitDataCursorOpts, func(mt *mtest.T) {
312-
mt.Run("apply remaining timeoutMS if less than maxAwaitTimeMS", func(mt *mtest.T) {
313-
initCollection(mt, mt.Coll)
314-
mt.ClearEvents()
313+
maxTimeMSRaw, err := evt.Command.LookupErr("maxTimeMS")
314+
require.NoError(mt, err)
315315

316-
// Create a find cursor
317-
opts := options.Find().SetBatchSize(1).SetMaxAwaitTime(100 * time.Millisecond)
316+
got, ok := maxTimeMSRaw.AsInt64OK()
317+
require.True(mt, ok)
318318

319-
cursor, err := mt.Coll.Find(context.Background(), bson.D{}, opts)
320-
require.NoError(mt, err)
319+
return got
320+
}
321321

322-
_ = mt.GetStartedEvent() // Empty find from started list.
322+
func TestCursor_tailableAwaitData(t *testing.T) {
323+
mt := mtest.New(t, mtest.NewOptions().CreateClient(false))
323324

324-
defer cursor.Close(context.Background())
325+
cappedOpts := options.CreateCollection().SetCapped(true).
326+
SetSizeInBytes(1024 * 64)
325327

326-
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
327-
defer cancel()
328+
// TODO(SERVER-96344): mongos doesn't honor a failpoint's full blockTimeMS.
329+
mtOpts := mtest.NewOptions().MinServerVersion("4.4").
330+
Topologies(mtest.ReplicaSet, mtest.LoadBalanced, mtest.Single).
331+
CollectionCreateOptions(cappedOpts)
328332

329-
// Iterate twice to force a getMore
330-
cursor.Next(ctx)
331-
cursor.Next(ctx)
333+
mt.RunOpts("apply remaining timeoutMS if less than maxAwaitTimeMS", mtOpts, func(mt *mtest.T) {
334+
initCollection(mt, mt.Coll)
332335

333-
cmd := mt.GetStartedEvent().Command
336+
// Create a 30ms failpoint for getMore.
337+
mt.SetFailPoint(failpoint.FailPoint{
338+
ConfigureFailPoint: "failCommand",
339+
Mode: failpoint.Mode{
340+
Times: 1,
341+
},
342+
Data: failpoint.Data{
343+
FailCommands: []string{"getMore"},
344+
BlockConnection: true,
345+
BlockTimeMS: 30,
346+
},
347+
})
334348

335-
maxTimeMSRaw, err := cmd.LookupErr("maxTimeMS")
336-
require.NoError(mt, err)
349+
// Create a find cursor with a 100ms maxAwaitTimeMS and a tailable awaitData
350+
// cursor type.
351+
opts := options.Find().
352+
SetBatchSize(1).
353+
SetMaxAwaitTime(100 * time.Millisecond).
354+
SetCursorType(options.TailableAwait)
337355

338-
got, ok := maxTimeMSRaw.AsInt64OK()
339-
require.True(mt, ok)
356+
cursor, err := mt.Coll.Find(context.Background(), bson.D{{"x", 2}}, opts)
357+
require.NoError(mt, err)
340358

341-
assert.LessOrEqual(mt, got, int64(50))
342-
})
359+
defer cursor.Close(context.Background())
343360

344-
mt.RunOpts("apply maxAwaitTimeMS if less than remaining timeout", tailableAwaitDataCursorOpts, func(mt *mtest.T) {
345-
initCollection(mt, mt.Coll)
346-
mt.ClearEvents()
361+
// Use a 200ms timeout that caps the lifetime of cursor.Next. The underlying
362+
// getMore loop should run at least two times: the first getMore will block
363+
// for 30ms on the getMore and then an additional 100ms for the
364+
// maxAwaitTimeMS. The second getMore will then use the remaining ~70ms
365+
// left on the timeout.
366+
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
367+
defer cancel()
347368

348-
// Create a find cursor
349-
opts := options.Find().SetBatchSize(1).SetMaxAwaitTime(50 * time.Millisecond)
369+
// Iterate twice to force a getMore
370+
cursor.Next(ctx)
350371

351-
cursor, err := mt.Coll.Find(context.Background(), bson.D{}, opts)
352-
require.NoError(mt, err)
372+
mt.ClearEvents()
373+
cursor.Next(ctx)
353374

354-
_ = mt.GetStartedEvent() // Empty find from started list.
375+
require.Error(mt, cursor.Err(), "expected error from cursor.Next")
376+
assert.ErrorIs(mt, cursor.Err(), context.DeadlineExceeded, "expected context deadline exceeded error")
355377

356-
defer cursor.Close(context.Background())
378+
// Collect all started events to find the getMore commands.
379+
startedEvents := mt.GetAllStartedEvents()
357380

358-
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
359-
defer cancel()
381+
var getMoreStartedEvents []*event.CommandStartedEvent
382+
for _, evt := range startedEvents {
383+
if evt.CommandName == "getMore" {
384+
getMoreStartedEvents = append(getMoreStartedEvents, evt)
385+
}
386+
}
360387

361-
// Iterate twice to force a getMore
362-
cursor.Next(ctx)
363-
cursor.Next(ctx)
388+
// The first getMore should have a maxTimeMS of <= 100ms.
389+
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreStartedEvents[0]), int64(100))
364390

365-
cmd := mt.GetStartedEvent().Command
391+
// The second getMore should have a maxTimeMS of <=71, indicating that we
392+
// are using the time remaining in the context rather than the
393+
// maxAwaitTimeMS.
394+
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreStartedEvents[1]), int64(71))
395+
})
366396

367-
maxTimeMSRaw, err := cmd.LookupErr("maxTimeMS")
368-
require.NoError(mt, err)
397+
mtOpts.Topologies(mtest.ReplicaSet, mtest.Sharded, mtest.LoadBalanced, mtest.Single)
369398

370-
got, ok := maxTimeMSRaw.AsInt64OK()
371-
require.True(mt, ok)
399+
mt.RunOpts("apply maxAwaitTimeMS if less than remaining timeout", mtOpts, func(mt *mtest.T) {
400+
initCollection(mt, mt.Coll)
401+
mt.ClearEvents()
372402

373-
assert.LessOrEqual(mt, got, int64(50))
374-
})
403+
// Create a find cursor
404+
opts := options.Find().SetBatchSize(1).SetMaxAwaitTime(50 * time.Millisecond)
405+
406+
cursor, err := mt.Coll.Find(context.Background(), bson.D{}, opts)
407+
require.NoError(mt, err)
408+
409+
_ = mt.GetStartedEvent() // Empty find from started list.
410+
411+
defer cursor.Close(context.Background())
412+
413+
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
414+
defer cancel()
415+
416+
// Iterate twice to force a getMore
417+
cursor.Next(ctx)
418+
cursor.Next(ctx)
419+
420+
cmd := mt.GetStartedEvent().Command
421+
422+
maxTimeMSRaw, err := cmd.LookupErr("maxTimeMS")
423+
require.NoError(mt, err)
424+
425+
got, ok := maxTimeMSRaw.AsInt64OK()
426+
require.True(mt, ok)
427+
428+
assert.LessOrEqual(mt, got, int64(50))
375429
})
376430
}
377431

432+
func TestCursor_tailableAwaitData_ShortCircuitingGetMore(t *testing.T) {
433+
mt := mtest.New(t, mtest.NewOptions().CreateClient(false))
434+
435+
cappedOpts := options.CreateCollection().SetCapped(true).
436+
SetSizeInBytes(1024 * 64)
437+
438+
mtOpts := mtest.NewOptions().CollectionCreateOptions(cappedOpts)
439+
tests := []struct {
440+
name string
441+
deadline time.Duration
442+
maxAwaitTime time.Duration
443+
wantShortCircuit bool
444+
}{
445+
{
446+
name: "maxAwaitTime less than operation timeout",
447+
deadline: 200 * time.Millisecond,
448+
maxAwaitTime: 100 * time.Millisecond,
449+
wantShortCircuit: false,
450+
},
451+
{
452+
name: "maxAwaitTime equal to operation timeout",
453+
deadline: 200 * time.Millisecond,
454+
maxAwaitTime: 200 * time.Millisecond,
455+
wantShortCircuit: true,
456+
},
457+
{
458+
name: "maxAwaitTime greater than operation timeout",
459+
deadline: 200 * time.Millisecond,
460+
maxAwaitTime: 300 * time.Millisecond,
461+
wantShortCircuit: true,
462+
},
463+
}
464+
465+
for _, tt := range tests {
466+
mt.Run(tt.name, func(mt *mtest.T) {
467+
mt.RunOpts("find", mtOpts, func(mt *mtest.T) {
468+
initCollection(mt, mt.Coll)
469+
470+
// Create a find cursor
471+
opts := options.Find().
472+
SetBatchSize(1).
473+
SetMaxAwaitTime(tt.maxAwaitTime).
474+
SetCursorType(options.TailableAwait)
475+
476+
ctx, cancel := context.WithTimeout(context.Background(), tt.deadline)
477+
defer cancel()
478+
479+
cur, err := mt.Coll.Find(ctx, bson.D{{Key: "x", Value: 3}}, opts)
480+
require.NoError(mt, err, "Find error: %v", err)
481+
482+
// Close to return the session to the pool.
483+
defer cur.Close(context.Background())
484+
485+
ok := cur.Next(ctx)
486+
if tt.wantShortCircuit {
487+
assert.False(mt, ok, "expected Next to return false, got true")
488+
assert.EqualError(t, cur.Err(), "MaxAwaitTime must be less than the operation timeout")
489+
} else {
490+
assert.True(mt, ok, "expected Next to return true, got false")
491+
assert.NoError(mt, cur.Err(), "expected no error, got %v", cur.Err())
492+
}
493+
})
494+
495+
mt.RunOpts("aggregate", mtOpts, func(mt *mtest.T) {
496+
initCollection(mt, mt.Coll)
497+
498+
// Create a find cursor
499+
opts := options.Aggregate().
500+
SetBatchSize(1).
501+
SetMaxAwaitTime(tt.maxAwaitTime)
502+
503+
ctx, cancel := context.WithTimeout(context.Background(), tt.deadline)
504+
defer cancel()
505+
506+
cur, err := mt.Coll.Aggregate(ctx, []bson.D{}, opts)
507+
require.NoError(mt, err, "Aggregate error: %v", err)
508+
509+
// Close to return the session to the pool.
510+
defer cur.Close(context.Background())
511+
512+
ok := cur.Next(ctx)
513+
if tt.wantShortCircuit {
514+
assert.False(mt, ok, "expected Next to return false, got true")
515+
assert.EqualError(t, cur.Err(), "MaxAwaitTime must be less than the operation timeout")
516+
} else {
517+
assert.True(mt, ok, "expected Next to return true, got false")
518+
assert.NoError(mt, cur.Err(), "expected no error, got %v", cur.Err())
519+
}
520+
})
521+
522+
// The $changeStream stage is only supported on replica sets.
523+
watchOpts := mtOpts.Topologies(mtest.ReplicaSet, mtest.Sharded)
524+
mt.RunOpts("watch", watchOpts, func(mt *mtest.T) {
525+
initCollection(mt, mt.Coll)
526+
527+
// Create a find cursor
528+
opts := options.ChangeStream().SetMaxAwaitTime(tt.maxAwaitTime)
529+
530+
ctx, cancel := context.WithTimeout(context.Background(), tt.deadline)
531+
defer cancel()
532+
533+
cur, err := mt.Coll.Watch(ctx, []bson.D{}, opts)
534+
require.NoError(mt, err, "Watch error: %v", err)
535+
536+
// Close to return the session to the pool.
537+
defer cur.Close(context.Background())
538+
539+
if tt.wantShortCircuit {
540+
ok := cur.Next(ctx)
541+
542+
assert.False(mt, ok, "expected Next to return false, got true")
543+
assert.EqualError(mt, cur.Err(), "MaxAwaitTime must be less than the operation timeout")
544+
}
545+
})
546+
})
547+
}
548+
}
549+
378550
type tryNextCursor interface {
379551
TryNext(context.Context) bool
380552
Err() error

internal/mongoutil/mongoutil.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
package mongoutil
88

99
import (
10+
"context"
1011
"reflect"
12+
"time"
1113

1214
"go.mongodb.org/mongo-driver/v2/mongo/options"
1315
)
@@ -83,3 +85,17 @@ func HostsFromURI(uri string) ([]string, error) {
8385

8486
return opts.Hosts, nil
8587
}
88+
89+
// TimeoutWithinContext will return true if the provided timeout is nil or if
90+
// it is less than the context deadline. If the context does not have a
91+
// deadline, it will return true.
92+
func TimeoutWithinContext(ctx context.Context, timeout time.Duration) bool {
93+
deadline, ok := ctx.Deadline()
94+
if !ok {
95+
return true
96+
}
97+
98+
ctxTimeout := time.Until(deadline)
99+
100+
return ctxTimeout <= 0 || timeout < ctxTimeout
101+
}

0 commit comments

Comments
 (0)