Skip to content

Commit 97b9c71

Browse files
committed
plug retryer into compare.go
1 parent 3425929 commit 97b9c71

File tree

1 file changed

+49
-32
lines changed

1 file changed

+49
-32
lines changed

internal/verifier/compare.go

Lines changed: 49 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ import (
55
"context"
66
"time"
77

8+
"github.com/10gen/migration-verifier/internal/retry"
89
"github.com/10gen/migration-verifier/internal/types"
910
"github.com/pkg/errors"
1011
"go.mongodb.org/mongo-driver/bson"
1112
"go.mongodb.org/mongo-driver/mongo"
1213
"golang.org/x/exp/slices"
13-
"golang.org/x/sync/errgroup"
1414
)
1515

1616
const readTimeout = 10 * time.Minute
@@ -24,31 +24,31 @@ func (verifier *Verifier) FetchAndCompareDocuments(
2424
types.ByteCount,
2525
error,
2626
) {
27-
// This function spawns three threads: one to read from the source,
28-
// another to read from the destination, and a third one to receive the
29-
// docs from the other 2 threads and compare them. It’s done this way,
30-
// rather than fetch-everything-then-compare, to minimize memory usage.
31-
errGroup, groupCtx := errgroup.WithContext(givenCtx)
32-
33-
srcChannel, dstChannel := verifier.getFetcherChannels(groupCtx, errGroup, task)
27+
srcChannel, dstChannel, readSrcCallback, readDstCallback := verifier.getFetcherChannelsAndCallbacks(task)
3428

3529
results := []VerificationResult{}
3630
var docCount types.DocumentCount
3731
var byteCount types.ByteCount
3832

39-
errGroup.Go(func() error {
40-
var err error
41-
results, docCount, byteCount, err = verifier.compareDocsFromChannels(
42-
groupCtx,
43-
task,
44-
srcChannel,
45-
dstChannel,
46-
)
47-
48-
return err
49-
})
33+
retryer := retry.New(retry.DefaultDurationLimit)
34+
35+
err := retryer.Run(
36+
givenCtx,
37+
verifier.logger,
38+
readSrcCallback,
39+
readDstCallback,
40+
func(ctx context.Context, _ *retry.FuncInfo) error {
41+
var err error
42+
results, docCount, byteCount, err = verifier.compareDocsFromChannels(
43+
ctx,
44+
task,
45+
srcChannel,
46+
dstChannel,
47+
)
5048

51-
err := errGroup.Wait()
49+
return err
50+
},
51+
)
5252

5353
return results, docCount, byteCount, err
5454
}
@@ -254,15 +254,18 @@ func simpleTimerReset(t *time.Timer, dur time.Duration) {
254254
t.Reset(dur)
255255
}
256256

257-
func (verifier *Verifier) getFetcherChannels(
258-
ctx context.Context,
259-
errGroup *errgroup.Group,
257+
func (verifier *Verifier) getFetcherChannelsAndCallbacks(
260258
task *VerificationTask,
261-
) (<-chan bson.Raw, <-chan bson.Raw) {
259+
) (
260+
<-chan bson.Raw,
261+
<-chan bson.Raw,
262+
func(context.Context, *retry.FuncInfo) error,
263+
func(context.Context, *retry.FuncInfo) error,
264+
) {
262265
srcChannel := make(chan bson.Raw)
263266
dstChannel := make(chan bson.Raw)
264267

265-
errGroup.Go(func() error {
268+
readSrcCallback := func(ctx context.Context, state *retry.FuncInfo) error {
266269
cursor, err := verifier.getDocumentsCursor(
267270
ctx,
268271
verifier.srcClientCollection(task),
@@ -272,8 +275,10 @@ func (verifier *Verifier) getFetcherChannels(
272275
)
273276

274277
if err == nil {
278+
state.NoteSuccess()
279+
275280
err = errors.Wrap(
276-
iterateCursorToChannel(ctx, cursor, srcChannel),
281+
iterateCursorToChannel(ctx, state, cursor, srcChannel),
277282
"failed to read source documents",
278283
)
279284
} else {
@@ -284,9 +289,9 @@ func (verifier *Verifier) getFetcherChannels(
284289
}
285290

286291
return err
287-
})
292+
}
288293

289-
errGroup.Go(func() error {
294+
readDstCallback := func(ctx context.Context, state *retry.FuncInfo) error {
290295
cursor, err := verifier.getDocumentsCursor(
291296
ctx,
292297
verifier.dstClientCollection(task),
@@ -296,8 +301,10 @@ func (verifier *Verifier) getFetcherChannels(
296301
)
297302

298303
if err == nil {
304+
state.NoteSuccess()
305+
299306
err = errors.Wrap(
300-
iterateCursorToChannel(ctx, cursor, dstChannel),
307+
iterateCursorToChannel(ctx, state, cursor, dstChannel),
301308
"failed to read destination documents",
302309
)
303310
} else {
@@ -308,16 +315,26 @@ func (verifier *Verifier) getFetcherChannels(
308315
}
309316

310317
return err
311-
})
318+
}
312319

313-
return srcChannel, dstChannel
320+
return srcChannel, dstChannel, readSrcCallback, readDstCallback
314321
}
315322

316-
func iterateCursorToChannel(ctx context.Context, cursor *mongo.Cursor, writer chan<- bson.Raw) error {
323+
func iterateCursorToChannel(
324+
ctx context.Context,
325+
state *retry.FuncInfo,
326+
cursor *mongo.Cursor,
327+
writer chan<- bson.Raw,
328+
) error {
317329
for cursor.Next(ctx) {
330+
state.NoteSuccess()
318331
writer <- slices.Clone(cursor.Current)
319332
}
320333

334+
if cursor.Err() == nil {
335+
state.NoteSuccess()
336+
}
337+
321338
close(writer)
322339

323340
return errors.Wrap(cursor.Err(), "failed to iterate cursor")

0 commit comments

Comments
 (0)