Skip to content

Commit e33eaa9

Browse files
cont.
1 parent 42852c5 commit e33eaa9

File tree

3 files changed

+103
-1
lines changed

3 files changed

+103
-1
lines changed

internal/integration/cursor_test.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,74 @@ func TestCursor_tailableAwaitData(t *testing.T) {
394394
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreStartedEvents[1]), int64(71))
395395
})
396396

397+
// Use a 200ms timeout that caps the lifetime of cursor.Next. The underlying
398+
// getMore loop should run at least two times: the first getMore will block
399+
// for 30ms on the getMore and then an additional 100ms for the
400+
// maxAwaitTimeMS. The second getMore will then use the remaining ~70ms
401+
// left on the timeout.
402+
clientMtOpts := mtOpts.ClientOptions(options.Client().SetTimeout(200 * time.Millisecond))
403+
404+
mt.RunOpts("apply remaining client-level timeoutMS if less than maxAwaitTimeMS", clientMtOpts, func(mt *mtest.T) {
405+
initCollection(mt, mt.Coll)
406+
407+
// Create a 30ms failpoint for getMore.
408+
mt.SetFailPoint(failpoint.FailPoint{
409+
ConfigureFailPoint: "failCommand",
410+
Mode: failpoint.Mode{
411+
Times: 1,
412+
},
413+
Data: failpoint.Data{
414+
FailCommands: []string{"getMore"},
415+
BlockConnection: true,
416+
BlockTimeMS: 30,
417+
},
418+
})
419+
420+
// Create a find cursor with a 100ms maxAwaitTimeMS and a tailable awaitData
421+
// cursor type.
422+
opts := options.Find().
423+
SetMaxAwaitTime(100 * time.Millisecond).
424+
SetCursorType(options.TailableAwait)
425+
426+
cursor, err := mt.Coll.Find(context.Background(), bson.D{{"x", 1}}, opts)
427+
require.NoError(mt, err)
428+
429+
defer cursor.Close(context.Background())
430+
431+
// Iterate twice to force a getMore
432+
cursor.Next(context.Background())
433+
434+
// We expect 2 calls to getMore. Since batchSize=1 the first call will
435+
mt.ClearEvents()
436+
cursor.Next(context.Background())
437+
438+
m := make(map[string]any)
439+
440+
err = cursor.Decode(&m)
441+
require.NoError(t, err, "expected to decode a document, got error: %v", err)
442+
443+
require.Error(mt, cursor.Err(), "expected error from cursor.Next")
444+
assert.ErrorIs(mt, cursor.Err(), context.DeadlineExceeded, "expected context deadline exceeded error")
445+
446+
// Collect all started events to find the getMore commands.
447+
startedEvents := mt.GetAllStartedEvents()
448+
449+
var getMoreStartedEvents []*event.CommandStartedEvent
450+
for _, evt := range startedEvents {
451+
if evt.CommandName == "getMore" {
452+
getMoreStartedEvents = append(getMoreStartedEvents, evt)
453+
}
454+
}
455+
456+
// The first getMore should have a maxTimeMS of <= 100ms.
457+
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreStartedEvents[0]), int64(100))
458+
459+
// The second getMore should have a maxTimeMS of <=71, indicating that we
460+
// are using the time remaining in the context rather than the
461+
// maxAwaitTimeMS.
462+
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreStartedEvents[1]), int64(71))
463+
})
464+
397465
mtOpts.Topologies(mtest.ReplicaSet, mtest.Sharded, mtest.LoadBalanced, mtest.Single)
398466

399467
mt.RunOpts("apply maxAwaitTimeMS if less than remaining timeout", mtOpts, func(mt *mtest.T) {

mongo/collection.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1567,7 +1567,9 @@ func (coll *Collection) find(
15671567
if err != nil {
15681568
return nil, wrapErrors(err)
15691569
}
1570-
return newCursorWithSession(bc, coll.bsonOpts, coll.registry, sess)
1570+
1571+
return newCursorWithSession(bc, coll.bsonOpts, coll.registry, sess,
1572+
withCursorOptionClientTimeout(coll.client.timeout))
15711573
}
15721574

15731575
func newFindArgsFromFindOneArgs(args *options.FindOneOptions) *options.FindOptions {

mongo/cursor.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,25 @@ type Cursor struct {
3737
bsonOpts *options.BSONOptions
3838
registry *bson.Registry
3939
clientSession *session.Client
40+
clientTimeout *time.Duration
4041

4142
err error
4243
}
4344

45+
type cursorOptions struct {
46+
clientTimeout *time.Duration
47+
}
48+
49+
type cursorOption func(*cursorOptions)
50+
51+
func withCursorOptionClientTimeout(dur *time.Duration) cursorOption {
52+
return func(opts *cursorOptions) {
53+
if dur != nil {
54+
opts.clientTimeout = dur
55+
}
56+
}
57+
}
58+
4459
func newCursor(
4560
bc batchCursor,
4661
bsonOpts *options.BSONOptions,
@@ -54,18 +69,26 @@ func newCursorWithSession(
5469
bsonOpts *options.BSONOptions,
5570
registry *bson.Registry,
5671
clientSession *session.Client,
72+
opts ...cursorOption,
5773
) (*Cursor, error) {
5874
if registry == nil {
5975
registry = defaultRegistry
6076
}
6177
if bc == nil {
6278
return nil, errors.New("batch cursor must not be nil")
6379
}
80+
81+
cursorOpts := &cursorOptions{}
82+
for _, opt := range opts {
83+
opt(cursorOpts)
84+
}
85+
6486
c := &Cursor{
6587
bc: bc,
6688
bsonOpts: bsonOpts,
6789
registry: registry,
6890
clientSession: clientSession,
91+
clientTimeout: cursorOpts.clientTimeout,
6992
}
7093
if bc.ID() == 0 {
7194
c.closeImplicitSession()
@@ -177,6 +200,15 @@ func (c *Cursor) next(ctx context.Context, nonBlocking bool) bool {
177200
ctx = context.Background()
178201
}
179202

203+
// If the context does not have a deadline we defer to a client-level timeout,
204+
// if one is set.
205+
if _, ok := ctx.Deadline(); !ok && c.clientTimeout != nil {
206+
var cancel context.CancelFunc
207+
ctx, cancel = context.WithTimeout(context.Background(), *c.clientTimeout)
208+
209+
defer cancel()
210+
}
211+
180212
// To avoid unnecessary socket timeouts, we attempt to short-circuit tailable
181213
// awaitData "getMore" operations by ensuring that the maxAwaitTimeMS is less
182214
// than the operation timeout.

0 commit comments

Comments
 (0)