diff --git a/internal/partitions/partition.go b/internal/partitions/partition.go index 024458e8..d19f150f 100644 --- a/internal/partitions/partition.go +++ b/internal/partitions/partition.go @@ -137,7 +137,7 @@ func (p *Partition) FindCmd( // (e.g. use the partitions on the source to read the destination for verification) // If the passed-in buildinfo indicates a mongodb version < 5.0, type bracketing is not used. // filterAndPredicates is a slice of filter criteria that's used to construct the "filter" field in the find option. -func (p *Partition) GetFindOptions(buildInfo *util.BuildInfo, filterAndPredicates bson.A) bson.D { +func (p *Partition) GetFindOptions(clusterInfo *util.ClusterInfo, filterAndPredicates bson.A) bson.D { if p == nil { if len(filterAndPredicates) > 0 { return bson.D{{"filter", bson.D{{"$and", filterAndPredicates}}}} @@ -158,11 +158,11 @@ func (p *Partition) GetFindOptions(buildInfo *util.BuildInfo, filterAndPredicate // For non-capped collections, the cursor should use the ID filter and the _id index. // Get the bounded query filter from the partition to be used in the Find command. allowTypeBracketing := false - if buildInfo != nil { + if clusterInfo != nil { allowTypeBracketing = true - if buildInfo.VersionArray != nil { - allowTypeBracketing = buildInfo.VersionArray[0] < 5 + if clusterInfo.VersionArray != nil { + allowTypeBracketing = clusterInfo.VersionArray[0] < 5 } } if !allowTypeBracketing { diff --git a/internal/partitions/partition_test.go b/internal/partitions/partition_test.go index 53d41a0a..5f847ba4 100644 --- a/internal/partitions/partition_test.go +++ b/internal/partitions/partition_test.go @@ -78,37 +78,37 @@ func (suite *UnitTestSuite) TestVersioning() { suite.Require().Equal(expectedFilter, filter) // 6.0 - findOptions = partition.GetFindOptions(&util.BuildInfo{VersionArray: []int{6, 0, 0}}, nil) + findOptions = partition.GetFindOptions(&util.ClusterInfo{VersionArray: []int{6, 0, 0}}, nil) filter = getFilterFromFindOptions(findOptions) suite.Require().Equal(expectedFilter, filter) // 5.3.0.9 - findOptions = partition.GetFindOptions(&util.BuildInfo{VersionArray: []int{5, 3, 0, 9}}, nil) + findOptions = partition.GetFindOptions(&util.ClusterInfo{VersionArray: []int{5, 3, 0, 9}}, nil) filter = getFilterFromFindOptions(findOptions) suite.Require().Equal(expectedFilter, filter) // 7.1.3.5 - findOptions = partition.GetFindOptions(&util.BuildInfo{VersionArray: []int{7, 1, 3, 5}}, nil) + findOptions = partition.GetFindOptions(&util.ClusterInfo{VersionArray: []int{7, 1, 3, 5}}, nil) filter = getFilterFromFindOptions(findOptions) suite.Require().Equal(expectedFilter, filter) // 4.4 (int64) - findOptions = partition.GetFindOptions(&util.BuildInfo{VersionArray: []int{4, 4, 0, 0}}, nil) + findOptions = partition.GetFindOptions(&util.ClusterInfo{VersionArray: []int{4, 4, 0, 0}}, nil) filter = getFilterFromFindOptions(findOptions) suite.Require().Equal(expectedFilterWithTypeBracketing, filter) // 4.4 - findOptions = partition.GetFindOptions(&util.BuildInfo{VersionArray: []int{4, 4, 0, 0}}, nil) + findOptions = partition.GetFindOptions(&util.ClusterInfo{VersionArray: []int{4, 4, 0, 0}}, nil) filter = getFilterFromFindOptions(findOptions) suite.Require().Equal(expectedFilterWithTypeBracketing, filter) // 4.2 - findOptions = partition.GetFindOptions(&util.BuildInfo{VersionArray: []int{4, 2, 0, 0}}, nil) + findOptions = partition.GetFindOptions(&util.ClusterInfo{VersionArray: []int{4, 2, 0, 0}}, nil) filter = getFilterFromFindOptions(findOptions) suite.Require().Equal(expectedFilterWithTypeBracketing, filter) // No version array -- assume old, require type bracketing. - findOptions = partition.GetFindOptions(&util.BuildInfo{}, nil) + findOptions = partition.GetFindOptions(&util.ClusterInfo{}, nil) filter = getFilterFromFindOptions(findOptions) suite.Require().Equal(expectedFilterWithTypeBracketing, filter) } diff --git a/internal/util/buildinfo.go b/internal/util/buildinfo.go deleted file mode 100644 index 2338e7dc..00000000 --- a/internal/util/buildinfo.go +++ /dev/null @@ -1,31 +0,0 @@ -package util - -import ( - "context" - - "github.com/10gen/migration-verifier/mbson" - "github.com/pkg/errors" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/mongo" -) - -type BuildInfo struct { - VersionArray []int -} - -func GetBuildInfo(ctx context.Context, client *mongo.Client) (BuildInfo, error) { - commandResult := client.Database("admin").RunCommand(ctx, bson.D{{"buildinfo", 1}}) - - rawResp, err := commandResult.Raw() - if err != nil { - return BuildInfo{}, errors.Wrap(err, "failed to fetch build info") - } - - bi := BuildInfo{} - _, err = mbson.RawLookup(rawResp, &bi.VersionArray, "versionArray") - if err != nil { - return BuildInfo{}, errors.Wrap(err, "failed to decode build info version array") - } - - return bi, nil -} diff --git a/internal/util/clusterinfo.go b/internal/util/clusterinfo.go new file mode 100644 index 00000000..980a374c --- /dev/null +++ b/internal/util/clusterinfo.go @@ -0,0 +1,78 @@ +package util + +import ( + "context" + + "github.com/10gen/migration-verifier/mbson" + "github.com/pkg/errors" + "github.com/samber/lo" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" +) + +type ClusterTopology string + +type ClusterInfo struct { + VersionArray []int + Topology ClusterTopology +} + +const ( + TopologySharded ClusterTopology = "sharded" + TopologyReplset ClusterTopology = "replset" +) + +func GetClusterInfo(ctx context.Context, client *mongo.Client) (ClusterInfo, error) { + va, err := getVersionArray(ctx, client) + if err != nil { + return ClusterInfo{}, errors.Wrap(err, "failed to fetch version array") + } + + topology, err := getTopology(ctx, client) + if err != nil { + return ClusterInfo{}, errors.Wrap(err, "failed to determine topology") + } + + return ClusterInfo{ + VersionArray: va, + Topology: topology, + }, nil +} + +func getVersionArray(ctx context.Context, client *mongo.Client) ([]int, error) { + commandResult := client.Database("admin").RunCommand(ctx, bson.D{{"buildinfo", 1}}) + + rawResp, err := commandResult.Raw() + if err != nil { + return nil, errors.Wrapf(err, "failed to run %#q", "buildinfo") + } + + var va []int + _, err = mbson.RawLookup(rawResp, &va, "versionArray") + if err != nil { + return nil, errors.Wrap(err, "failed to decode build info version array") + } + + return va, nil +} + +func getTopology(ctx context.Context, client *mongo.Client) (ClusterTopology, error) { + resp := client.Database("admin").RunCommand( + ctx, + bson.D{{"hello", 1}}, + ) + + hello := struct { + Msg string + }{} + + if err := resp.Decode(&hello); err != nil { + return "", errors.Wrapf( + err, + "failed to decode %#q response", + "hello", + ) + } + + return lo.Ternary(hello.Msg == "isdbgrid", TopologySharded, TopologyReplset), nil +} diff --git a/internal/verifier/change_stream.go b/internal/verifier/change_stream.go index 483af491..fe1548d3 100644 --- a/internal/verifier/change_stream.go +++ b/internal/verifier/change_stream.go @@ -315,7 +315,7 @@ func (verifier *Verifier) createChangeStream( SetMaxAwaitTime(1 * time.Second). SetFullDocument(options.UpdateLookup) - if verifier.srcBuildInfo.VersionArray[0] >= 6 { + if verifier.srcClusterInfo.VersionArray[0] >= 6 { opts = opts.SetCustomPipeline(bson.M{"showExpandedEvents": true}) } diff --git a/internal/verifier/change_stream_test.go b/internal/verifier/change_stream_test.go index f4a6f42d..1b1aca0a 100644 --- a/internal/verifier/change_stream_test.go +++ b/internal/verifier/change_stream_test.go @@ -416,7 +416,7 @@ func (suite *IntegrationTestSuite) testInsertsBeforeWritesOff(docsCount int) { func (suite *IntegrationTestSuite) TestCreateForbidden() { ctx := suite.Context() - buildInfo, err := util.GetBuildInfo(ctx, suite.srcMongoClient) + buildInfo, err := util.GetClusterInfo(ctx, suite.srcMongoClient) suite.Require().NoError(err) if buildInfo.VersionArray[0] < 6 { diff --git a/internal/verifier/compare.go b/internal/verifier/compare.go index d6a50324..911051ad 100644 --- a/internal/verifier/compare.go +++ b/internal/verifier/compare.go @@ -266,7 +266,7 @@ func (verifier *Verifier) getFetcherChannels( cursor, err := verifier.getDocumentsCursor( ctx, verifier.srcClientCollection(task), - verifier.srcBuildInfo, + verifier.srcClusterInfo, verifier.srcStartAtTs, task, ) @@ -290,7 +290,7 @@ func (verifier *Verifier) getFetcherChannels( cursor, err := verifier.getDocumentsCursor( ctx, verifier.dstClientCollection(task), - verifier.dstBuildInfo, + verifier.dstClusterInfo, nil, //startAtTs task, ) diff --git a/internal/verifier/integration_test_suite.go b/internal/verifier/integration_test_suite.go index 33e3a336..805d84b8 100644 --- a/internal/verifier/integration_test_suite.go +++ b/internal/verifier/integration_test_suite.go @@ -161,7 +161,7 @@ func (suite *IntegrationTestSuite) BuildVerifier() *Verifier { qfilter := QueryFilter{Namespace: "keyhole.dealers"} task := VerificationTask{QueryFilter: qfilter} - verifier := NewVerifier(VerifierSettings{}) + verifier := NewVerifier(VerifierSettings{}, "stderr") //verifier.SetStartClean(true) verifier.SetNumWorkers(3) verifier.SetGenerationPauseDelayMillis(0) @@ -183,7 +183,6 @@ func (suite *IntegrationTestSuite) BuildVerifier() *Verifier { verifier.SetMetaURI(ctx, suite.metaConnStr), "should set metadata connection string", ) - verifier.SetLogger("stderr") verifier.SetMetaDBName(metaDBName) suite.Require().NoError(verifier.srcClientCollection(&task).Drop(ctx)) diff --git a/internal/verifier/migration_verifier.go b/internal/verifier/migration_verifier.go index 9b056a51..b6a04d6b 100644 --- a/internal/verifier/migration_verifier.go +++ b/internal/verifier/migration_verifier.go @@ -90,8 +90,8 @@ type Verifier struct { metaClient *mongo.Client srcClient *mongo.Client dstClient *mongo.Client - srcBuildInfo *util.BuildInfo - dstBuildInfo *util.BuildInfo + srcClusterInfo *util.ClusterInfo + dstClusterInfo *util.ClusterInfo numWorkers int failureDisplaySize int64 @@ -187,13 +187,18 @@ type VerifierSettings struct { } // NewVerifier creates a new Verifier -func NewVerifier(settings VerifierSettings) *Verifier { +func NewVerifier(settings VerifierSettings, logPath string) *Verifier { readConcern := settings.ReadConcernSetting if readConcern == "" { readConcern = ReadConcernMajority } + logger, logWriter := getLoggerAndWriter(logPath) + return &Verifier{ + logger: logger, + writer: logWriter, + phase: Idle, numWorkers: NumWorkers, readPreference: readpref.Primary(), @@ -311,40 +316,6 @@ func (verifier *Verifier) AddMetaIndexes(ctx context.Context) error { return err } -func (verifier *Verifier) SetSrcURI(ctx context.Context, uri string) error { - opts := verifier.getClientOpts(uri) - var err error - verifier.srcClient, err = mongo.Connect(ctx, opts) - if err != nil { - return errors.Wrapf(err, "failed to connect to source %#q", uri) - } - - buildInfo, err := util.GetBuildInfo(ctx, verifier.srcClient) - if err != nil { - return errors.Wrap(err, "failed to read source build info") - } - - verifier.srcBuildInfo = &buildInfo - return nil -} - -func (verifier *Verifier) SetDstURI(ctx context.Context, uri string) error { - opts := verifier.getClientOpts(uri) - var err error - verifier.dstClient, err = mongo.Connect(ctx, opts) - if err != nil { - return errors.Wrapf(err, "failed to connect to destination %#q", uri) - } - - buildInfo, err := util.GetBuildInfo(ctx, verifier.dstClient) - if err != nil { - return errors.Wrap(err, "failed to read destination build info") - } - - verifier.dstBuildInfo = &buildInfo - return nil -} - func (verifier *Verifier) SetServerPort(port int) { verifier.port = port } @@ -366,10 +337,6 @@ func (verifier *Verifier) SetPartitionSizeMB(partitionSizeMB uint32) { verifier.partitionSizeInBytes = int64(partitionSizeMB) * 1024 * 1024 } -func (verifier *Verifier) SetLogger(logPath string) { - verifier.logger, verifier.writer = getLoggerAndWriter(logPath) -} - func (verifier *Verifier) SetSrcNamespaces(arg []string) { verifier.srcNamespaces = arg } @@ -473,7 +440,7 @@ func (verifier *Verifier) maybeAppendGlobalFilterToPredicates(predicates bson.A) return append(predicates, verifier.globalFilter) } -func (verifier *Verifier) getDocumentsCursor(ctx context.Context, collection *mongo.Collection, buildInfo *util.BuildInfo, +func (verifier *Verifier) getDocumentsCursor(ctx context.Context, collection *mongo.Collection, clusterInfo *util.ClusterInfo, startAtTs *primitive.Timestamp, task *VerificationTask) (*mongo.Cursor, error) { var findOptions bson.D runCommandOptions := options.RunCmd() @@ -486,7 +453,7 @@ func (verifier *Verifier) getDocumentsCursor(ctx context.Context, collection *mo bson.E{"filter", bson.D{{"$and", andPredicates}}}, } } else { - findOptions = task.QueryFilter.Partition.GetFindOptions(buildInfo, verifier.maybeAppendGlobalFilterToPredicates(andPredicates)) + findOptions = task.QueryFilter.Partition.GetFindOptions(clusterInfo, verifier.maybeAppendGlobalFilterToPredicates(andPredicates)) } if verifier.readPreference.Mode() != readpref.PrimaryMode { runCommandOptions = runCommandOptions.SetReadPreference(verifier.readPreference) diff --git a/internal/verifier/migration_verifier_bench_test.go b/internal/verifier/migration_verifier_bench_test.go index 9732ded4..9568e53d 100644 --- a/internal/verifier/migration_verifier_bench_test.go +++ b/internal/verifier/migration_verifier_bench_test.go @@ -54,7 +54,7 @@ func BenchmarkGeneric(t *testing.B) { fmt.Printf("Running with %s as the meta db name. Specify META_DB_NAME= to change\n", metaDBName) // fmt.Printf("Running with %s as the namespace. Specify META_DB_NAME= to change\n", metaDBName) - verifier := NewVerifier(VerifierSettings{}) + verifier := NewVerifier(VerifierSettings{}, "stderr") verifier.SetNumWorkers(numWorkers) verifier.SetGenerationPauseDelayMillis(0) verifier.SetWorkerSleepDelayMillis(0) @@ -71,7 +71,6 @@ func BenchmarkGeneric(t *testing.B) { if err != nil { t.Fatal(err) } - verifier.SetLogger("stderr") verifier.SetMetaDBName(metaDBName) err = verifier.verificationTaskCollection().Drop(context.Background()) if err != nil { diff --git a/internal/verifier/migration_verifier_test.go b/internal/verifier/migration_verifier_test.go index a8fae37d..46449c8b 100644 --- a/internal/verifier/migration_verifier_test.go +++ b/internal/verifier/migration_verifier_test.go @@ -456,7 +456,7 @@ func (suite *IntegrationTestSuite) TestFailedVerificationTaskInsertions() { func TestVerifierCompareDocs(t *testing.T) { id := rand.Intn(1000) - verifier := NewVerifier(VerifierSettings{}) + verifier := NewVerifier(VerifierSettings{}, "stderr") verifier.SetIgnoreBSONFieldOrder(true) type compareTest struct { diff --git a/internal/verifier/mongos_refresh.go b/internal/verifier/mongos_refresh.go new file mode 100644 index 00000000..33c31131 --- /dev/null +++ b/internal/verifier/mongos_refresh.go @@ -0,0 +1,185 @@ +package verifier + +import ( + "context" + + "github.com/10gen/migration-verifier/internal/logger" + "github.com/10gen/migration-verifier/internal/retry" + "github.com/10gen/migration-verifier/mmongo" + "github.com/pkg/errors" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/mongo/readconcern" +) + +const UnauthorizedErrCode = 13 + +// RefreshAllMongosInstances prevents data corruption from SERVER-32198, which can cause reads and writes to be +// accepted by the wrong shard (this is caused by a mongos not knowing the collection is sharded and the shard not +// knowing the collection is sharded). This method relies on the verifier rejecting 4.4 source SRV connection strings +// (see the `auditor` package for more details). +// +// This method must only be called on a sharded cluster, otherwise it returns a "no such command: 'listShards'" error. +// +// Note: this is a reimplementation of MaybeRefreshAllSourceMongosInstances() in mongosync. +func RefreshAllMongosInstances( + ctx context.Context, + l *logger.Logger, + clientOpts *options.ClientOptions, +) error { + hosts := clientOpts.Hosts + l.Info(). + Strs("hosts", hosts). + Msgf("Refreshing all %d mongos instance(s) on the source.", len(hosts)) + + r := retry.New(retry.DefaultDurationLimit) + + for _, host := range hosts { + singleHostClientOpts := *clientOpts + + // Only connect to one host at a time. + singleHostClientOpts.SetHosts([]string{host}) + + // Only open 1 connection to each mongos to reduce the risk of overwhelming the source cluster. + singleHostClientOpts.SetMaxConnecting(1) + + singleHostClient, err := mongo.Connect(ctx, &singleHostClientOpts) + if err != nil { + return errors.Wrapf(err, "failed to connect to mongos host %#q", host) + } + + shardConnStr, err := getAnyExistingShardConnectionStr( + ctx, + l, + r, + singleHostClient, + ) + if err != nil { + return err + } + + err = r.RunForTransientErrorsOnly( + ctx, + l, + func(ri *retry.Info) error { + // Query a collection on the config server with linearizable read concern to advance the config + // server primary's majority-committed optime. This populates the $configOpTime. + opts := options.Database().SetReadConcern(readconcern.Linearizable()) + err := singleHostClient. + Database("admin", opts). + Collection("system.version"). + FindOne( + ctx, + bson.D{{"_id", "featureCompatibilityVersion"}}, + ). + Err() + if err != nil { + return errors.Wrap(err, "failed to query the config server") + } + + // Run `addShard` on an existing shard to force the mongos' ShardRegistry to refresh. This combined + // with the previous step guarantees that all shards are known to the mongos. + err = singleHostClient. + Database("admin"). + RunCommand(ctx, bson.D{{"addShard", shardConnStr}}). + Err() + if err != nil { + // TODO (REP-3952): Do this error check using the `shared` package. + if mmongo.ErrorHasCode(err, UnauthorizedErrCode) { + return errors.New( + "missing privileges to refresh mongos instances on the source; please restart " + + "migration-verifier with a URI that includes the `clusterManager` role", + ) + } + return errors.Wrap( + err, + "failed to execute `addShard` to force the mongos' ShardRegistry to refresh", + ) + } + + // We could alternatively run `flushRouterConfig: ` for each db, but that requires a + // listDatabases call. We should _never_ run `flushRouterConfig: .` because that + // would cause the mongos to no longer know whether the collection is sharded or not. See this + // document: https://docs.google.com/document/d/1C0EG2Qx2ECZbUsaNdGDTY-5JK0NISeo5_UT9oMG1dps/edit + // for more information. + err = singleHostClient. + Database("admin"). + RunCommand(ctx, bson.D{{"flushRouterConfig", 1}}). + Err() + if err != nil { + return errors.Wrap(err, "failed to flush the mongos config") + } + + return nil + }) + + if err != nil { + return err + } + + if err = singleHostClient.Disconnect(ctx); err != nil { + return errors.Wrap(err, "failed to gracefully disconnect from the mongos") + } + } + + l.Info(). + Strs("hosts", hosts). + Msgf("Successfully refreshed all %d mongos instance(s) on the source.", len(hosts)) + return nil +} + +// getAnyExistingShardConnectionStr will return the shard connection string of +// a shard in the current cluster. If the cluster is not sharded, +// an empty string and error will be returned. +// +// Note: this is a reimplementation of a method of the same name in mongosync. +func getAnyExistingShardConnectionStr( + ctx context.Context, + l *logger.Logger, + r retry.Retryer, + client *mongo.Client, +) (string, error) { + res, err := runListShards(ctx, l, r, client) + if err != nil { + return "", err + } + + doc, err := res.Raw() + if err != nil { + return "", err + } + + rawHost, lookupErr := doc.LookupErr("shards", "0", "host") + if lookupErr != nil { + return "", lookupErr + } + + shardConnStr, ok := rawHost.StringValueOK() + if !ok { + return "", errors.New("failed to convert rawHost to string") + } + + return shardConnStr, nil +} + +// runListShards returns the mongo.SingleResult from running the listShards command. +// +// Note: this is a reimplementation of a method of the same name in mongosync. +func runListShards( + ctx context.Context, + l *logger.Logger, + r retry.Retryer, + client *mongo.Client, +) (*mongo.SingleResult, error) { + var res *mongo.SingleResult + err := r.RunForTransientErrorsOnly( + ctx, + l, + func(_ *retry.Info) error { + res = client.Database("admin").RunCommand(ctx, bson.D{{"listShards", 1}}) + return res.Err() + }, + ) + return res, err +} diff --git a/internal/verifier/uri.go b/internal/verifier/uri.go new file mode 100644 index 00000000..25470af4 --- /dev/null +++ b/internal/verifier/uri.go @@ -0,0 +1,103 @@ +package verifier + +import ( + "context" + + "github.com/10gen/migration-verifier/internal/util" + "github.com/pkg/errors" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/x/mongo/driver/connstring" +) + +func (verifier *Verifier) SetSrcURI(ctx context.Context, uri string) error { + opts := verifier.getClientOpts(uri) + var err error + verifier.srcClient, err = mongo.Connect(ctx, opts) + if err != nil { + return errors.Wrapf(err, "failed to connect to source %#q", uri) + } + + clusterInfo, err := util.GetClusterInfo(ctx, verifier.srcClient) + if err != nil { + return errors.Wrap(err, "failed to read source build info") + } + + verifier.srcClusterInfo = &clusterInfo + + if clusterInfo.Topology == util.TopologySharded { + err := RefreshAllMongosInstances( + ctx, + verifier.logger, + opts, + ) + + if err != nil { + return errors.Wrap( + err, + "failed to refresh source mongos instances", + ) + } + } + + return checkURIAgainstServerVersion(uri, clusterInfo) +} + +func (verifier *Verifier) SetDstURI(ctx context.Context, uri string) error { + opts := verifier.getClientOpts(uri) + var err error + verifier.dstClient, err = mongo.Connect(ctx, opts) + if err != nil { + return errors.Wrapf(err, "failed to connect to destination %#q", uri) + } + + clusterInfo, err := util.GetClusterInfo(ctx, verifier.dstClient) + if err != nil { + return errors.Wrap(err, "failed to read destination build info") + } + + verifier.dstClusterInfo = &clusterInfo + + if clusterInfo.Topology == util.TopologySharded { + err := RefreshAllMongosInstances( + ctx, + verifier.logger, + opts, + ) + + if err != nil { + return errors.Wrap( + err, + "failed to refresh source mongos instances", + ) + } + } + + return checkURIAgainstServerVersion(uri, clusterInfo) +} + +func checkURIAgainstServerVersion(uri string, bi util.ClusterInfo) error { + if bi.VersionArray[0] >= 5 { + return nil + } + + cs, err := connstring.ParseAndValidate(uri) + + if err != nil { + return errors.Wrap(err, "failed to parse and validate connection string") + } + if cs == nil { + panic("parsed and validated connection string (" + uri + ") must not be nil") + } + + // migration-verifier disallows SRV strings for pre-v5 clusters for the + // same reason as mongosync’s embedded verifier: mongoses can be added + // dynamically, which means they could avoid the critical router-flush that + // SERVER-32198 necessitates for pre-v5 clusters. + if cs.Scheme == connstring.SchemeMongoDBSRV { + return errors.Errorf( + "SRV connection string is forbidden for pre-v5 clusters", + ) + } + + return nil +} diff --git a/main/migration_verifier.go b/main/migration_verifier.go index 8863bbba..f9209b46 100644 --- a/main/migration_verifier.go +++ b/main/migration_verifier.go @@ -199,7 +199,10 @@ func handleArgs(ctx context.Context, cCtx *cli.Context) (*verifier.Verifier, err verifierSettings.ReadConcernSetting = verifier.ReadConcernIgnore } - v := verifier.NewVerifier(verifierSettings) + logPath := cCtx.String(logPath) + + v := verifier.NewVerifier(verifierSettings, logPath) + err := v.SetSrcURI(ctx, cCtx.String(srcURI)) if err != nil { return nil, err @@ -232,8 +235,7 @@ func handleArgs(ctx context.Context, cCtx *cli.Context) (*verifier.Verifier, err } v.SetStartClean(cCtx.Bool(startClean)) - logPath := cCtx.String(logPath) - v.SetLogger(logPath) + if cCtx.Bool(verifyAll) { if len(cCtx.StringSlice(srcNamespace)) > 0 || len(cCtx.StringSlice(dstNamespace)) > 0 { return nil, errors.Errorf("Setting both verifyAll and explicit namespaces is not supported") diff --git a/mmongo/error.go b/mmongo/error.go new file mode 100644 index 00000000..6fd9f65a --- /dev/null +++ b/mmongo/error.go @@ -0,0 +1,14 @@ +package mmongo + +import ( + "github.com/pkg/errors" + "go.mongodb.org/mongo-driver/mongo" +) + +// ErrorHasCode returns true if (and only if) this error is a +// mongo.ServerError that contains the given error code. +func ErrorHasCode[T ~int](err error, code T) bool { + var serverError mongo.ServerError + + return errors.As(err, &serverError) && serverError.HasErrorCode(int(code)) +}