diff --git a/internal/verifier/change_stream.go b/internal/verifier/change_stream.go index 13ea7918..60c282a1 100644 --- a/internal/verifier/change_stream.go +++ b/internal/verifier/change_stream.go @@ -6,6 +6,7 @@ import ( "time" "github.com/10gen/migration-verifier/internal/keystring" + "github.com/10gen/migration-verifier/internal/logger" "github.com/10gen/migration-verifier/internal/retry" "github.com/10gen/migration-verifier/internal/util" "github.com/pkg/errors" @@ -52,8 +53,79 @@ func (uee UnknownEventError) Error() string { return fmt.Sprintf("received event with unknown optype: %+v", uee.Event) } +type ChangeStreamReader struct { + readerType whichCluster + + lastChangeEventTime *primitive.Timestamp + logger *logger.Logger + namespaces []string + + metaDB *mongo.Database + watcherClient *mongo.Client + clusterInfo util.ClusterInfo + + changeStreamRunning bool + changeEventBatchChan chan []ParsedEvent + writesOffTsChan chan primitive.Timestamp + errChan chan error + doneChan chan struct{} + + startAtTs *primitive.Timestamp +} + +func (verifier *Verifier) initializeChangeStreamReaders() { + verifier.srcChangeStreamReader = &ChangeStreamReader{ + readerType: src, + logger: verifier.logger, + namespaces: verifier.srcNamespaces, + metaDB: verifier.metaClient.Database(verifier.metaDBName), + watcherClient: verifier.srcClient, + clusterInfo: *verifier.srcClusterInfo, + changeStreamRunning: false, + changeEventBatchChan: make(chan []ParsedEvent), + writesOffTsChan: make(chan primitive.Timestamp), + errChan: make(chan error), + doneChan: make(chan struct{}), + } + verifier.dstChangeStreamReader = &ChangeStreamReader{ + readerType: dst, + logger: verifier.logger, + namespaces: verifier.dstNamespaces, + metaDB: verifier.metaClient.Database(verifier.metaDBName), + watcherClient: verifier.dstClient, + clusterInfo: *verifier.dstClusterInfo, + changeStreamRunning: false, + changeEventBatchChan: make(chan []ParsedEvent), + writesOffTsChan: make(chan primitive.Timestamp), + errChan: make(chan error), + doneChan: make(chan struct{}), + } +} + +// StartChangeEventHandler starts a goroutine that handles change event batches from the reader. +// It needs to be started after the reader starts. +func (verifier *Verifier) StartChangeEventHandler(ctx context.Context, reader *ChangeStreamReader) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + case batch, more := <-reader.changeEventBatchChan: + if !more { + verifier.logger.Debug().Msgf("Change Event Batch Channel has been closed by %s, returning...", reader) + return nil + } + verifier.logger.Trace().Msgf("Verifier is handling a change event batch from %s: %v", reader, batch) + err := verifier.HandleChangeStreamEvents(ctx, batch, reader.readerType) + if err != nil { + reader.errChan <- err + return err + } + } + } +} + // HandleChangeStreamEvents performs the necessary work for change stream events after receiving a batch. -func (verifier *Verifier) HandleChangeStreamEvents(ctx context.Context, batch []ParsedEvent) error { +func (verifier *Verifier) HandleChangeStreamEvents(ctx context.Context, batch []ParsedEvent, eventOrigin whichCluster) error { if len(batch) == 0 { return nil } @@ -64,11 +136,6 @@ func (verifier *Verifier) HandleChangeStreamEvents(ctx context.Context, batch [] dataSizes := make([]int, len(batch)) for i, changeEvent := range batch { - if changeEvent.ClusterTime != nil && - (verifier.lastChangeEventTime == nil || - verifier.lastChangeEventTime.Before(*changeEvent.ClusterTime)) { - verifier.lastChangeEventTime = changeEvent.ClusterTime - } switch changeEvent.OpType { case "delete": fallthrough @@ -80,8 +147,34 @@ func (verifier *Verifier) HandleChangeStreamEvents(ctx context.Context, batch [] if err := verifier.eventRecorder.AddEvent(&changeEvent); err != nil { return errors.Wrapf(err, "failed to augment stats with change event (%+v)", changeEvent) } - dbNames[i] = changeEvent.Ns.DB - collNames[i] = changeEvent.Ns.Coll + + var srcDBName, srcCollName string + + // Recheck Docs are keyed by source namespaces. + // We need to retrieve the source namespaces if change events are from the destination. + switch eventOrigin { + case dst: + if verifier.nsMap.Len() == 0 { + // Namespace is not remapped. Source namespace is the same as the destination. + srcDBName = changeEvent.Ns.DB + srcCollName = changeEvent.Ns.Coll + } else { + dstNs := fmt.Sprintf("%s.%s", changeEvent.Ns.DB, changeEvent.Ns.Coll) + srcNs, exist := verifier.nsMap.GetSrcNamespace(dstNs) + if !exist { + return errors.Errorf("no source namespace corresponding to the destination namepsace %s", dstNs) + } + srcDBName, srcCollName = SplitNamespace(srcNs) + } + case src: + srcDBName = changeEvent.Ns.DB + srcCollName = changeEvent.Ns.Coll + default: + return errors.Errorf("unknown event origin: %s", eventOrigin) + } + + dbNames[i] = srcDBName + collNames[i] = srcCollName docIDs[i] = changeEvent.DocKey.ID if changeEvent.FullDocument == nil { @@ -116,17 +209,16 @@ func (verifier *Verifier) HandleChangeStreamEvents(ctx context.Context, batch [] // and omit fullDocument, but $bsonSize was new in MongoDB 4.4, and we still // want to verify migrations from 4.2. fullDocument is unlikely to be a // bottleneck anyway. -func (verifier *Verifier) GetChangeStreamFilter() (pipeline mongo.Pipeline) { - - if len(verifier.srcNamespaces) == 0 { +func (csr *ChangeStreamReader) GetChangeStreamFilter() (pipeline mongo.Pipeline) { + if len(csr.namespaces) == 0 { pipeline = mongo.Pipeline{ {{"$match", bson.D{ - {"ns.db", bson.D{{"$ne", verifier.metaDBName}}}, + {"ns.db", bson.D{{"$ne", csr.metaDB.Name()}}}, }}}, } } else { filter := []bson.D{} - for _, ns := range verifier.srcNamespaces { + for _, ns := range csr.namespaces { db, coll := SplitNamespace(ns) filter = append(filter, bson.D{ {"ns", bson.D{ @@ -158,7 +250,7 @@ func (verifier *Verifier) GetChangeStreamFilter() (pipeline mongo.Pipeline) { // the verifier will enqueue rechecks from those post-writesOff events. This // is unideal but shouldn’t impede correctness since post-writesOff events // shouldn’t really happen anyway by definition. -func (verifier *Verifier) readAndHandleOneChangeEventBatch( +func (csr *ChangeStreamReader) readAndHandleOneChangeEventBatch( ctx context.Context, ri *retry.FuncInfo, cs *mongo.ChangeStream, @@ -185,6 +277,15 @@ func (verifier *Verifier) readAndHandleOneChangeEventBatch( return errors.Wrapf(err, "failed to decode change event to %T", changeEventBatch[eventsRead]) } + // This only logs in tests. + csr.logger.Trace().Interface("event", changeEventBatch[eventsRead]).Msgf("%s received a change event", csr) + + if changeEventBatch[eventsRead].ClusterTime != nil && + (csr.lastChangeEventTime == nil || + csr.lastChangeEventTime.Before(*changeEventBatch[eventsRead].ClusterTime)) { + csr.lastChangeEventTime = changeEventBatch[eventsRead].ClusterTime + } + eventsRead++ } @@ -194,15 +295,11 @@ func (verifier *Verifier) readAndHandleOneChangeEventBatch( return nil } - err := verifier.HandleChangeStreamEvents(ctx, changeEventBatch) - if err != nil { - return errors.Wrap(err, "failed to handle change events") - } - + csr.changeEventBatchChan <- changeEventBatch return nil } -func (verifier *Verifier) iterateChangeStream( +func (csr *ChangeStreamReader) iterateChangeStream( ctx context.Context, ri *retry.FuncInfo, cs *mongo.ChangeStream, @@ -214,7 +311,7 @@ func (verifier *Verifier) iterateChangeStream( return nil } - err := verifier.persistChangeStreamResumeToken(ctx, cs) + err := csr.persistChangeStreamResumeToken(ctx, cs) if err == nil { lastPersistedTime = time.Now() } @@ -232,17 +329,18 @@ func (verifier *Verifier) iterateChangeStream( case <-ctx.Done(): return ctx.Err() - // If the changeStreamEnderChan has a message, the user has indicated that - // source writes are ended. This means we should exit rather than continue - // reading the change stream since there should be no more events. - case writesOffTs := <-verifier.changeStreamWritesOffTsChan: - verifier.logger.Debug(). + // If the ChangeStreamEnderChan has a message, the user has indicated that + // source writes are ended and the migration tool is finished / committed. + // This means we should exit rather than continue reading the change stream + // since there should be no more events. + case writesOffTs := <-csr.writesOffTsChan: + csr.logger.Debug(). Interface("writesOffTimestamp", writesOffTs). - Msg("Change stream thread received writesOff timestamp. Finalizing change stream.") + Msgf("%s thread received writesOff timestamp. Finalizing change stream.", csr) gotwritesOffTimestamp = true - // Read all change events until the source reports no events. + // Read change events until the stream reaches the writesOffTs. // (i.e., the `getMore` call returns empty) for { var curTs primitive.Timestamp @@ -254,15 +352,15 @@ func (verifier *Verifier) iterateChangeStream( // writesOffTs never refers to a real event, // so we can stop once curTs >= writesOffTs. if !curTs.Before(writesOffTs) { - verifier.logger.Debug(). + csr.logger.Debug(). Interface("currentTimestamp", curTs). Interface("writesOffTimestamp", writesOffTs). - Msg("Change stream has reached the writesOff timestamp. Shutting down.") + Msgf("%s has reached the writesOff timestamp. Shutting down.", csr) break } - err = verifier.readAndHandleOneChangeEventBatch(ctx, ri, cs) + err = csr.readAndHandleOneChangeEventBatch(ctx, ri, cs) if err != nil { return err @@ -270,7 +368,7 @@ func (verifier *Verifier) iterateChangeStream( } default: - err = verifier.readAndHandleOneChangeEventBatch(ctx, ri, cs) + err = csr.readAndHandleOneChangeEventBatch(ctx, ri, cs) if err == nil { err = persistResumeTokenIfNeeded() @@ -282,24 +380,22 @@ func (verifier *Verifier) iterateChangeStream( } if gotwritesOffTimestamp { - verifier.mux.Lock() - verifier.changeStreamRunning = false - if verifier.lastChangeEventTime != nil { - verifier.srcStartAtTs = verifier.lastChangeEventTime + csr.changeStreamRunning = false + if csr.lastChangeEventTime != nil { + csr.startAtTs = csr.lastChangeEventTime } - verifier.mux.Unlock() // since we have started Recheck, we must signal that we have // finished the change stream changes so that Recheck can continue. - verifier.changeStreamDoneChan <- struct{}{} + csr.doneChan <- struct{}{} break } } - infoLog := verifier.logger.Info() - if verifier.lastChangeEventTime == nil { + infoLog := csr.logger.Info() + if csr.lastChangeEventTime == nil { infoLog = infoLog.Str("lastEventTime", "none") } else { - infoLog = infoLog.Interface("lastEventTime", *verifier.lastChangeEventTime) + infoLog = infoLog.Interface("lastEventTime", *csr.lastChangeEventTime) } infoLog.Msg("Change stream is done.") @@ -307,34 +403,34 @@ func (verifier *Verifier) iterateChangeStream( return nil } -func (verifier *Verifier) createChangeStream( +func (csr *ChangeStreamReader) createChangeStream( ctx context.Context, ) (*mongo.ChangeStream, primitive.Timestamp, error) { - pipeline := verifier.GetChangeStreamFilter() + pipeline := csr.GetChangeStreamFilter() opts := options.ChangeStream(). SetMaxAwaitTime(1 * time.Second). SetFullDocument(options.UpdateLookup) - if verifier.srcClusterInfo.VersionArray[0] >= 6 { + if csr.clusterInfo.VersionArray[0] >= 6 { opts = opts.SetCustomPipeline(bson.M{"showExpandedEvents": true}) } - savedResumeToken, err := verifier.loadChangeStreamResumeToken(ctx) + savedResumeToken, err := csr.loadChangeStreamResumeToken(ctx) if err != nil { return nil, primitive.Timestamp{}, errors.Wrap(err, "failed to load persisted change stream resume token") } - csStartLogEvent := verifier.logger.Info() + csStartLogEvent := csr.logger.Info() if savedResumeToken != nil { logEvent := csStartLogEvent. - Stringer("resumeToken", savedResumeToken) + Stringer(csr.resumeTokenDocID(), savedResumeToken) ts, err := extractTimestampFromResumeToken(savedResumeToken) if err == nil { logEvent = addTimestampToLogEvent(ts, logEvent) } else { - verifier.logger.Warn(). + csr.logger.Warn(). Err(err). Msg("Failed to extract timestamp from persisted resume token.") } @@ -343,25 +439,25 @@ func (verifier *Verifier) createChangeStream( opts = opts.SetStartAfter(savedResumeToken) } else { - csStartLogEvent.Msg("Starting change stream from current source cluster time.") + csStartLogEvent.Msgf("Starting change stream from current %s cluster time.", csr.readerType) } - sess, err := verifier.srcClient.StartSession() + sess, err := csr.watcherClient.StartSession() if err != nil { return nil, primitive.Timestamp{}, errors.Wrap(err, "failed to start session") } sctx := mongo.NewSessionContext(ctx, sess) - srcChangeStream, err := verifier.srcClient.Watch(sctx, pipeline, opts) + changeStream, err := csr.watcherClient.Watch(sctx, pipeline, opts) if err != nil { return nil, primitive.Timestamp{}, errors.Wrap(err, "failed to open change stream") } - err = verifier.persistChangeStreamResumeToken(ctx, srcChangeStream) + err = csr.persistChangeStreamResumeToken(ctx, changeStream) if err != nil { return nil, primitive.Timestamp{}, err } - startTs, err := extractTimestampFromResumeToken(srcChangeStream.ResumeToken()) + startTs, err := extractTimestampFromResumeToken(changeStream.ResumeToken()) if err != nil { return nil, primitive.Timestamp{}, errors.Wrap(err, "failed to extract timestamp from change stream's resume token") } @@ -378,11 +474,11 @@ func (verifier *Verifier) createChangeStream( startTs = clusterTime } - return srcChangeStream, startTs, nil + return changeStream, startTs, nil } // StartChangeStream starts the change stream. -func (verifier *Verifier) StartChangeStream(ctx context.Context) error { +func (csr *ChangeStreamReader) StartChangeStream(ctx context.Context) error { // This channel holds the first change stream creation's result, whether // success or failure. Rather than using a Result we could make separate // Timestamp and error channels, but the single channel is cleaner since @@ -390,6 +486,10 @@ func (verifier *Verifier) StartChangeStream(ctx context.Context) error { initialCreateResultChan := make(chan mo.Result[primitive.Timestamp]) go func() { + // Closing changeEventBatchChan at the end of change stream goroutine + // notifies the verifier's change event handler to exit. + defer close(csr.changeEventBatchChan) + retryer := retry.New(retry.DefaultDurationLimit) retryer = retryer.WithErrorCodes(util.CursorKilled) @@ -397,9 +497,9 @@ func (verifier *Verifier) StartChangeStream(ctx context.Context) error { err := retryer.Run( ctx, - verifier.logger, + csr.logger, func(ctx context.Context, ri *retry.FuncInfo) error { - srcChangeStream, startTs, err := verifier.createChangeStream(ctx) + changeStream, startTs, err := csr.createChangeStream(ctx) if err != nil { if parentThreadWaiting { initialCreateResultChan <- mo.Err[primitive.Timestamp](err) @@ -409,7 +509,7 @@ func (verifier *Verifier) StartChangeStream(ctx context.Context) error { return err } - defer srcChangeStream.Close(ctx) + defer changeStream.Close(ctx) if parentThreadWaiting { initialCreateResultChan <- mo.Ok(startTs) @@ -417,15 +517,15 @@ func (verifier *Verifier) StartChangeStream(ctx context.Context) error { parentThreadWaiting = false } - return verifier.iterateChangeStream(ctx, ri, srcChangeStream) + return csr.iterateChangeStream(ctx, ri, changeStream) }, ) if err != nil { // NB: This failure always happens after the initial change stream // creation. - verifier.changeStreamErrChan <- err - close(verifier.changeStreamErrChan) + csr.errChan <- err + close(csr.errChan) } }() @@ -436,11 +536,9 @@ func (verifier *Verifier) StartChangeStream(ctx context.Context) error { return err } - verifier.srcStartAtTs = &startTs + csr.startAtTs = &startTs - verifier.mux.Lock() - verifier.changeStreamRunning = true - verifier.mux.Unlock() + csr.changeStreamRunning = true return nil } @@ -451,16 +549,16 @@ func addTimestampToLogEvent(ts primitive.Timestamp, event *zerolog.Event) *zerol Time("time", time.Unix(int64(ts.T), int64(0))) } -func (v *Verifier) getChangeStreamMetadataCollection() *mongo.Collection { - return v.metaClient.Database(v.metaDBName).Collection(metadataChangeStreamCollectionName) +func (csr *ChangeStreamReader) getChangeStreamMetadataCollection() *mongo.Collection { + return csr.metaDB.Collection(metadataChangeStreamCollectionName) } -func (verifier *Verifier) loadChangeStreamResumeToken(ctx context.Context) (bson.Raw, error) { - coll := verifier.getChangeStreamMetadataCollection() +func (csr *ChangeStreamReader) loadChangeStreamResumeToken(ctx context.Context) (bson.Raw, error) { + coll := csr.getChangeStreamMetadataCollection() token, err := coll.FindOne( ctx, - bson.D{{"_id", "resumeToken"}}, + bson.D{{"_id", csr.resumeTokenDocID()}}, ).Raw() if errors.Is(err, mongo.ErrNoDocuments) { @@ -470,13 +568,28 @@ func (verifier *Verifier) loadChangeStreamResumeToken(ctx context.Context) (bson return token, err } -func (verifier *Verifier) persistChangeStreamResumeToken(ctx context.Context, cs *mongo.ChangeStream) error { +func (csr *ChangeStreamReader) String() string { + return fmt.Sprintf("%s change stream reader", csr.readerType) +} + +func (csr *ChangeStreamReader) resumeTokenDocID() string { + switch csr.readerType { + case src: + return "srcResumeToken" + case dst: + return "dstResumeToken" + default: + panic("unknown readerType: " + csr.readerType) + } +} + +func (csr *ChangeStreamReader) persistChangeStreamResumeToken(ctx context.Context, cs *mongo.ChangeStream) error { token := cs.ResumeToken() - coll := verifier.getChangeStreamMetadataCollection() + coll := csr.getChangeStreamMetadataCollection() _, err := coll.ReplaceOne( ctx, - bson.D{{"_id", "resumeToken"}}, + bson.D{{"_id", csr.resumeTokenDocID()}}, token, options.Replace().SetUpsert(true), ) @@ -484,16 +597,16 @@ func (verifier *Verifier) persistChangeStreamResumeToken(ctx context.Context, cs if err == nil { ts, err := extractTimestampFromResumeToken(token) - logEvent := verifier.logger.Debug() + logEvent := csr.logger.Debug() if err == nil { logEvent = addTimestampToLogEvent(ts, logEvent) } else { - verifier.logger.Warn().Err(err). + csr.logger.Warn().Err(err). Msg("failed to extract resume token timestamp") } - logEvent.Msg("Persisted change stream resume token.") + logEvent.Msgf("Persisted %s's resume token.", csr) return nil } diff --git a/internal/verifier/change_stream_test.go b/internal/verifier/change_stream_test.go index 1b1aca0a..105046c1 100644 --- a/internal/verifier/change_stream_test.go +++ b/internal/verifier/change_stream_test.go @@ -3,7 +3,6 @@ package verifier import ( "context" "strings" - "testing" "time" "github.com/10gen/migration-verifier/internal/testutil" @@ -11,7 +10,6 @@ import ( "github.com/10gen/migration-verifier/mslices" "github.com/pkg/errors" "github.com/samber/lo" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" @@ -20,18 +18,17 @@ import ( "go.mongodb.org/mongo-driver/mongo/readconcern" ) -func TestChangeStreamFilter(t *testing.T) { - verifier := Verifier{} - verifier.SetMetaDBName("metadb") - assert.Contains(t, - verifier.GetChangeStreamFilter(), +func (suite *IntegrationTestSuite) TestChangeStreamFilter() { + verifier := suite.BuildVerifier() + suite.Assert().Contains( + verifier.srcChangeStreamReader.GetChangeStreamFilter(), bson.D{ - {"$match", bson.D{{"ns.db", bson.D{{"$ne", "metadb"}}}}}, + {"$match", bson.D{{"ns.db", bson.D{{"$ne", metaDBName}}}}}, }, ) - verifier.srcNamespaces = []string{"foo.bar", "foo.baz", "test.car", "test.chaz"} - assert.Contains(t, - verifier.GetChangeStreamFilter(), + verifier.srcChangeStreamReader.namespaces = []string{"foo.bar", "foo.baz", "test.car", "test.chaz"} + suite.Assert().Contains( + verifier.srcChangeStreamReader.GetChangeStreamFilter(), bson.D{{"$match", bson.D{ {"$or", []bson.D{ {{"ns", bson.D{{"db", "foo"}, {"coll", "bar"}}}}, @@ -43,6 +40,18 @@ func TestChangeStreamFilter(t *testing.T) { ) } +func (suite *IntegrationTestSuite) startSrcChangeStreamReaderAndHandler(ctx context.Context, verifier *Verifier) { + err := verifier.srcChangeStreamReader.StartChangeStream(ctx) + suite.Require().NoError(err) + go func() { + err := verifier.StartChangeEventHandler(ctx, verifier.srcChangeStreamReader) + if errors.Is(err, context.Canceled) { + return + } + suite.Require().NoError(err) + }() +} + // TestChangeStreamResumability creates a verifier, starts its change stream, // terminates that verifier, updates the source cluster, starts a new // verifier with change stream, and confirms that things look as they should. @@ -57,8 +66,7 @@ func (suite *IntegrationTestSuite) TestChangeStreamResumability() { verifier1 := suite.BuildVerifier() ctx, cancel := context.WithCancel(suite.Context()) defer cancel() - err := verifier1.StartChangeStream(ctx) - suite.Require().NoError(err) + suite.startSrcChangeStreamReaderAndHandler(ctx, verifier1) }() ctx, cancel := context.WithCancel(suite.Context()) @@ -82,13 +90,12 @@ func (suite *IntegrationTestSuite) TestChangeStreamResumability() { newTime := suite.getClusterTime(ctx, suite.srcMongoClient) - err = verifier2.StartChangeStream(ctx) - suite.Require().NoError(err) + suite.startSrcChangeStreamReaderAndHandler(ctx, verifier2) - suite.Require().NotNil(verifier2.srcStartAtTs) + suite.Require().NotNil(verifier2.srcChangeStreamReader.startAtTs) suite.Assert().False( - verifier2.srcStartAtTs.After(newTime), + verifier2.srcChangeStreamReader.startAtTs.After(newTime), "verifier2's change stream should be no later than this new session", ) @@ -156,12 +163,11 @@ func (suite *IntegrationTestSuite) TestStartAtTimeNoChanges() { suite.Require().NoError(err) origStartTs := sess.OperationTime() suite.Require().NotNil(origStartTs) - err = verifier.StartChangeStream(ctx) - suite.Require().NoError(err) - suite.Require().Equal(verifier.srcStartAtTs, origStartTs) - verifier.changeStreamWritesOffTsChan <- *origStartTs - <-verifier.changeStreamDoneChan - suite.Require().Equal(verifier.srcStartAtTs, origStartTs) + suite.startSrcChangeStreamReaderAndHandler(ctx, verifier) + suite.Require().Equal(verifier.srcChangeStreamReader.startAtTs, origStartTs) + verifier.srcChangeStreamReader.writesOffTsChan <- *origStartTs + <-verifier.srcChangeStreamReader.doneChan + suite.Require().Equal(verifier.srcChangeStreamReader.startAtTs, origStartTs) } func (suite *IntegrationTestSuite) TestStartAtTimeWithChanges() { @@ -176,13 +182,12 @@ func (suite *IntegrationTestSuite) TestStartAtTimeWithChanges() { origSessionTime := sess.OperationTime() suite.Require().NotNil(origSessionTime) - err = verifier.StartChangeStream(ctx) - suite.Require().NoError(err) + suite.startSrcChangeStreamReaderAndHandler(ctx, verifier) // srcStartAtTs derives from the change stream’s resume token, which can // postdate our session time but should not precede it. suite.Require().False( - verifier.srcStartAtTs.Before(*origSessionTime), + verifier.srcChangeStreamReader.startAtTs.Before(*origSessionTime), "srcStartAtTs should be >= the insert’s optime", ) @@ -206,12 +211,12 @@ func (suite *IntegrationTestSuite) TestStartAtTimeWithChanges() { "session time after events should exceed the original", ) - verifier.changeStreamWritesOffTsChan <- *postEventsSessionTime - <-verifier.changeStreamDoneChan + verifier.srcChangeStreamReader.writesOffTsChan <- *postEventsSessionTime + <-verifier.srcChangeStreamReader.doneChan suite.Assert().Equal( *postEventsSessionTime, - *verifier.srcStartAtTs, + *verifier.srcChangeStreamReader.startAtTs, "verifier.srcStartAtTs should now be our session timestamp", ) } @@ -227,10 +232,9 @@ func (suite *IntegrationTestSuite) TestNoStartAtTime() { suite.Require().NoError(err) origStartTs := sess.OperationTime() suite.Require().NotNil(origStartTs) - err = verifier.StartChangeStream(ctx) - suite.Require().NoError(err) - suite.Require().NotNil(verifier.srcStartAtTs) - suite.Require().LessOrEqual(origStartTs.Compare(*verifier.srcStartAtTs), 0) + suite.startSrcChangeStreamReaderAndHandler(ctx, verifier) + suite.Require().NotNil(verifier.srcChangeStreamReader.startAtTs) + suite.Require().LessOrEqual(origStartTs.Compare(*verifier.srcChangeStreamReader.startAtTs), 0) } func (suite *IntegrationTestSuite) TestWithChangeEventsBatching() { @@ -246,7 +250,7 @@ func (suite *IntegrationTestSuite) TestWithChangeEventsBatching() { verifier := suite.BuildVerifier() - suite.Require().NoError(verifier.StartChangeStream(ctx)) + suite.startSrcChangeStreamReaderAndHandler(ctx, verifier) _, err := coll1.InsertOne(ctx, bson.D{{"_id", 1}}) suite.Require().NoError(err) @@ -267,7 +271,6 @@ func (suite *IntegrationTestSuite) TestWithChangeEventsBatching() { 500*time.Millisecond, "the verifier should flush a recheck doc after a batch", ) - } func (suite *IntegrationTestSuite) TestCursorKilledResilience() { @@ -451,6 +454,77 @@ func (suite *IntegrationTestSuite) TestCreateForbidden() { suite.Assert().Equal("create", eventErr.Event.OpType) } +func (suite *IntegrationTestSuite) TestRecheckDocsWithDstChangeEvents() { + ctx := suite.Context() + + srcDBName := suite.DBNameForTest("src") + dstDBName := suite.DBNameForTest("dst") + + db := suite.dstMongoClient.Database(dstDBName) + coll1 := db.Collection("dstColl1") + coll2 := db.Collection("dstColl2") + + for _, coll := range mslices.Of(coll1, coll2) { + suite.Require().NoError(db.CreateCollection(ctx, coll.Name())) + } + + verifier := suite.BuildVerifier() + verifier.SetSrcNamespaces([]string{srcDBName + ".srcColl1", srcDBName + ".srcColl2"}) + verifier.SetDstNamespaces([]string{dstDBName + ".dstColl1", dstDBName + ".dstColl2"}) + verifier.SetNamespaceMap() + + suite.Require().NoError(verifier.dstChangeStreamReader.StartChangeStream(ctx)) + go func() { + err := verifier.StartChangeEventHandler(ctx, verifier.dstChangeStreamReader) + if errors.Is(err, context.Canceled) { + return + } + suite.Require().NoError(err) + }() + + _, err := coll1.InsertOne(ctx, bson.D{{"_id", 1}}) + suite.Require().NoError(err) + _, err = coll1.InsertOne(ctx, bson.D{{"_id", 2}}) + suite.Require().NoError(err) + + _, err = coll2.InsertOne(ctx, bson.D{{"_id", 1}}) + suite.Require().NoError(err) + + var rechecks []RecheckDoc + require.Eventually( + suite.T(), + func() bool { + recheckColl := verifier.verificationDatabase().Collection(recheckQueue) + cursor, err := recheckColl.Find(ctx, bson.D{}) + if errors.Is(err, mongo.ErrNoDocuments) { + return false + } + + suite.Require().NoError(err) + suite.Require().NoError(cursor.All(ctx, &rechecks)) + return len(rechecks) == 3 + }, + time.Minute, + 500*time.Millisecond, + "the verifier should flush a recheck doc after a batch", + ) + + coll1RecheckCount, coll2RecheckCount := 0, 0 + for _, recheck := range rechecks { + suite.Require().Equal(srcDBName, recheck.PrimaryKey.SrcDatabaseName) + switch recheck.PrimaryKey.SrcCollectionName { + case "srcColl1": + coll1RecheckCount++ + case "srcColl2": + coll2RecheckCount++ + default: + suite.T().Fatalf("unknown collection name: %v", recheck.PrimaryKey.SrcCollectionName) + } + } + suite.Require().Equal(2, coll1RecheckCount) + suite.Require().Equal(1, coll2RecheckCount) +} + func (suite *IntegrationTestSuite) TestLargeEvents() { ctx := suite.Context() diff --git a/internal/verifier/check.go b/internal/verifier/check.go index d65fd955..f3538d10 100644 --- a/internal/verifier/check.go +++ b/internal/verifier/check.go @@ -40,17 +40,17 @@ func (verifier *Verifier) Check(ctx context.Context, filter map[string]any) { verifier.MaybeStartPeriodicHeapProfileCollection(ctx) } -func (verifier *Verifier) waitForChangeStream(ctx context.Context) error { +func (verifier *Verifier) waitForChangeStream(ctx context.Context, csr *ChangeStreamReader) error { select { case <-ctx.Done(): return ctx.Err() - case err := <-verifier.changeStreamErrChan: + case err := <-csr.errChan: verifier.logger.Warn().Err(err). - Msg("Received error from change stream.") + Msgf("Received error from %s.", csr) return err - case <-verifier.changeStreamDoneChan: + case <-csr.doneChan: verifier.logger.Debug(). - Msg("Received completion signal from change stream.") + Msgf("Received completion signal from %s.", csr) break } @@ -82,8 +82,10 @@ func (verifier *Verifier) CheckWorker(ctxIn context.Context) error { // If the change stream fails, everything should stop. eg.Go(func() error { select { - case err := <-verifier.changeStreamErrChan: - return errors.Wrap(err, "change stream failed") + case err := <-verifier.srcChangeStreamReader.errChan: + return errors.Wrapf(err, "%s failed", verifier.srcChangeStreamReader) + case err := <-verifier.dstChangeStreamReader.errChan: + return errors.Wrapf(err, "%s failed", verifier.dstChangeStreamReader) case <-ctx.Done(): return nil } @@ -168,6 +170,7 @@ func (verifier *Verifier) CheckDriver(ctx context.Context, filter map[string]any } verifier.running = true verifier.globalFilter = filter + verifier.initializeChangeStreamReaders() verifier.mux.Unlock() defer func() { verifier.mux.Lock() @@ -204,17 +207,20 @@ func (verifier *Verifier) CheckDriver(ctx context.Context, filter map[string]any verifier.phase = Idle }() - verifier.mux.RLock() - csRunning := verifier.changeStreamRunning - verifier.mux.RUnlock() - if csRunning { - verifier.logger.Debug().Msg("Check: Change stream already running.") - } else { - verifier.logger.Debug().Msg("Change stream not running; starting change stream") + ceHandlerGroup, groupCtx := errgroup.WithContext(ctx) + for _, csReader := range []*ChangeStreamReader{verifier.srcChangeStreamReader, verifier.dstChangeStreamReader} { + if csReader.changeStreamRunning { + verifier.logger.Debug().Msgf("Check: %s already running.", csReader) + } else { + verifier.logger.Debug().Msgf("%s not running; starting change stream", csReader) - err = verifier.StartChangeStream(ctx) - if err != nil { - return errors.Wrap(err, "failed to start change stream on source") + err = csReader.StartChangeStream(ctx) + if err != nil { + return errors.Wrapf(err, "failed to start %s", csReader) + } + ceHandlerGroup.Go(func() error { + return verifier.StartChangeEventHandler(groupCtx, csReader) + }) } } @@ -279,13 +285,18 @@ func (verifier *Verifier) CheckDriver(ctx context.Context, filter map[string]any // caught again on the next iteration. if verifier.writesOff { verifier.logger.Debug(). - Msg("Waiting for change stream to end.") + Msg("Waiting for change streams to end.") // It's necessary to wait for the change stream to finish before incrementing the // generation number, or the last changes will not be checked. verifier.mux.Unlock() - err := verifier.waitForChangeStream(ctx) - if err != nil { + if err = verifier.waitForChangeStream(ctx, verifier.srcChangeStreamReader); err != nil { + return err + } + if err = verifier.waitForChangeStream(ctx, verifier.dstChangeStreamReader); err != nil { + return err + } + if err = ceHandlerGroup.Wait(); err != nil { return err } verifier.mux.Lock() diff --git a/internal/verifier/compare.go b/internal/verifier/compare.go index e619255c..1add3b5b 100644 --- a/internal/verifier/compare.go +++ b/internal/verifier/compare.go @@ -299,7 +299,7 @@ func (verifier *Verifier) getFetcherChannelsAndCallbacks( ctx, verifier.srcClientCollection(task), verifier.srcClusterInfo, - verifier.srcStartAtTs, + verifier.srcChangeStreamReader.startAtTs, task, ) @@ -325,7 +325,7 @@ func (verifier *Verifier) getFetcherChannelsAndCallbacks( ctx, verifier.dstClientCollection(task), verifier.dstClusterInfo, - nil, //startAtTs + verifier.dstChangeStreamReader.startAtTs, task, ) diff --git a/internal/verifier/integration_test_suite.go b/internal/verifier/integration_test_suite.go index 805d84b8..b1d16cf7 100644 --- a/internal/verifier/integration_test_suite.go +++ b/internal/verifier/integration_test_suite.go @@ -123,7 +123,6 @@ func (suite *IntegrationTestSuite) TearDownTest() { suite.contextCanceller(errors.Errorf("tearing down test %#q", suite.T().Name())) suite.testContext, suite.contextCanceller = nil, nil - ctx := context.Background() for _, client := range []*mongo.Client{suite.srcMongoClient, suite.dstMongoClient} { dbNames, err := client.ListDatabaseNames(ctx, bson.D{}) @@ -184,6 +183,7 @@ func (suite *IntegrationTestSuite) BuildVerifier() *Verifier { "should set metadata connection string", ) verifier.SetMetaDBName(metaDBName) + verifier.initializeChangeStreamReaders() suite.Require().NoError(verifier.srcClientCollection(&task).Drop(ctx)) suite.Require().NoError(verifier.dstClientCollection(&task).Drop(ctx)) diff --git a/internal/verifier/migration_verifier.go b/internal/verifier/migration_verifier.go index 1a83db6e..a018e87e 100644 --- a/internal/verifier/migration_verifier.go +++ b/internal/verifier/migration_verifier.go @@ -76,6 +76,13 @@ const ( notOkSymbol = "\u2757" // heavy exclamation mark symbol ) +type whichCluster string + +const ( + src whichCluster = "source" + dst whichCluster = "destination" +) + var timeFormat = time.RFC3339 // Verifier is the main state for the migration verifier @@ -121,17 +128,13 @@ type Verifier struct { srcNamespaces []string dstNamespaces []string - nsMap map[string]string + nsMap *NSMap metaDBName string - srcStartAtTs *primitive.Timestamp - mux sync.RWMutex - changeStreamRunning bool - changeStreamWritesOffTsChan chan primitive.Timestamp - changeStreamErrChan chan error - changeStreamDoneChan chan struct{} - lastChangeEventTime *primitive.Timestamp - writesOffTimestamp *primitive.Timestamp + mux sync.RWMutex + + srcChangeStreamReader *ChangeStreamReader + dstChangeStreamReader *ChangeStreamReader readConcernSetting ReadConcernSetting @@ -199,15 +202,13 @@ func NewVerifier(settings VerifierSettings, logPath string) *Verifier { logger: logger, writer: logWriter, - phase: Idle, - numWorkers: NumWorkers, - readPreference: readpref.Primary(), - partitionSizeInBytes: 400 * 1024 * 1024, - failureDisplaySize: DefaultFailureDisplaySize, - changeStreamWritesOffTsChan: make(chan primitive.Timestamp), - changeStreamErrChan: make(chan error), - changeStreamDoneChan: make(chan struct{}), - readConcernSetting: readConcern, + phase: Idle, + numWorkers: NumWorkers, + readPreference: readpref.Primary(), + partitionSizeInBytes: 400 * 1024 * 1024, + failureDisplaySize: DefaultFailureDisplaySize, + + readConcernSetting: readConcern, // This will get recreated once gen0 starts, but we want it // here in case the change streams gets an event before then. @@ -216,6 +217,7 @@ func NewVerifier(settings VerifierSettings, logPath string) *Verifier { workerTracker: NewWorkerTracker(NumWorkers), verificationStatusCheckInterval: 15 * time.Second, + nsMap: NewNSMap(), } } @@ -248,35 +250,50 @@ func (verifier *Verifier) WritesOff(ctx context.Context) error { Msg("WritesOff called.") verifier.mux.Lock() + if verifier.writesOff { + verifier.mux.Unlock() + return errors.New("writesOff already set") + } verifier.writesOff = true - if verifier.writesOffTimestamp == nil { - verifier.logger.Debug().Msg("Change stream still running. Signalling that writes are done.") + verifier.logger.Debug().Msg("Signalling that writes are done.") - finalTs, err := GetNewClusterTime( - ctx, - verifier.logger, - verifier.srcClient, - ) + srcFinalTs, err := GetNewClusterTime( + ctx, + verifier.logger, + verifier.srcClient, + ) - if err != nil { - return errors.Wrapf(err, "failed to fetch source's cluster time") - } + if err != nil { + verifier.mux.Unlock() + return errors.Wrapf(err, "failed to fetch source's cluster time") + } - verifier.writesOffTimestamp = &finalTs + dstFinalTs, err := GetNewClusterTime( + ctx, + verifier.logger, + verifier.dstClient, + ) + if err != nil { verifier.mux.Unlock() + return errors.Wrapf(err, "failed to fetch destination's cluster time") + } + verifier.mux.Unlock() - // This has to happen outside the lock because the change stream - // might be inserting docs into the recheck queue, which happens - // under the lock. - select { - case verifier.changeStreamWritesOffTsChan <- finalTs: - case err := <-verifier.changeStreamErrChan: - return errors.Wrap(err, "tried to send writes-off timestamp to change stream, but change stream already failed") - } - } else { - verifier.mux.Unlock() + // This has to happen outside the lock because the change streams + // might be inserting docs into the recheck queue, which happens + // under the lock. + select { + case verifier.srcChangeStreamReader.writesOffTsChan <- srcFinalTs: + case err := <-verifier.srcChangeStreamReader.errChan: + return errors.Wrapf(err, "tried to send writes-off timestamp to %s, but change stream already failed", verifier.srcChangeStreamReader) + } + + select { + case verifier.dstChangeStreamReader.writesOffTsChan <- dstFinalTs: + case err := <-verifier.dstChangeStreamReader.errChan: + return errors.Wrapf(err, "tried to send writes-off timestamp to %s, but change stream already failed", verifier.dstChangeStreamReader) } return nil @@ -346,13 +363,7 @@ func (verifier *Verifier) SetDstNamespaces(arg []string) { } func (verifier *Verifier) SetNamespaceMap() { - verifier.nsMap = make(map[string]string) - if len(verifier.dstNamespaces) == 0 { - return - } - for i, ns := range verifier.srcNamespaces { - verifier.nsMap[ns] = verifier.dstNamespaces[i] - } + verifier.nsMap.PopulateWithNamespaces(verifier.srcNamespaces, verifier.dstNamespaces) } func (verifier *Verifier) SetMetaDBName(arg string) { diff --git a/internal/verifier/migration_verifier_test.go b/internal/verifier/migration_verifier_test.go index 6528a523..97eb49b4 100644 --- a/internal/verifier/migration_verifier_test.go +++ b/internal/verifier/migration_verifier_test.go @@ -153,6 +153,7 @@ func (suite *IntegrationTestSuite) TestGetNamespaceStatistics_Recheck() { ID: "heyhey", }, }}, + src, ) suite.Require().NoError(err) @@ -168,6 +169,7 @@ func (suite *IntegrationTestSuite) TestGetNamespaceStatistics_Recheck() { ID: "hoohoo", }, }}, + src, ) suite.Require().NoError(err) @@ -411,19 +413,19 @@ func (suite *IntegrationTestSuite) TestFailedVerificationTaskInsertions() { }, } - err = verifier.HandleChangeStreamEvents(ctx, []ParsedEvent{event}) + err = verifier.HandleChangeStreamEvents(ctx, []ParsedEvent{event}, src) suite.Require().NoError(err) event.OpType = "insert" - err = verifier.HandleChangeStreamEvents(ctx, []ParsedEvent{event}) + err = verifier.HandleChangeStreamEvents(ctx, []ParsedEvent{event}, src) suite.Require().NoError(err) event.OpType = "replace" - err = verifier.HandleChangeStreamEvents(ctx, []ParsedEvent{event}) + err = verifier.HandleChangeStreamEvents(ctx, []ParsedEvent{event}, src) suite.Require().NoError(err) event.OpType = "update" - err = verifier.HandleChangeStreamEvents(ctx, []ParsedEvent{event}) + err = verifier.HandleChangeStreamEvents(ctx, []ParsedEvent{event}, src) suite.Require().NoError(err) event.OpType = "flibbity" - err = verifier.HandleChangeStreamEvents(ctx, []ParsedEvent{event}) + err = verifier.HandleChangeStreamEvents(ctx, []ParsedEvent{event}, src) badEventErr := UnknownEventError{} suite.Require().ErrorAs(err, &badEventErr) suite.Assert().Equal("flibbity", badEventErr.Event.OpType) @@ -1333,7 +1335,8 @@ func (suite *IntegrationTestSuite) TestGenerationalRechecking() { dbname1 := suite.DBNameForTest("1") dbname2 := suite.DBNameForTest("2") - zerolog.SetGlobalLevel(zerolog.DebugLevel) + zerolog.SetGlobalLevel(zerolog.TraceLevel) + defer zerolog.SetGlobalLevel(zerolog.DebugLevel) verifier := suite.BuildVerifier() verifier.SetSrcNamespaces([]string{dbname1 + ".testColl1"}) verifier.SetDstNamespaces([]string{dbname2 + ".testColl3"}) @@ -1385,7 +1388,13 @@ func (suite *IntegrationTestSuite) TestGenerationalRechecking() { // wait for generation to finish suite.Require().NoError(runner.AwaitGenerationEnd()) status = waitForTasks() - // there should be no failures now, since they are are equivalent at this point in time + // there should be no failures now, since they are equivalent at this point in time + suite.Require().Equal(VerificationStatus{TotalTasks: 1, CompletedTasks: 1}, *status) + + // The next generation should process the recheck task caused by inserting {_id: 2} on the destination. + suite.Require().NoError(runner.StartNextGeneration()) + suite.Require().NoError(runner.AwaitGenerationEnd()) + status = waitForTasks() suite.Require().Equal(VerificationStatus{TotalTasks: 1, CompletedTasks: 1}, *status) // now insert in the source, this should come up next generation @@ -1413,7 +1422,7 @@ func (suite *IntegrationTestSuite) TestGenerationalRechecking() { suite.Require().NoError(runner.AwaitGenerationEnd()) status = waitForTasks() - // there should be no failures now, since they are are equivalent at this point in time + // there should be no failures now, since they are equivalent at this point in time suite.Assert().Equal(VerificationStatus{TotalTasks: 1, CompletedTasks: 1}, *status) // We could just abandon this verifier, but we might as well shut it down @@ -1545,6 +1554,81 @@ func (suite *IntegrationTestSuite) TestVerifierWithFilter() { <-checkDoneChan } +func (suite *IntegrationTestSuite) waitForRecheckDocs(verifier *Verifier) { + suite.Eventually(func() bool { + cursor, err := suite.metaMongoClient.Database(verifier.metaDBName).Collection(recheckQueue).Find(suite.Context(), bson.D{}) + var docs []bson.D + suite.Require().NoError(err) + suite.Require().NoError(cursor.All(suite.Context(), &docs)) + return len(docs) > 0 + }, 1*time.Minute, 100*time.Millisecond) +} + +func (suite *IntegrationTestSuite) TestChangesOnDstBeforeSrc() { + zerolog.SetGlobalLevel(zerolog.TraceLevel) + defer zerolog.SetGlobalLevel(zerolog.DebugLevel) + + ctx := suite.Context() + + collName := "mycoll" + + srcDB := suite.srcMongoClient.Database(suite.DBNameForTest()) + dstDB := suite.dstMongoClient.Database(suite.DBNameForTest()) + suite.Require().NoError(srcDB.CreateCollection(ctx, collName)) + suite.Require().NoError(dstDB.CreateCollection(ctx, collName)) + + verifier := suite.BuildVerifier() + runner := RunVerifierCheck(ctx, suite.T(), verifier) + + // Dry run generation 0 to make sure change stream reader is started. + suite.Require().NoError(runner.AwaitGenerationEnd()) + + // Insert two documents in generation 1. They should be batched and become a verify task in generation 2. + suite.Require().NoError(runner.StartNextGeneration()) + _, err := dstDB.Collection(collName).InsertOne(ctx, bson.D{{"_id", 1}}) + suite.Require().NoError(err) + _, err = dstDB.Collection(collName).InsertOne(ctx, bson.D{{"_id", 2}}) + suite.Require().NoError(err) + suite.Require().NoError(runner.AwaitGenerationEnd()) + suite.waitForRecheckDocs(verifier) + + // Run generation 2 and get verification status. + suite.Require().NoError(runner.StartNextGeneration()) + suite.Require().NoError(runner.AwaitGenerationEnd()) + status, err := verifier.GetVerificationStatus(ctx) + suite.Require().NoError(err) + suite.Assert().Equal( + 1, + status.FailedTasks, + ) + + // Patch up only one of the two mismatched documents in generation 3. + suite.Require().NoError(runner.StartNextGeneration()) + _, err = srcDB.Collection(collName).InsertOne(ctx, bson.D{{"_id", 1}}) + suite.Require().NoError(err) + suite.Require().NoError(runner.AwaitGenerationEnd()) + suite.waitForRecheckDocs(verifier) + + status, err = verifier.GetVerificationStatus(ctx) + suite.Require().NoError(err) + suite.Assert().Equal( + 1, + status.FailedTasks, + ) + + // Patch up the other mismatched document in generation 4. + suite.Require().NoError(runner.StartNextGeneration()) + _, err = srcDB.Collection(collName).InsertOne(ctx, bson.D{{"_id", 2}}) + suite.Require().NoError(err) + suite.Require().NoError(runner.AwaitGenerationEnd()) + suite.waitForRecheckDocs(verifier) + + // Everything should match by the end of it. + status, err = verifier.GetVerificationStatus(ctx) + suite.Require().NoError(err) + suite.Assert().Zero(status.FailedTasks) +} + func (suite *IntegrationTestSuite) TestBackgroundInIndexSpec() { ctx := suite.Context() @@ -1611,6 +1695,7 @@ func (suite *IntegrationTestSuite) TestPartitionWithFilter() { // Set up the verifier for testing. verifier := suite.BuildVerifier() verifier.SetSrcNamespaces([]string{dbname + ".testColl1"}) + verifier.SetDstNamespaces([]string{dbname + ".testColl1"}) verifier.SetNamespaceMap() verifier.globalFilter = filter // Use a small partition size so that we can test creating multiple partitions. diff --git a/internal/verifier/nsmap.go b/internal/verifier/nsmap.go new file mode 100644 index 00000000..ebf52777 --- /dev/null +++ b/internal/verifier/nsmap.go @@ -0,0 +1,49 @@ +package verifier + +type NSMap struct { + srcDstNsMap map[string]string + dstSrcNsMap map[string]string +} + +func NewNSMap() *NSMap { + return &NSMap{ + srcDstNsMap: make(map[string]string), + dstSrcNsMap: make(map[string]string), + } +} + +func (nsmap *NSMap) PopulateWithNamespaces(srcNamespaces []string, dstNamespaces []string) { + if len(srcNamespaces) != len(dstNamespaces) { + panic("source and destination namespaces are not the same length") + } + + for i, srcNs := range srcNamespaces { + dstNs := dstNamespaces[i] + if _, exist := nsmap.srcDstNsMap[srcNs]; exist { + panic("another mapping already exists for source namespace " + srcNs) + } + if _, exist := nsmap.dstSrcNsMap[dstNs]; exist { + panic("another mapping already exists for destination namespace " + dstNs) + } + nsmap.srcDstNsMap[srcNs] = dstNs + nsmap.dstSrcNsMap[dstNs] = srcNs + } +} + +func (nsmap *NSMap) Len() int { + if len(nsmap.srcDstNsMap) != len(nsmap.dstSrcNsMap) { + panic("source and destination namespaces are not the same length") + } + + return len(nsmap.srcDstNsMap) +} + +func (nsmap *NSMap) GetDstNamespace(srcNamespace string) (string, bool) { + ns, ok := nsmap.srcDstNsMap[srcNamespace] + return ns, ok +} + +func (nsmap *NSMap) GetSrcNamespace(dstNamespace string) (string, bool) { + ns, ok := nsmap.dstSrcNsMap[dstNamespace] + return ns, ok +} diff --git a/internal/verifier/nsmap_test.go b/internal/verifier/nsmap_test.go new file mode 100644 index 00000000..766fada3 --- /dev/null +++ b/internal/verifier/nsmap_test.go @@ -0,0 +1,39 @@ +package verifier + +import ( + "testing" + + "github.com/stretchr/testify/suite" +) + +type UnitTestSuite struct { + suite.Suite +} + +func TestUnitTestSuite(t *testing.T) { + ts := new(UnitTestSuite) + suite.Run(t, ts) +} + +func (s *UnitTestSuite) Test_EmptyNsMap() { + nsMap := NewNSMap() + srcNamespaces := []string{"srcDB.A", "srcDB.B"} + dstNamespaces := []string{"dstDB.B", "dstDB.A"} + nsMap.PopulateWithNamespaces(srcNamespaces, dstNamespaces) + s.Require().Equal(2, nsMap.Len()) + + _, ok := nsMap.GetDstNamespace("non-existent.coll") + s.Require().False(ok) + + for i, srcNs := range srcNamespaces { + gotNs, ok := nsMap.GetDstNamespace(srcNs) + s.Require().True(ok) + s.Require().Equal(dstNamespaces[i], gotNs) + } + + for i, dstNs := range dstNamespaces { + gotNs, ok := nsMap.GetSrcNamespace(dstNs) + s.Require().True(ok) + s.Require().Equal(srcNamespaces[i], gotNs) + } +} diff --git a/internal/verifier/recheck.go b/internal/verifier/recheck.go index 1ca45737..55104eea 100644 --- a/internal/verifier/recheck.go +++ b/internal/verifier/recheck.go @@ -24,16 +24,16 @@ const ( // RecheckPrimaryKey stores the implicit type of recheck to perform // Currently, we only handle document mismatches/change stream updates, -// so DatabaseName, CollectionName, and DocumentID must always be specified. +// so SrcDatabaseName, SrcCollectionName, and DocumentID must always be specified. // // NB: Order is important here so that, within a given generation, // sorting by _id will guarantee that all rechecks for a given // namespace appear consecutively. type RecheckPrimaryKey struct { - Generation int `bson:"generation"` - DatabaseName string `bson:"db"` - CollectionName string `bson:"coll"` - DocumentID interface{} `bson:"docID"` + Generation int `bson:"generation"` + SrcDatabaseName string `bson:"db"` + SrcCollectionName string `bson:"coll"` + DocumentID interface{} `bson:"docID"` } // RecheckDoc stores the necessary information to know which documents must be rechecked. @@ -108,10 +108,10 @@ func (verifier *Verifier) insertRecheckDocs( models := make([]mongo.WriteModel, len(curThreadIndexes)) for m, i := range curThreadIndexes { pk := RecheckPrimaryKey{ - Generation: generation, - DatabaseName: dbNames[i], - CollectionName: collNames[i], - DocumentID: documentIDs[i], + Generation: generation, + SrcDatabaseName: dbNames[i], + SrcCollectionName: collNames[i], + DocumentID: documentIDs[i], } // The filter must exclude DataSize; otherwise, if a failed comparison @@ -302,8 +302,8 @@ func (verifier *Verifier) GenerateRecheckTasks(ctx context.Context) error { // - the buffered document IDs’ size exceeds the per-task maximum // - the buffered documents exceed the partition size // - if doc.PrimaryKey.DatabaseName != prevDBName || - doc.PrimaryKey.CollectionName != prevCollName || + if doc.PrimaryKey.SrcDatabaseName != prevDBName || + doc.PrimaryKey.SrcCollectionName != prevCollName || int64(len(idAccum)) > maxDocsPerTask || idLenAccum >= maxIdsPerRecheckTask || dataSizeAccum >= verifier.partitionSizeInBytes { @@ -313,8 +313,8 @@ func (verifier *Verifier) GenerateRecheckTasks(ctx context.Context) error { return err } - prevDBName = doc.PrimaryKey.DatabaseName - prevCollName = doc.PrimaryKey.CollectionName + prevDBName = doc.PrimaryKey.SrcDatabaseName + prevCollName = doc.PrimaryKey.SrcCollectionName idLenAccum = 0 dataSizeAccum = 0 idAccum = idAccum[:0] diff --git a/internal/verifier/recheck_test.go b/internal/verifier/recheck_test.go index 00d4c6aa..c9ae8b5e 100644 --- a/internal/verifier/recheck_test.go +++ b/internal/verifier/recheck_test.go @@ -29,10 +29,10 @@ func (suite *IntegrationTestSuite) TestFailedCompareThenReplace() { []RecheckDoc{ { PrimaryKey: RecheckPrimaryKey{ - Generation: verifier.generation, - DatabaseName: "the", - CollectionName: "namespace", - DocumentID: "theDocID", + Generation: verifier.generation, + SrcDatabaseName: "the", + SrcCollectionName: "namespace", + DocumentID: "theDocID", }, DataSize: 1234, }, @@ -53,7 +53,7 @@ func (suite *IntegrationTestSuite) TestFailedCompareThenReplace() { FullDocument: testutil.MustMarshal(bson.D{{"foo", 1}}), } - err := verifier.HandleChangeStreamEvents(ctx, []ParsedEvent{event}) + err := verifier.HandleChangeStreamEvents(ctx, []ParsedEvent{event}, src) suite.Require().NoError(err) recheckDocs = suite.fetchRecheckDocs(ctx, verifier) @@ -61,10 +61,10 @@ func (suite *IntegrationTestSuite) TestFailedCompareThenReplace() { []RecheckDoc{ { PrimaryKey: RecheckPrimaryKey{ - Generation: verifier.generation, - DatabaseName: "the", - CollectionName: "namespace", - DocumentID: "theDocID", + Generation: verifier.generation, + SrcDatabaseName: "the", + SrcCollectionName: "namespace", + DocumentID: "theDocID", }, DataSize: len(event.FullDocument), }, @@ -102,10 +102,10 @@ func (suite *IntegrationTestSuite) TestLargeIDInsertions() { d1 := RecheckDoc{ PrimaryKey: RecheckPrimaryKey{ - Generation: 0, - DatabaseName: "testDB", - CollectionName: "testColl", - DocumentID: id1, + Generation: 0, + SrcDatabaseName: "testDB", + SrcCollectionName: "testColl", + DocumentID: id1, }, DataSize: overlyLarge} d2 := d1 @@ -162,10 +162,10 @@ func (suite *IntegrationTestSuite) TestLargeDataInsertions() { suite.Require().NoError(err) d1 := RecheckDoc{ PrimaryKey: RecheckPrimaryKey{ - Generation: 0, - DatabaseName: "testDB", - CollectionName: "testColl", - DocumentID: id1, + Generation: 0, + SrcDatabaseName: "testDB", + SrcCollectionName: "testColl", + DocumentID: id1, }, DataSize: dataSizes[0]} d2 := d1 @@ -284,10 +284,10 @@ func (suite *IntegrationTestSuite) TestGenerationalClear() { d1 := RecheckDoc{ PrimaryKey: RecheckPrimaryKey{ - Generation: 0, - DatabaseName: "testDB", - CollectionName: "testColl", - DocumentID: id1, + Generation: 0, + SrcDatabaseName: "testDB", + SrcCollectionName: "testColl", + DocumentID: id1, }, DataSize: dataSizes[0]} d2 := d1 diff --git a/internal/verifier/verification_task.go b/internal/verifier/verification_task.go index edfaaed5..cc8d6479 100644 --- a/internal/verifier/verification_task.go +++ b/internal/verifier/verification_task.go @@ -95,9 +95,9 @@ func (verifier *Verifier) insertCollectionVerificationTask( generation int) (*VerificationTask, error) { dstNamespace := srcNamespace - if len(verifier.nsMap) != 0 { + if verifier.nsMap.Len() != 0 { var ok bool - dstNamespace, ok = verifier.nsMap[srcNamespace] + dstNamespace, ok = verifier.nsMap.GetDstNamespace(srcNamespace) if !ok { return nil, fmt.Errorf("Could not find Namespace %s", srcNamespace) } @@ -169,9 +169,9 @@ func (verifier *Verifier) InsertDocumentRecheckTask( srcNamespace string, ) error { dstNamespace := srcNamespace - if len(verifier.nsMap) != 0 { + if verifier.nsMap.Len() != 0 { var ok bool - dstNamespace, ok = verifier.nsMap[srcNamespace] + dstNamespace, ok = verifier.nsMap.GetDstNamespace(srcNamespace) if !ok { return fmt.Errorf("Could not find Namespace %s", srcNamespace) }