Skip to content

Commit 5c404fb

Browse files
Decouple teardown from tadc factories
1 parent 085d3bc commit 5c404fb

File tree

1 file changed

+81
-79
lines changed

1 file changed

+81
-79
lines changed

internal/integration/cursor_test.go

Lines changed: 81 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -319,129 +319,132 @@ func parseMaxAwaitTime(mt *mtest.T, evt *event.CommandStartedEvent) int64 {
319319
return got
320320
}
321321

322-
func tadcFindFactory(ctx context.Context, mt *mtest.T) (*mongo.Cursor, func() error) {
322+
func tadcFindFactory(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor {
323323
mt.Helper()
324324

325-
initCollection(mt, mt.Coll)
326-
cur, err := mt.Coll.Find(ctx, bson.D{{"x", 1}},
325+
initCollection(mt, &coll)
326+
cur, err := coll.Find(ctx, bson.D{{"__nomatch", 1}},
327327
options.Find().SetBatchSize(1).SetCursorType(options.TailableAwait))
328328
require.NoError(mt, err, "Find error: %v", err)
329329

330-
return cur, func() error { return cur.Close(context.Background()) }
330+
return cur
331331
}
332332

333-
func tadcAggregateFactory(ctx context.Context, mt *mtest.T) (*mongo.Cursor, func() error) {
333+
func tadcAggregateFactory(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor {
334334
mt.Helper()
335335

336-
initCollection(mt, mt.Coll)
337-
338-
opts := options.Aggregate().SetMaxAwaitTime(100 * time.Millisecond)
339-
pipe := mongo.Pipeline{{{"$changeStream", bson.D{}}}}
336+
initCollection(mt, &coll)
337+
opts := options.Aggregate()
338+
pipeline := mongo.Pipeline{{{"$changeStream", bson.D{{"fullDocument", "default"}}}},
339+
{{"$match", bson.D{
340+
{"operationType", "insert"},
341+
{"fullDocment.__nomatch", 1},
342+
}}},
343+
}
340344

341-
cursor, err := mt.Coll.Aggregate(ctx, pipe, opts)
345+
cursor, err := coll.Aggregate(ctx, pipeline, opts)
342346
require.NoError(mt, err, "Aggregate error: %v", err)
343347

344-
return cursor, func() error { return cursor.Close(context.Background()) }
348+
return cursor
345349
}
346350

347-
func tadcRunCommandCursorFactory(ctx context.Context, mt *mtest.T) (*mongo.Cursor, func() error) {
351+
func tadcRunCommandCursorFactory(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor {
348352
mt.Helper()
349353

350-
initCollection(mt, mt.Coll)
354+
initCollection(mt, &coll)
351355

352-
cur, err := mt.DB.RunCommandCursor(ctx, bson.D{
353-
{"find", mt.Coll.Name()},
354-
{"filter", bson.D{{"x", 1}}},
356+
cur, err := coll.Database().RunCommandCursor(ctx, bson.D{
357+
{"find", coll.Name()},
358+
{"filter", bson.D{{"__nomatch", 1}}},
355359
{"tailable", true},
356360
{"awaitData", true},
357361
{"batchSize", int32(1)},
358362
})
359363
require.NoError(mt, err, "RunCommandCursor error: %v", err)
360364

361-
return cur, func() error { return cur.Close(context.Background()) }
365+
return cur
362366
}
363367

364368
// For tailable awaitData cursors, the maxTimeMS for a getMore should be
365369
// min(maxAwaitTimeMS, remaining timeoutMS - minRoundTripTime) to allow the
366370
// server more opportunities to respond with an empty batch before a
367371
// client-side timeout.
368372
func TestCursor_tailableAwaitData_applyRemainingTimeout(t *testing.T) {
369-
const timeout = 2000 * time.Millisecond
370-
371-
// Setup mtest instance.
372-
mt := mtest.New(t, mtest.NewOptions().CreateClient(false))
373-
374-
cappedOpts := options.CreateCollection().SetCapped(true).
375-
SetSizeInBytes(1024 * 64)
376-
377-
// TODO(SERVER-96344): mongos doesn't honor a failpoint's full blockTimeMS.
373+
// These values reflect what is used in the unified spec tests, see
374+
// DRIVERS-2868.
375+
const timeoutMS = 200
376+
const maxAwaitTimeMS = 100
377+
const blockTimeMS = 30
378+
const getMoreBound = 71
379+
380+
// TODO(GODRIVER-3328): mongos doesn't honor a failpoint's full blockTimeMS.
378381
baseTopologies := []mtest.TopologyKind{mtest.Single, mtest.LoadBalanced, mtest.ReplicaSet}
379382

380383
type testCase struct {
381384
name string
382-
factory func(ctx context.Context, mt *mtest.T) (*mongo.Cursor, func() error)
385+
factory func(ctx context.Context, mt *mtest.T, coll mongo.Collection) *mongo.Cursor
383386
opTimeout bool
384387
topologies []mtest.TopologyKind
385-
386-
// Operations that insert a document into the collection will require that
387-
// an initial batch be consumed to ensure that the getMore is sent in
388-
// subsequent Next calls.
389-
consumeFirstBatch bool
390388
}
391389

392390
cases := []testCase{
391+
// TODO(GODRIVER-2944): "find" cursors are tested in the CSOT unified spec
392+
// tests for tailable/awaitData cursors and so these tests can be removed
393+
// once the driver supports timeoutMode.
393394
{
394-
name: "find client-level timeout",
395-
factory: tadcFindFactory,
396-
topologies: baseTopologies,
397-
opTimeout: false,
398-
consumeFirstBatch: true,
395+
name: "find client-level timeout",
396+
factory: tadcFindFactory,
397+
topologies: baseTopologies,
398+
opTimeout: false,
399399
},
400400
{
401-
name: "find operation-level timeout",
402-
factory: tadcFindFactory,
403-
topologies: baseTopologies,
404-
opTimeout: true,
405-
consumeFirstBatch: true,
401+
name: "find operation-level timeout",
402+
factory: tadcFindFactory,
403+
topologies: baseTopologies,
404+
opTimeout: true,
406405
},
406+
407+
// There is no analogue to tailable/awaiData cursor unified spec tests for
408+
// aggregate and runnCommand.
407409
{
408-
name: "aggregate with $changeStream client-level timeout",
409-
factory: tadcAggregateFactory,
410-
topologies: []mtest.TopologyKind{mtest.ReplicaSet, mtest.LoadBalanced},
411-
opTimeout: false,
412-
consumeFirstBatch: false,
410+
name: "aggregate with changeStream client-level timeout",
411+
factory: tadcAggregateFactory,
412+
topologies: []mtest.TopologyKind{mtest.ReplicaSet, mtest.LoadBalanced},
413+
opTimeout: false,
413414
},
414415
{
415-
name: "aggregate with $changeStream operation-level timeout",
416-
factory: tadcAggregateFactory,
417-
topologies: []mtest.TopologyKind{mtest.ReplicaSet, mtest.LoadBalanced},
418-
opTimeout: true,
419-
consumeFirstBatch: false,
416+
name: "aggregate with changeStream operation-level timeout",
417+
factory: tadcAggregateFactory,
418+
topologies: []mtest.TopologyKind{mtest.ReplicaSet, mtest.LoadBalanced},
419+
opTimeout: true,
420420
},
421421
{
422-
name: "runCommandCursor client-level timeout",
423-
factory: tadcRunCommandCursorFactory,
424-
topologies: baseTopologies,
425-
opTimeout: false,
426-
consumeFirstBatch: true,
422+
name: "runCommandCursor client-level timeout",
423+
factory: tadcRunCommandCursorFactory,
424+
topologies: baseTopologies,
425+
opTimeout: false,
427426
},
428427
{
429-
name: "runCommandCursor operation-level timeout",
430-
factory: tadcRunCommandCursorFactory,
431-
topologies: baseTopologies,
432-
opTimeout: true,
433-
consumeFirstBatch: true,
428+
name: "runCommandCursor operation-level timeout",
429+
factory: tadcRunCommandCursorFactory,
430+
topologies: baseTopologies,
431+
opTimeout: true,
434432
},
435433
}
436434

437-
mtOpts := mtest.NewOptions().CollectionCreateOptions(cappedOpts)
435+
mt := mtest.New(t, mtest.NewOptions().CreateClient(false).MinServerVersion("4.2"))
438436

439437
for _, tc := range cases {
440-
caseOpts := mtOpts
441-
caseOpts = caseOpts.Topologies(tc.topologies...)
438+
// Reset the collection between test cases to avoid leaking timeouts
439+
// between tests.
440+
cappedOpts := options.CreateCollection().SetCapped(true).SetSizeInBytes(1024 * 64)
441+
caseOpts := mtest.NewOptions().
442+
CollectionCreateOptions(cappedOpts).
443+
Topologies(tc.topologies...).
444+
CreateClient(true)
442445

443446
if !tc.opTimeout {
444-
caseOpts = mtOpts.ClientOptions(options.Client().SetTimeout(timeout))
447+
caseOpts = caseOpts.ClientOptions(options.Client().SetTimeout(timeoutMS * time.Millisecond))
445448
}
446449

447450
mt.RunOpts(tc.name, caseOpts, func(mt *mtest.T) {
@@ -451,30 +454,27 @@ func TestCursor_tailableAwaitData_applyRemainingTimeout(t *testing.T) {
451454
Data: failpoint.Data{
452455
FailCommands: []string{"getMore"},
453456
BlockConnection: true,
454-
BlockTimeMS: 300,
457+
BlockTimeMS: int32(blockTimeMS),
455458
},
456459
})
457460

458461
ctx := context.Background()
459462

460463
var cancel context.CancelFunc
461464
if tc.opTimeout {
462-
ctx, cancel = context.WithTimeout(ctx, timeout)
465+
ctx, cancel = context.WithTimeout(ctx, timeoutMS*time.Millisecond)
463466
defer cancel()
464467
}
465468

466-
cur, cleanup := tc.factory(ctx, mt)
467-
defer func() { assert.NoError(mt, cleanup()) }()
469+
cur := tc.factory(ctx, mt, *mt.Coll)
470+
defer func() { assert.NoError(mt, cur.Close(context.Background())) }()
468471

469472
require.NoError(mt, cur.Err())
470473

471-
cur.SetMaxAwaitTime(1000 * time.Millisecond)
472-
473-
if tc.consumeFirstBatch {
474-
assert.True(mt, cur.Next(ctx)) // consume first batch item
475-
}
474+
cur.SetMaxAwaitTime(maxAwaitTimeMS * time.Millisecond)
476475

477476
mt.ClearEvents()
477+
478478
assert.False(mt, cur.Next(ctx))
479479

480480
require.Error(mt, cur.Err(), "expected error from cursor.Next")
@@ -487,17 +487,19 @@ func TestCursor_tailableAwaitData_applyRemainingTimeout(t *testing.T) {
487487
}
488488
}
489489

490-
require.Len(mt, getMoreEvts, 2)
490+
// It's possible that three getMore events are called: 100ms, 70ms, and
491+
// then some small leftover of remaining time (e.g. 20µs).
492+
require.GreaterOrEqual(mt, len(getMoreEvts), 2)
491493

492494
// The first getMore should have a maxTimeMS of <= 100ms but greater
493495
// than 71ms, indicating that the maxAwaitTimeMS was used.
494-
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreEvts[0]), int64(1000))
495-
assert.Greater(mt, parseMaxAwaitTime(mt, getMoreEvts[0]), int64(710))
496+
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreEvts[0]), int64(maxAwaitTimeMS))
497+
assert.Greater(mt, parseMaxAwaitTime(mt, getMoreEvts[0]), int64(getMoreBound))
496498

497499
// The second getMore should have a maxTimeMS of <=71, indicating that we
498500
// are using the time remaining in the context rather than the
499501
// maxAwaitTimeMS.
500-
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreEvts[1]), int64(710))
502+
assert.LessOrEqual(mt, parseMaxAwaitTime(mt, getMoreEvts[1]), int64(getMoreBound))
501503
})
502504
}
503505
}

0 commit comments

Comments
 (0)