diff --git a/internal/partitions/partitions.go b/internal/partitions/partitions.go index 278f875b..01419808 100644 --- a/internal/partitions/partitions.go +++ b/internal/partitions/partitions.go @@ -119,7 +119,7 @@ const ( func PartitionCollectionWithSize( ctx context.Context, uuidEntry *uuidutil.NamespaceAndUUID, - retryer retry.Retryer, + retryer *retry.Retryer, srcClient *mongo.Client, replicatorList []Replicator, subLogger *logger.Logger, @@ -137,7 +137,7 @@ func PartitionCollectionWithSize( partitions, docCount, byteCount, err := PartitionCollectionWithParameters( ctx, uuidEntry, - &retryer, + retryer, srcClient, replicatorList, defaultSampleRate, @@ -153,7 +153,7 @@ func PartitionCollectionWithSize( return PartitionCollectionWithParameters( ctx, uuidEntry, - &retryer, + retryer, srcClient, replicatorList, defaultSampleRate, diff --git a/internal/retry/retry.go b/internal/retry/retry.go index 362d3f00..470da5b5 100644 --- a/internal/retry/retry.go +++ b/internal/retry/retry.go @@ -44,9 +44,9 @@ type RetryCallback = func(context.Context, *FuncInfo) error // This returns an error if the duration limit is reached, or if f() returns a // non-transient error. func (r *Retryer) Run( - ctx context.Context, logger *logger.Logger, f ...RetryCallback, + ctx context.Context, logger *logger.Logger, funcs ...RetryCallback, ) error { - return r.runRetryLoop(ctx, logger, f) + return r.runRetryLoop(ctx, logger, funcs) } // runRetryLoop contains the core logic for the retry loops. @@ -74,8 +74,18 @@ func (r *Retryer) runRetryLoop( sleepTime := minSleepTime for { + if beforeFunc, hasBefore := r.before.Get(); hasBefore { + beforeFunc() + } + eg, egCtx := errgroup.WithContext(ctx) for i, curFunc := range funcs { + if curFunc == nil { + panic("curFunc should be non-nil") + } + if funcinfos[i] == nil { + panic(fmt.Sprintf("funcinfos[%d] should be non-nil", i)) + } eg.Go(func() error { err := curFunc(egCtx, funcinfos[i]) diff --git a/internal/retry/retryer.go b/internal/retry/retryer.go index 15ba15d9..6c269113 100644 --- a/internal/retry/retryer.go +++ b/internal/retry/retryer.go @@ -2,24 +2,27 @@ package retry import ( "time" + + "github.com/10gen/migration-verifier/option" ) // Retryer handles retrying operations that fail because of network failures. type Retryer struct { retryLimit time.Duration retryRandomly bool + before option.Option[func()] additionalErrorCodes []int } // New returns a new retryer. -func New(retryLimit time.Duration) Retryer { +func New(retryLimit time.Duration) *Retryer { return NewWithRandomlyRetries(retryLimit, false) } // NewWithRandomlyRetries returns a new retryer, but allows the option of setting the // retryRandomly field. -func NewWithRandomlyRetries(retryLimit time.Duration, retryRandomly bool) Retryer { - return Retryer{ +func NewWithRandomlyRetries(retryLimit time.Duration, retryRandomly bool) *Retryer { + return &Retryer{ retryLimit: retryLimit, retryRandomly: retryRandomly, } @@ -29,9 +32,21 @@ func NewWithRandomlyRetries(retryLimit time.Duration, retryRandomly bool) Retrye // this method. This allows for a single function to customize the codes it // wants to retry on. Note that if the Retryer already has additional custom // error codes set, these are _replaced_ when this method is called. -func (r Retryer) WithErrorCodes(codes ...int) Retryer { - r2 := r +func (r *Retryer) WithErrorCodes(codes ...int) *Retryer { + r2 := *r r2.additionalErrorCodes = codes - return r2 + return &r2 +} + +// WithBefore sets a callback that always runs before any retryer callback. +// +// This is useful if there are multiple callbacks and you need to reset some +// condition before each retryer iteration. (In the single-callback case it’s +// largely redundant.) +func (r *Retryer) WithBefore(todo func()) *Retryer { + r2 := *r + r2.before = option.Some(todo) + + return &r2 } diff --git a/internal/uuidutil/get_uuid.go b/internal/uuidutil/get_uuid.go index a8a7e576..86996933 100644 --- a/internal/uuidutil/get_uuid.go +++ b/internal/uuidutil/get_uuid.go @@ -27,7 +27,7 @@ type NamespaceAndUUID struct { CollName string } -func GetCollectionNamespaceAndUUID(ctx context.Context, logger *logger.Logger, retryer retry.Retryer, db *mongo.Database, collName string) (*NamespaceAndUUID, error) { +func GetCollectionNamespaceAndUUID(ctx context.Context, logger *logger.Logger, retryer *retry.Retryer, db *mongo.Database, collName string) (*NamespaceAndUUID, error) { binaryUUID, uuidErr := GetCollectionUUID(ctx, logger, retryer, db, collName) if uuidErr != nil { return nil, uuidErr @@ -39,7 +39,7 @@ func GetCollectionNamespaceAndUUID(ctx context.Context, logger *logger.Logger, r }, nil } -func GetCollectionUUID(ctx context.Context, logger *logger.Logger, retryer retry.Retryer, db *mongo.Database, collName string) (*primitive.Binary, error) { +func GetCollectionUUID(ctx context.Context, logger *logger.Logger, retryer *retry.Retryer, db *mongo.Database, collName string) (*primitive.Binary, error) { filter := bson.D{{"name", collName}} opts := options.ListCollections().SetNameOnly(false) diff --git a/internal/verifier/compare.go b/internal/verifier/compare.go index e619255c..802d0bae 100644 --- a/internal/verifier/compare.go +++ b/internal/verifier/compare.go @@ -25,7 +25,8 @@ func (verifier *Verifier) FetchAndCompareDocuments( types.ByteCount, error, ) { - srcChannel, dstChannel, readSrcCallback, readDstCallback := verifier.getFetcherChannelsAndCallbacks(task) + var srcChannel, dstChannel <-chan bson.Raw + var readSrcCallback, readDstCallback func(context.Context, *retry.FuncInfo) error results := []VerificationResult{} var docCount types.DocumentCount @@ -33,23 +34,31 @@ func (verifier *Verifier) FetchAndCompareDocuments( retryer := retry.New(retry.DefaultDurationLimit) - err := retryer.Run( - givenCtx, - verifier.logger, - readSrcCallback, - readDstCallback, - func(ctx context.Context, _ *retry.FuncInfo) error { - var err error - results, docCount, byteCount, err = verifier.compareDocsFromChannels( - ctx, - task, - srcChannel, - dstChannel, - ) + err := retryer. + WithBefore(func() { + srcChannel, dstChannel, readSrcCallback, readDstCallback = verifier.getFetcherChannelsAndCallbacks(task) + }). + Run( + givenCtx, + verifier.logger, + func(ctx context.Context, fi *retry.FuncInfo) error { + return readSrcCallback(ctx, fi) + }, + func(ctx context.Context, fi *retry.FuncInfo) error { + return readDstCallback(ctx, fi) + }, + func(ctx context.Context, _ *retry.FuncInfo) error { + var err error + results, docCount, byteCount, err = verifier.compareDocsFromChannels( + ctx, + task, + srcChannel, + dstChannel, + ) - return err - }, - ) + return err + }, + ) return results, docCount, byteCount, err } diff --git a/internal/verifier/mongos_refresh.go b/internal/verifier/mongos_refresh.go index 5b12d6aa..ec8fa70f 100644 --- a/internal/verifier/mongos_refresh.go +++ b/internal/verifier/mongos_refresh.go @@ -137,7 +137,7 @@ func RefreshAllMongosInstances( func getAnyExistingShardConnectionStr( ctx context.Context, l *logger.Logger, - r retry.Retryer, + r *retry.Retryer, client *mongo.Client, ) (string, error) { res, err := runListShards(ctx, l, r, client) @@ -169,7 +169,7 @@ func getAnyExistingShardConnectionStr( func runListShards( ctx context.Context, l *logger.Logger, - r retry.Retryer, + r *retry.Retryer, client *mongo.Client, ) (*mongo.SingleResult, error) { var res *mongo.SingleResult