diff --git a/internal/partitions/partition.go b/internal/partitions/partition.go index 5c9d507b..024458e8 100644 --- a/internal/partitions/partition.go +++ b/internal/partitions/partition.go @@ -137,7 +137,7 @@ func (p *Partition) FindCmd( // (e.g. use the partitions on the source to read the destination for verification) // If the passed-in buildinfo indicates a mongodb version < 5.0, type bracketing is not used. // filterAndPredicates is a slice of filter criteria that's used to construct the "filter" field in the find option. -func (p *Partition) GetFindOptions(buildInfo *bson.M, filterAndPredicates bson.A) bson.D { +func (p *Partition) GetFindOptions(buildInfo *util.BuildInfo, filterAndPredicates bson.A) bson.D { if p == nil { if len(filterAndPredicates) > 0 { return bson.D{{"filter", bson.D{{"$and", filterAndPredicates}}}} @@ -160,16 +160,9 @@ func (p *Partition) GetFindOptions(buildInfo *bson.M, filterAndPredicates bson.A allowTypeBracketing := false if buildInfo != nil { allowTypeBracketing = true - versionArray, ok := (*buildInfo)["versionArray"].(bson.A) - //bson values are int32 or int64, never int. - if ok { - majorVersion, ok := versionArray[0].(int32) - if ok { - allowTypeBracketing = majorVersion < 5 - } else { - majorVersion64, _ := versionArray[0].(int64) - allowTypeBracketing = majorVersion64 < 5 - } + + if buildInfo.VersionArray != nil { + allowTypeBracketing = buildInfo.VersionArray[0] < 5 } } if !allowTypeBracketing { diff --git a/internal/partitions/partition_test.go b/internal/partitions/partition_test.go index e0cae05c..53d41a0a 100644 --- a/internal/partitions/partition_test.go +++ b/internal/partitions/partition_test.go @@ -77,43 +77,38 @@ func (suite *UnitTestSuite) TestVersioning() { filter := getFilterFromFindOptions(findOptions) suite.Require().Equal(expectedFilter, filter) - // 6.0 (int64) - findOptions = partition.GetFindOptions(&bson.M{"versionArray": bson.A{int64(6), int64(0), int64(0), int64(0)}}, nil) - filter = getFilterFromFindOptions(findOptions) - suite.Require().Equal(expectedFilter, filter) - // 6.0 - findOptions = partition.GetFindOptions(&bson.M{"versionArray": bson.A{int32(6), int32(0), int32(0), int32(0)}}, nil) + findOptions = partition.GetFindOptions(&util.BuildInfo{VersionArray: []int{6, 0, 0}}, nil) filter = getFilterFromFindOptions(findOptions) suite.Require().Equal(expectedFilter, filter) // 5.3.0.9 - findOptions = partition.GetFindOptions(&bson.M{"versionArray": bson.A{int32(5), int32(3), int32(0), int32(9)}}, nil) + findOptions = partition.GetFindOptions(&util.BuildInfo{VersionArray: []int{5, 3, 0, 9}}, nil) filter = getFilterFromFindOptions(findOptions) suite.Require().Equal(expectedFilter, filter) // 7.1.3.5 - findOptions = partition.GetFindOptions(&bson.M{"versionArray": bson.A{int32(7), int32(1), int32(3), int32(5)}}, nil) + findOptions = partition.GetFindOptions(&util.BuildInfo{VersionArray: []int{7, 1, 3, 5}}, nil) filter = getFilterFromFindOptions(findOptions) suite.Require().Equal(expectedFilter, filter) // 4.4 (int64) - findOptions = partition.GetFindOptions(&bson.M{"versionArray": bson.A{int64(4), int64(4), int64(0), int64(0)}}, nil) + findOptions = partition.GetFindOptions(&util.BuildInfo{VersionArray: []int{4, 4, 0, 0}}, nil) filter = getFilterFromFindOptions(findOptions) suite.Require().Equal(expectedFilterWithTypeBracketing, filter) // 4.4 - findOptions = partition.GetFindOptions(&bson.M{"versionArray": bson.A{int32(4), int32(4), int32(0), int32(0)}}, nil) + findOptions = partition.GetFindOptions(&util.BuildInfo{VersionArray: []int{4, 4, 0, 0}}, nil) filter = getFilterFromFindOptions(findOptions) suite.Require().Equal(expectedFilterWithTypeBracketing, filter) // 4.2 - findOptions = partition.GetFindOptions(&bson.M{"versionArray": bson.A{int32(4), int32(2), int32(0), int32(0)}}, nil) + findOptions = partition.GetFindOptions(&util.BuildInfo{VersionArray: []int{4, 2, 0, 0}}, nil) filter = getFilterFromFindOptions(findOptions) suite.Require().Equal(expectedFilterWithTypeBracketing, filter) // No version array -- assume old, require type bracketing. - findOptions = partition.GetFindOptions(&bson.M{"notVersionArray": bson.A{6, int32(0), int32(0), int32(0)}}, nil) + findOptions = partition.GetFindOptions(&util.BuildInfo{}, nil) filter = getFilterFromFindOptions(findOptions) suite.Require().Equal(expectedFilterWithTypeBracketing, filter) } diff --git a/internal/util/buildinfo.go b/internal/util/buildinfo.go new file mode 100644 index 00000000..2338e7dc --- /dev/null +++ b/internal/util/buildinfo.go @@ -0,0 +1,31 @@ +package util + +import ( + "context" + + "github.com/10gen/migration-verifier/mbson" + "github.com/pkg/errors" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" +) + +type BuildInfo struct { + VersionArray []int +} + +func GetBuildInfo(ctx context.Context, client *mongo.Client) (BuildInfo, error) { + commandResult := client.Database("admin").RunCommand(ctx, bson.D{{"buildinfo", 1}}) + + rawResp, err := commandResult.Raw() + if err != nil { + return BuildInfo{}, errors.Wrap(err, "failed to fetch build info") + } + + bi := BuildInfo{} + _, err = mbson.RawLookup(rawResp, &bi.VersionArray, "versionArray") + if err != nil { + return BuildInfo{}, errors.Wrap(err, "failed to decode build info version array") + } + + return bi, nil +} diff --git a/internal/verifier/change_stream.go b/internal/verifier/change_stream.go index 372de6bc..dcd8122a 100644 --- a/internal/verifier/change_stream.go +++ b/internal/verifier/change_stream.go @@ -293,6 +293,10 @@ func (verifier *Verifier) StartChangeStream(ctx context.Context) error { SetMaxAwaitTime(1 * time.Second). SetFullDocument(options.UpdateLookup) + if verifier.srcBuildInfo.VersionArray[0] >= 6 { + opts = opts.SetCustomPipeline(bson.M{"showExpandedEvents": true}) + } + savedResumeToken, err := verifier.loadChangeStreamResumeToken(ctx) if err != nil { return errors.Wrap(err, "failed to load persisted change stream resume token") diff --git a/internal/verifier/change_stream_test.go b/internal/verifier/change_stream_test.go index 71424e0f..da2153d8 100644 --- a/internal/verifier/change_stream_test.go +++ b/internal/verifier/change_stream_test.go @@ -5,6 +5,8 @@ import ( "testing" "time" + "github.com/10gen/migration-verifier/internal/util" + "github.com/10gen/migration-verifier/mslices" "github.com/pkg/errors" "github.com/samber/lo" "github.com/stretchr/testify/require" @@ -35,6 +37,12 @@ func TestChangeStreamFilter(t *testing.T) { // terminates that verifier, updates the source cluster, starts a new // verifier with change stream, and confirms that things look as they should. func (suite *IntegrationTestSuite) TestChangeStreamResumability() { + suite.Require().NoError( + suite.srcMongoClient. + Database(suite.DBNameForTest()). + CreateCollection(suite.Context(), "testColl"), + ) + func() { verifier1 := suite.BuildVerifier() ctx, cancel := context.WithCancel(context.Background()) @@ -43,7 +51,7 @@ func (suite *IntegrationTestSuite) TestChangeStreamResumability() { suite.Require().NoError(err) }() - ctx, cancel := context.WithCancel(context.Background()) + ctx, cancel := context.WithCancel(suite.Context()) defer cancel() _, err := suite.srcMongoClient. @@ -219,19 +227,26 @@ func (suite *IntegrationTestSuite) TestNoStartAtTime() { } func (suite *IntegrationTestSuite) TestWithChangeEventsBatching() { - verifier := suite.BuildVerifier() + ctx := suite.Context() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() + db := suite.srcMongoClient.Database(suite.DBNameForTest()) + coll1 := db.Collection("testColl1") + coll2 := db.Collection("testColl2") + + for _, coll := range mslices.Of(coll1, coll2) { + suite.Require().NoError(db.CreateCollection(ctx, coll.Name())) + } + + verifier := suite.BuildVerifier() suite.Require().NoError(verifier.StartChangeStream(ctx)) - _, err := suite.srcMongoClient.Database("testDb").Collection("testColl1").InsertOne(ctx, bson.D{{"_id", 1}}) + _, err := coll1.InsertOne(ctx, bson.D{{"_id", 1}}) suite.Require().NoError(err) - _, err = suite.srcMongoClient.Database("testDb").Collection("testColl1").InsertOne(ctx, bson.D{{"_id", 2}}) + _, err = coll1.InsertOne(ctx, bson.D{{"_id", 2}}) suite.Require().NoError(err) - _, err = suite.srcMongoClient.Database("testDb").Collection("testColl2").InsertOne(ctx, bson.D{{"_id", 1}}) + _, err = coll2.InsertOne(ctx, bson.D{{"_id", 1}}) suite.Require().NoError(err) var rechecks []bson.M @@ -245,6 +260,7 @@ func (suite *IntegrationTestSuite) TestWithChangeEventsBatching() { 500*time.Millisecond, "the verifier should flush a recheck doc after a batch", ) + } func (suite *IntegrationTestSuite) TestManyInsertsBeforeWritesOff() { @@ -304,3 +320,40 @@ func (suite *IntegrationTestSuite) testInsertsBeforeWritesOff(docsCount int) { suite.Assert().Equal(docsCount, totalFailed, "all source docs should be missing") } + +func (suite *IntegrationTestSuite) TestCreateForbidden() { + ctx := suite.Context() + buildInfo, err := util.GetBuildInfo(ctx, suite.srcMongoClient) + suite.Require().NoError(err) + + if buildInfo.VersionArray[0] < 6 { + suite.T().Skipf("This test requires server v6+. (Found: %v)", buildInfo.VersionArray) + } + + verifier := suite.BuildVerifier() + + // start verifier + verifierRunner := RunVerifierCheck(suite.Context(), suite.T(), verifier) + + // wait for generation 0 to end + verifierRunner.AwaitGenerationEnd() + + db := suite.srcMongoClient.Database(suite.DBNameForTest()) + coll := db.Collection("mycoll") + suite.Require().NoError( + db.CreateCollection(ctx, coll.Name()), + ) + + // The error from the create event will come either at WritesOff + // or when we finalize the change stream. + err = verifier.WritesOff(ctx) + if err == nil { + err = verifierRunner.Await() + } + + suite.Require().Error(err, "should detect forbidden create event") + + eventErr := UnknownEventError{} + suite.Require().ErrorAs(err, &eventErr) + suite.Assert().Equal("create", eventErr.Event.OpType) +} diff --git a/internal/verifier/check.go b/internal/verifier/check.go index 0ae007c5..6be24116 100644 --- a/internal/verifier/check.go +++ b/internal/verifier/check.go @@ -236,6 +236,9 @@ func (verifier *Verifier) CheckDriver(ctx context.Context, filter map[string]any // paying attention. Also, this should not matter too much because any failures will be // caught again on the next iteration. if verifier.writesOff { + verifier.logger.Debug(). + Msg("Waiting for change stream to end.") + // It's necessary to wait for the change stream to finish before incrementing the // generation number, or the last changes will not be checked. verifier.mux.Unlock() diff --git a/internal/verifier/migration_verifier.go b/internal/verifier/migration_verifier.go index bdf71fc7..7a6b0d39 100644 --- a/internal/verifier/migration_verifier.go +++ b/internal/verifier/migration_verifier.go @@ -89,8 +89,8 @@ type Verifier struct { metaClient *mongo.Client srcClient *mongo.Client dstClient *mongo.Client - srcBuildInfo *bson.M - dstBuildInfo *bson.M + srcBuildInfo *util.BuildInfo + dstBuildInfo *util.BuildInfo numWorkers int failureDisplaySize int64 @@ -312,10 +312,16 @@ func (verifier *Verifier) SetSrcURI(ctx context.Context, uri string) error { var err error verifier.srcClient, err = mongo.Connect(ctx, opts) if err != nil { - return err + return errors.Wrapf(err, "failed to connect to source %#q", uri) } - verifier.srcBuildInfo, err = getBuildInfo(ctx, verifier.srcClient) - return err + + buildInfo, err := util.GetBuildInfo(ctx, verifier.srcClient) + if err != nil { + return errors.Wrap(err, "failed to read source build info") + } + + verifier.srcBuildInfo = &buildInfo + return nil } func (verifier *Verifier) SetDstURI(ctx context.Context, uri string) error { @@ -323,10 +329,16 @@ func (verifier *Verifier) SetDstURI(ctx context.Context, uri string) error { var err error verifier.dstClient, err = mongo.Connect(ctx, opts) if err != nil { - return err + return errors.Wrapf(err, "failed to connect to destination %#q", uri) } - verifier.dstBuildInfo, err = getBuildInfo(ctx, verifier.dstClient) - return err + + buildInfo, err := util.GetBuildInfo(ctx, verifier.dstClient) + if err != nil { + return errors.Wrap(err, "failed to read destination build info") + } + + verifier.dstBuildInfo = &buildInfo + return nil } func (verifier *Verifier) SetServerPort(port int) { @@ -457,7 +469,7 @@ func (verifier *Verifier) maybeAppendGlobalFilterToPredicates(predicates bson.A) return append(predicates, verifier.globalFilter) } -func (verifier *Verifier) getDocumentsCursor(ctx context.Context, collection *mongo.Collection, buildInfo *bson.M, +func (verifier *Verifier) getDocumentsCursor(ctx context.Context, collection *mongo.Collection, buildInfo *util.BuildInfo, startAtTs *primitive.Timestamp, task *VerificationTask) (*mongo.Cursor, error) { var findOptions bson.D runCommandOptions := options.RunCmd() @@ -1510,16 +1522,3 @@ func (verifier *Verifier) getNamespaces(ctx context.Context, fieldName string) ( } return namespaces, nil } - -func getBuildInfo(ctx context.Context, client *mongo.Client) (*bson.M, error) { - commandResult := client.Database("admin").RunCommand(ctx, bson.D{{"buildinfo", 1}}) - if commandResult.Err() != nil { - return nil, commandResult.Err() - } - var buildInfoMap bson.M - err := commandResult.Decode(&buildInfoMap) - if err != nil { - return nil, err - } - return &buildInfoMap, nil -}