Skip to content

Commit 9230ddf

Browse files
committed
REP-5329 make retryer callbacks take a context
1 parent 0ae907f commit 9230ddf

File tree

7 files changed

+24
-22
lines changed

7 files changed

+24
-22
lines changed

internal/partitions/partitions.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ func GetSizeAndDocumentCount(ctx context.Context, logger *logger.Logger, retryer
324324
Capped bool `bson:"capped"`
325325
}{}
326326

327-
err := retryer.Run(ctx, logger, func(ri *retry.Info) error {
327+
err := retryer.Run(ctx, logger, func(ctx context.Context, ri *retry.Info) error {
328328
ri.Log(logger.Logger, "collStats", "source", srcDB.Name(), collName, "Retrieving collection size and document count.")
329329
request := bson.D{
330330
{"aggregate", collName},
@@ -395,7 +395,7 @@ func GetDocumentCountAfterFiltering(ctx context.Context, logger *logger.Logger,
395395
}
396396
pipeline = append(pipeline, bson.D{{"$count", "numFilteredDocs"}})
397397

398-
err := retryer.Run(ctx, logger, func(ri *retry.Info) error {
398+
err := retryer.Run(ctx, logger, func(ctx context.Context, ri *retry.Info) error {
399399
ri.Log(logger.Logger, "count", "source", srcDB.Name(), collName, "Counting filtered documents.")
400400
request := bson.D{
401401
{"aggregate", collName},
@@ -488,7 +488,7 @@ func getOuterIDBound(
488488
}...)
489489

490490
// Get one document containing only the smallest or largest _id value in the collection.
491-
err := retryer.Run(ctx, subLogger, func(ri *retry.Info) error {
491+
err := retryer.Run(ctx, subLogger, func(ctx context.Context, ri *retry.Info) error {
492492
ri.Log(subLogger.Logger, "aggregate", "source", srcDB.Name(), collName, fmt.Sprintf("getting %s _id partition bound", minOrMaxBound))
493493
cursor, cmdErr :=
494494
srcDB.RunCommandCursor(ctx, bson.D{
@@ -577,7 +577,7 @@ func getMidIDBounds(
577577
// Get a cursor for the $sample and $bucketAuto aggregation.
578578
var midIDBounds []interface{}
579579
agRetryer := retryer.WithErrorCodes(util.SampleTooManyDuplicates)
580-
err := agRetryer.Run(ctx, logger, func(ri *retry.Info) error {
580+
err := agRetryer.Run(ctx, logger, func(ctx context.Context, ri *retry.Info) error {
581581
ri.Log(logger.Logger, "aggregate", "source", srcDB.Name(), collName, "Retrieving mid _id partition bounds using $sample.")
582582
cursor, cmdErr :=
583583
srcDB.RunCommandCursor(ctx, bson.D{

internal/retry/retry.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import (
99
"github.com/10gen/migration-verifier/internal/util"
1010
)
1111

12+
type RetryCallback = func(context.Context, *Info) error
13+
1214
// Run retries f() whenever a transient error happens, up to the retryer's
1315
// configured duration limit.
1416
//
@@ -24,7 +26,7 @@ import (
2426
// This returns an error if the duration limit is reached, or if f() returns a
2527
// non-transient error.
2628
func (r *Retryer) Run(
27-
ctx context.Context, logger *logger.Logger, f func(*Info) error,
29+
ctx context.Context, logger *logger.Logger, f RetryCallback,
2830
) error {
2931
return r.runRetryLoop(ctx, logger, f)
3032
}
@@ -33,7 +35,7 @@ func (r *Retryer) Run(
3335
func (r *Retryer) runRetryLoop(
3436
ctx context.Context,
3537
logger *logger.Logger,
36-
f func(*Info) error,
38+
f RetryCallback,
3739
) error {
3840
var err error
3941

@@ -44,7 +46,7 @@ func (r *Retryer) runRetryLoop(
4446
sleepTime := minSleepTime
4547

4648
for {
47-
err = f(ri)
49+
err = f(ctx, ri)
4850

4951
// If f() returned a transient error, sleep and increase the sleep
5052
// time for the next retry, maxing out at the maxSleepTime.

internal/retry/retryer_test.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ func (suite *UnitTestSuite) TestRetryer() {
1515

1616
suite.Run("with a function that immediately succeeds", func() {
1717
attemptNumber := -1
18-
f := func(ri *Info) error {
18+
f := func(_ context.Context, ri *Info) error {
1919
attemptNumber = ri.GetAttemptNumber()
2020
return nil
2121
}
@@ -24,7 +24,7 @@ func (suite *UnitTestSuite) TestRetryer() {
2424
suite.NoError(err)
2525
suite.Equal(0, attemptNumber)
2626

27-
f2 := func(ri *Info) error {
27+
f2 := func(_ context.Context, ri *Info) error {
2828
attemptNumber = ri.GetAttemptNumber()
2929
return nil
3030
}
@@ -36,7 +36,7 @@ func (suite *UnitTestSuite) TestRetryer() {
3636

3737
suite.Run("with a function that succeeds after two attempts", func() {
3838
attemptNumber := -1
39-
f := func(ri *Info) error {
39+
f := func(_ context.Context, ri *Info) error {
4040
attemptNumber = ri.GetAttemptNumber()
4141
if attemptNumber < 2 {
4242
return mongo.CommandError{
@@ -52,7 +52,7 @@ func (suite *UnitTestSuite) TestRetryer() {
5252
suite.Equal(2, attemptNumber)
5353

5454
attemptNumber = -1
55-
f2 := func(ri *Info) error {
55+
f2 := func(_ context.Context, ri *Info) error {
5656
attemptNumber = ri.GetAttemptNumber()
5757
if attemptNumber < 2 {
5858
return mongo.CommandError{
@@ -77,7 +77,7 @@ func (suite *UnitTestSuite) TestRetryerDurationLimitIsZero() {
7777
Labels: []string{"NetworkError"},
7878
Name: "NetworkError",
7979
}
80-
f := func(ri *Info) error {
80+
f := func(_ context.Context, ri *Info) error {
8181
attemptNumber = ri.attemptNumber
8282
return cmdErr
8383
}
@@ -103,7 +103,7 @@ func (suite *UnitTestSuite) TestRetryerDurationReset() {
103103
// 1) Not calling IterationSuccess() means f will not be retried, since the
104104
// durationLimit is exceeded
105105
noSuccessIterations := 0
106-
f1 := func(ri *Info) error {
106+
f1 := func(_ context.Context, ri *Info) error {
107107
// Artificially advance how much time was taken.
108108
ri.lastResetTime = ri.lastResetTime.Add(-2 * ri.durationLimit)
109109

@@ -126,7 +126,7 @@ func (suite *UnitTestSuite) TestRetryerDurationReset() {
126126
// 2) Calling IterationSuccess() means f will run more than once because the
127127
// duration should be reset.
128128
successIterations := 0
129-
f2 := func(ri *Info) error {
129+
f2 := func(_ context.Context, ri *Info) error {
130130
// Artificially advance how much time was taken.
131131
ri.lastResetTime = ri.lastResetTime.Add(-2 * ri.durationLimit)
132132

@@ -152,7 +152,7 @@ func (suite *UnitTestSuite) TestCancelViaContext() {
152152
counter := 0
153153
var wg sync.WaitGroup
154154
wg.Add(1)
155-
f := func(_ *Info) error {
155+
f := func(_ context.Context, _ *Info) error {
156156
counter++
157157
if counter == 1 {
158158
return errors.New("not master")
@@ -184,7 +184,7 @@ func (suite *UnitTestSuite) TestRetryerAdditionalErrorCodes() {
184184
}
185185

186186
var attemptNumber int
187-
f := func(ri *Info) error {
187+
f := func(_ context.Context, ri *Info) error {
188188
attemptNumber = ri.GetAttemptNumber()
189189
if attemptNumber == 0 {
190190
return customError

internal/uuidutil/get_uuid.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func GetCollectionUUID(ctx context.Context, logger *logger.Logger, retryer retry
4747
err := retryer.Run(
4848
ctx,
4949
logger,
50-
func(ri *retry.Info) error {
50+
func(_ context.Context, ri *retry.Info) error {
5151
ri.Log(logger.Logger, "ListCollectionSpecifications", db.Name(), collName, "Getting collection UUID.", "")
5252
var driverErr error
5353
collSpecs, driverErr = db.ListCollectionSpecifications(ctx, filter, opts)

internal/verifier/change_stream.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ func (verifier *Verifier) StartChangeStream(ctx context.Context) error {
398398
err := retryer.Run(
399399
ctx,
400400
verifier.logger,
401-
func(ri *retry.Info) error {
401+
func(ctx context.Context, ri *retry.Info) error {
402402
srcChangeStream, startTs, err := verifier.createChangeStream(ctx)
403403
if err != nil {
404404
if parentThreadWaiting {

internal/verifier/clustertime.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ func GetNewClusterTime(
3434
err := retryer.Run(
3535
ctx,
3636
logger,
37-
func(_ *retry.Info) error {
37+
func(ctx context.Context, _ *retry.Info) error {
3838
var err error
3939
clusterTime, err = runAppendOplogNote(
4040
ctx,
@@ -56,7 +56,7 @@ func GetNewClusterTime(
5656
err = retryer.Run(
5757
ctx,
5858
logger,
59-
func(_ *retry.Info) error {
59+
func(ctx context.Context, _ *retry.Info) error {
6060
var err error
6161
_, err = runAppendOplogNote(
6262
ctx,

internal/verifier/recheck.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,9 @@ func (verifier *Verifier) insertRecheckDocs(
136136
err := retryer.Run(
137137
groupCtx,
138138
verifier.logger,
139-
func(_ *retry.Info) error {
139+
func(retryCtx context.Context, _ *retry.Info) error {
140140
_, err := verifier.verificationDatabase().Collection(recheckQueue).BulkWrite(
141-
groupCtx,
141+
retryCtx,
142142
models,
143143
options.BulkWrite().SetOrdered(false),
144144
)

0 commit comments

Comments
 (0)