diff --git a/internal/documentmap/documentmap.go b/internal/documentmap/documentmap.go index 499f214b..38e10664 100644 --- a/internal/documentmap/documentmap.go +++ b/internal/documentmap/documentmap.go @@ -34,6 +34,8 @@ import ( "fmt" "github.com/10gen/migration-verifier/internal/logger" + "github.com/10gen/migration-verifier/internal/memorytracker" + "github.com/10gen/migration-verifier/internal/reportutils" "github.com/10gen/migration-verifier/internal/types" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" @@ -56,6 +58,7 @@ type mapKeyToDocMap map[MapKey]bson.Raw // Map is the main struct for this package. type Map struct { internalMap mapKeyToDocMap + bytesSize types.ByteCount logger *logger.Logger // This always includes idFieldName @@ -92,14 +95,14 @@ func (m *Map) CloneEmpty() *Map { // own goroutine. // // As a safeguard, this panics if called more than once. -func (m *Map) ImportFromCursor(ctx context.Context, cursor *mongo.Cursor) error { +func (m *Map) ImportFromCursor(ctx context.Context, cursor *mongo.Cursor, trackerWriter memorytracker.Writer) error { if m.imported { panic("Refuse duplicate call!") } m.imported = true - var bytesReturned int64 + var bytesReturned types.ByteCount bytesReturned, nDocumentsReturned := 0, 0 for cursor.Next(ctx) { @@ -112,12 +115,23 @@ func (m *Map) ImportFromCursor(ctx context.Context, cursor *mongo.Cursor) error return err } + docSize := (types.ByteCount)(len(cursor.Current)) + + // This will block if needs be to prevent OOMs. + trackerWriter <- memorytracker.Unit(docSize) + + bytesReturned += docSize nDocumentsReturned++ - bytesReturned += (int64)(len(cursor.Current)) m.copyAndAddDocument(cursor.Current) } - m.logger.Debug().Msgf("Find returned %d documents containing %d bytes", nDocumentsReturned, bytesReturned) + + m.bytesSize = bytesReturned + + m.logger.Info(). + Int("documentedReturned", nDocumentsReturned). + Str("totalSize", reportutils.FmtBytes(bytesReturned)). + Msgf("Finished reading %#q query.", "find") return nil } @@ -184,12 +198,7 @@ func (m *Map) Count() types.DocumentCount { // TotalDocsBytes returns the combined byte size of the Map’s documents. func (m *Map) TotalDocsBytes() types.ByteCount { - var size types.ByteCount - for _, doc := range m.internalMap { - size += types.ByteCount(len(doc)) - } - - return size + return m.bytesSize } // ---------------------------------------------------------------------- diff --git a/internal/memorytracker/memorytracker.go b/internal/memorytracker/memorytracker.go new file mode 100644 index 00000000..1d99627a --- /dev/null +++ b/internal/memorytracker/memorytracker.go @@ -0,0 +1,148 @@ +package memorytracker + +import ( + "context" + "reflect" + "slices" + "sync" + + "github.com/10gen/migration-verifier/internal/logger" + "github.com/10gen/migration-verifier/internal/reportutils" +) + +type Unit = int64 +type reader = <-chan Unit +type Writer = chan<- Unit + +type Tracker struct { + logger *logger.Logger + softLimit Unit + curUsage Unit + selectCases []reflect.SelectCase + mux sync.RWMutex +} + +func Start(ctx context.Context, logger *logger.Logger, max Unit) *Tracker { + tracker := Tracker{ + softLimit: max, + logger: logger, + } + + go tracker.track(ctx) + + return &tracker +} + +func (mt *Tracker) AddWriter() Writer { + mt.mux.RLock() + defer mt.mux.RUnlock() + + newChan := make(chan Unit) + + mt.selectCases = append(mt.selectCases, reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(reader(newChan)), + }) + + return newChan +} + +func (mt *Tracker) getSelectCases(ctx context.Context) []reflect.SelectCase { + mt.mux.RLock() + defer mt.mux.RUnlock() + + cases := make([]reflect.SelectCase, 1+len(mt.selectCases)) + cases[0] = reflect.SelectCase{ + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(ctx.Done()), + } + + for i := range mt.selectCases { + cases[1+i] = mt.selectCases[i] + } + + return cases +} + +func (mt *Tracker) removeSelectCase(i int) { + mt.mux.Lock() + defer mt.mux.Unlock() + + mt.selectCases = slices.Delete(mt.selectCases, i-1, i) +} + +func (mt *Tracker) track(ctx context.Context) { + for { + if mt.curUsage > mt.softLimit { + mt.logger.Panic(). + Int64("usage", mt.curUsage). + Int64("softLimit", mt.softLimit). + Msg("track() loop should never be in memory excess!") + } + + selectCases := mt.getSelectCases(ctx) + + chosen, gotVal, alive := reflect.Select(selectCases) + + if chosen == 0 { + mt.logger.Debug(). + AnErr("contextErr", context.Cause(ctx)). + Msg("Stopping memory tracker.") + + return + } + + got := (gotVal.Interface()).(Unit) + mt.curUsage += got + + if got < 0 { + mt.logger.Info(). + Str("reclaimed", reportutils.FmtBytes(-got)). + Str("tracked", reportutils.FmtBytes(mt.curUsage)). + Msg("Reclaimed tracked memory.") + } + + if !alive { + if got != 0 { + mt.logger.Panic(). + Int64("receivedValue", got). + Msg("Got nonzero track value but channel is closed.") + } + + // Closure of a channel indicates that the worker thread is + // finished. + mt.removeSelectCase(chosen) + + continue + } + + didSingleThread := false + + for mt.curUsage > mt.softLimit { + reader := (selectCases[chosen].Chan.Interface()).(reader) + + if !didSingleThread { + mt.logger.Warn(). + Str("usage", reportutils.FmtBytes(mt.curUsage)). + Str("softLimit", reportutils.FmtBytes(mt.softLimit)). + Msg("Tracked memory usage now exceeds soft limit. Suspending concurrent reads until tracked usage falls.") + + didSingleThread = true + } + + got, alive := <-reader + mt.curUsage += got + + if !alive { + mt.removeSelectCase(chosen) + } + } + + if didSingleThread { + mt.logger.Info(). + Str("usage", reportutils.FmtBytes(mt.curUsage)). + Str("softLimit", reportutils.FmtBytes(mt.softLimit)). + Msg("Tracked memory usage is now below soft limit. Resuming concurrent reads.") + } + } +} diff --git a/internal/reportutils/reportutils.go b/internal/reportutils/reportutils.go index acb603b9..4991f0a0 100644 --- a/internal/reportutils/reportutils.go +++ b/internal/reportutils/reportutils.go @@ -15,10 +15,9 @@ import ( const decimalPrecision = 2 -// This could include signed ints, but we have no need for now. -// The bigger requirement is that it exclude uint8. +// This must exclude uint8. type num16Plus interface { - constraints.Float | ~uint | ~uint16 | ~uint32 | ~uint64 + constraints.Float | ~uint | ~uint16 | ~uint32 | ~uint64 | ~int64 } type realNum interface { @@ -68,6 +67,13 @@ func DurationToHMS(duration time.Duration) string { return str } +// FmtBytes is a convenience that combines BytesToUnit with FindBestUnit. +// Use it to format a single count of bytes. +func FmtBytes[T num16Plus](count T) string { + unit := FindBestUnit(count) + return BytesToUnit(count, unit) + " " + string(unit) +} + // BytesToUnit returns a stringified number that represents `count` // in the given `unit`. For example, count=1024 and unit=KiB would // return "1". diff --git a/internal/verifier/check.go b/internal/verifier/check.go index b7a1338d..be584348 100644 --- a/internal/verifier/check.go +++ b/internal/verifier/check.go @@ -6,6 +6,7 @@ import ( "sync" "time" + "github.com/10gen/migration-verifier/internal/memorytracker" "github.com/10gen/migration-verifier/internal/retry" mapset "github.com/deckarep/golang-set/v2" "github.com/pkg/errors" @@ -62,11 +63,14 @@ func (verifier *Verifier) waitForChangeStream() error { func (verifier *Verifier) CheckWorker(ctx context.Context) error { verifier.logger.Debug().Msgf("Starting %d verification workers", verifier.numWorkers) + memTracker := memorytracker.Start(ctx, verifier.logger, 40_000_000_000) // TODO ctx, cancel := context.WithCancel(ctx) wg := sync.WaitGroup{} for i := 0; i < verifier.numWorkers; i++ { wg.Add(1) - go verifier.Work(ctx, i, &wg) + trackerWriter := memTracker.AddWriter() + defer close(trackerWriter) + go verifier.Work(ctx, i, &wg, trackerWriter) time.Sleep(10 * time.Millisecond) } @@ -345,7 +349,7 @@ func FetchFailedAndIncompleteTasks(ctx context.Context, coll *mongo.Collection, return FailedTasks, IncompleteTasks, nil } -func (verifier *Verifier) Work(ctx context.Context, workerNum int, wg *sync.WaitGroup) { +func (verifier *Verifier) Work(ctx context.Context, workerNum int, wg *sync.WaitGroup, trackerWriter memorytracker.Writer) { defer wg.Done() verifier.logger.Debug().Msgf("[Worker %d] Started", workerNum) for { @@ -371,7 +375,7 @@ func (verifier *Verifier) Work(ctx context.Context, workerNum int, wg *sync.Wait } } } else { - verifier.ProcessVerifyTask(workerNum, task) + verifier.ProcessVerifyTask(workerNum, task, trackerWriter) } } } diff --git a/internal/verifier/migration_verifier.go b/internal/verifier/migration_verifier.go index 47259eb4..445330eb 100644 --- a/internal/verifier/migration_verifier.go +++ b/internal/verifier/migration_verifier.go @@ -18,6 +18,7 @@ import ( "github.com/10gen/migration-verifier/internal/documentmap" "github.com/10gen/migration-verifier/internal/logger" + "github.com/10gen/migration-verifier/internal/memorytracker" "github.com/10gen/migration-verifier/internal/partitions" "github.com/10gen/migration-verifier/internal/reportutils" "github.com/10gen/migration-verifier/internal/retry" @@ -137,6 +138,8 @@ type Verifier struct { // The verifier only checks documents within the filter. globalFilter map[string]any + memoryTracker *memorytracker.Tracker + pprofInterval time.Duration } @@ -423,7 +426,7 @@ func (verifier *Verifier) maybeAppendGlobalFilterToPredicates(predicates bson.A) } func (verifier *Verifier) getDocumentsCursor(ctx context.Context, collection *mongo.Collection, buildInfo *bson.M, - startAtTs *primitive.Timestamp, task *VerificationTask) (*mongo.Cursor, error) { + startAtTs *primitive.Timestamp, task *VerificationTask) (bson.D, *mongo.Cursor, error) { var findOptions bson.D runCommandOptions := options.RunCmd() var andPredicates bson.A @@ -454,11 +457,19 @@ func (verifier *Verifier) getDocumentsCursor(ctx context.Context, collection *mo findCmd := append(bson.D{{"find", collection.Name()}}, findOptions...) verifier.logger.Debug().Msgf("getDocuments findCmd: %s opts: %v", findCmd, *runCommandOptions) - return collection.Database().RunCommandCursor(ctx, findCmd, runCommandOptions) + cursor, err := collection.Database().RunCommandCursor(ctx, findCmd, runCommandOptions) + + return findCmd, cursor, err } -func (verifier *Verifier) FetchAndCompareDocuments(task *VerificationTask) ([]VerificationResult, types.DocumentCount, types.ByteCount, error) { - srcClientMap, dstClientMap, err := verifier.fetchDocuments(task) +func (verifier *Verifier) FetchAndCompareDocuments(task *VerificationTask, trackerWriter memorytracker.Writer) ([]VerificationResult, types.DocumentCount, types.ByteCount, error) { + srcClientMap, dstClientMap, err := verifier.fetchDocuments(task, trackerWriter) + defer func() { + for _, clientMap := range []*documentmap.Map{srcClientMap, dstClientMap} { + trackerWriter <- -memorytracker.Unit(clientMap.TotalDocsBytes()) + } + }() + if err != nil { return nil, 0, 0, err } @@ -472,9 +483,7 @@ func (verifier *Verifier) FetchAndCompareDocuments(task *VerificationTask) ([]Ve } // This is split out to allow unit testing of fetching separate from comparison. -func (verifier *Verifier) fetchDocuments(task *VerificationTask) (*documentmap.Map, *documentmap.Map, error) { - - var srcErr, dstErr error +func (verifier *Verifier) fetchDocuments(task *VerificationTask, trackerWriter memorytracker.Writer) (*documentmap.Map, *documentmap.Map, error) { errGroup, ctx := errgroup.WithContext(context.Background()) @@ -483,28 +492,36 @@ func (verifier *Verifier) fetchDocuments(task *VerificationTask) (*documentmap.M srcClientMap := documentmap.New(verifier.GetLogger(), shardFieldNames...) dstClientMap := srcClientMap.CloneEmpty() + warnThreshold := 10 * verifier.partitionSizeInBytes + errGroup.Go(func() error { - var cursor *mongo.Cursor - cursor, srcErr = verifier.getDocumentsCursor(ctx, verifier.srcClientCollection(task), verifier.srcBuildInfo, + findCmd, cursor, err := verifier.getDocumentsCursor(ctx, verifier.srcClientCollection(task), verifier.srcBuildInfo, verifier.srcStartAtTs, task) - if srcErr == nil { - srcErr = srcClientMap.ImportFromCursor(ctx, cursor) + if err == nil { + err = srcClientMap.ImportFromCursor(ctx, cursor, trackerWriter) + } + + if err == nil && int64(srcClientMap.TotalDocsBytes()) > warnThreshold { + verifier.logger.Warn(). + Str("totalSize", reportutils.FmtBytes(srcClientMap.TotalDocsBytes())). + Str("intendedPartitionSize", reportutils.FmtBytes(verifier.partitionSizeInBytes)). + Str("filter", fmt.Sprintf("%v", findCmd)). + Msg("Partition greatly exceeds desired size. This may cause excess memory usage.") } - return srcErr + return err }) errGroup.Go(func() error { - var cursor *mongo.Cursor - cursor, dstErr = verifier.getDocumentsCursor(ctx, verifier.dstClientCollection(task), verifier.dstBuildInfo, + _, cursor, err := verifier.getDocumentsCursor(ctx, verifier.dstClientCollection(task), verifier.dstBuildInfo, nil /*startAtTs*/, task) - if dstErr == nil { - dstErr = dstClientMap.ImportFromCursor(ctx, cursor) + if err == nil { + err = dstClientMap.ImportFromCursor(ctx, cursor, trackerWriter) } - return dstErr + return err }) err := errGroup.Wait() @@ -632,10 +649,10 @@ func (verifier *Verifier) compareOneDocument(srcClientDoc, dstClientDoc bson.Raw }}, nil } -func (verifier *Verifier) ProcessVerifyTask(workerNum int, task *VerificationTask) { +func (verifier *Verifier) ProcessVerifyTask(workerNum int, task *VerificationTask, trackerWriter memorytracker.Writer) { verifier.logger.Debug().Msgf("[Worker %d] Processing verify task", workerNum) - problems, docsCount, bytesCount, err := verifier.FetchAndCompareDocuments(task) + problems, docsCount, bytesCount, err := verifier.FetchAndCompareDocuments(task, trackerWriter) if err != nil { task.Status = verificationTaskFailed