Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions internal/documentmap/documentmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import (
"fmt"

"github.com/10gen/migration-verifier/internal/logger"
"github.com/10gen/migration-verifier/internal/memorytracker"
"github.com/10gen/migration-verifier/internal/reportutils"
"github.com/10gen/migration-verifier/internal/types"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
Expand All @@ -56,6 +58,7 @@ type mapKeyToDocMap map[MapKey]bson.Raw
// Map is the main struct for this package.
type Map struct {
internalMap mapKeyToDocMap
bytesSize types.ByteCount
logger *logger.Logger

// This always includes idFieldName
Expand Down Expand Up @@ -92,14 +95,14 @@ func (m *Map) CloneEmpty() *Map {
// own goroutine.
//
// As a safeguard, this panics if called more than once.
func (m *Map) ImportFromCursor(ctx context.Context, cursor *mongo.Cursor) error {
func (m *Map) ImportFromCursor(ctx context.Context, cursor *mongo.Cursor, trackerWriter memorytracker.Writer) error {
if m.imported {
panic("Refuse duplicate call!")
}

m.imported = true

var bytesReturned int64
var bytesReturned types.ByteCount
bytesReturned, nDocumentsReturned := 0, 0

for cursor.Next(ctx) {
Expand All @@ -112,12 +115,23 @@ func (m *Map) ImportFromCursor(ctx context.Context, cursor *mongo.Cursor) error
return err
}

docSize := (types.ByteCount)(len(cursor.Current))

// This will block if needs be to prevent OOMs.
trackerWriter <- memorytracker.Unit(docSize)

bytesReturned += docSize
nDocumentsReturned++
bytesReturned += (int64)(len(cursor.Current))

m.copyAndAddDocument(cursor.Current)
}
m.logger.Debug().Msgf("Find returned %d documents containing %d bytes", nDocumentsReturned, bytesReturned)

m.bytesSize = bytesReturned

m.logger.Info().
Int("documentedReturned", nDocumentsReturned).
Str("totalSize", reportutils.FmtBytes(bytesReturned)).
Msgf("Finished reading %#q query.", "find")

return nil
}
Expand Down Expand Up @@ -184,12 +198,7 @@ func (m *Map) Count() types.DocumentCount {

// TotalDocsBytes returns the combined byte size of the Map’s documents.
func (m *Map) TotalDocsBytes() types.ByteCount {
var size types.ByteCount
for _, doc := range m.internalMap {
size += types.ByteCount(len(doc))
}

return size
return m.bytesSize
}

// ----------------------------------------------------------------------
Expand Down
148 changes: 148 additions & 0 deletions internal/memorytracker/memorytracker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package memorytracker

import (
"context"
"reflect"
"slices"

Check failure on line 6 in internal/memorytracker/memorytracker.go

View workflow job for this annotation

GitHub Actions / basics (macos-latest, mongodb-macos-arm64, 1.20)

package slices is not in GOROOT (/Users/runner/hostedtoolcache/go/1.20.14/arm64/src/slices)

Check failure on line 6 in internal/memorytracker/memorytracker.go

View workflow job for this annotation

GitHub Actions / basics (ubuntu-latest, mongodb-linux-x86_64-ubuntu1804, 1.20)

package slices is not in GOROOT (/opt/hostedtoolcache/go/1.20.14/x64/src/slices)
"sync"

"github.com/10gen/migration-verifier/internal/logger"
"github.com/10gen/migration-verifier/internal/reportutils"
)

type Unit = int64
type reader = <-chan Unit
type Writer = chan<- Unit

type Tracker struct {
logger *logger.Logger
softLimit Unit
curUsage Unit
selectCases []reflect.SelectCase
mux sync.RWMutex
}

func Start(ctx context.Context, logger *logger.Logger, max Unit) *Tracker {
tracker := Tracker{
softLimit: max,
logger: logger,
}

go tracker.track(ctx)

return &tracker
}

func (mt *Tracker) AddWriter() Writer {
mt.mux.RLock()
defer mt.mux.RUnlock()

newChan := make(chan Unit)

mt.selectCases = append(mt.selectCases, reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(reader(newChan)),
})

return newChan
}

func (mt *Tracker) getSelectCases(ctx context.Context) []reflect.SelectCase {
mt.mux.RLock()
defer mt.mux.RUnlock()

cases := make([]reflect.SelectCase, 1+len(mt.selectCases))
cases[0] = reflect.SelectCase{
Dir: reflect.SelectRecv,
Chan: reflect.ValueOf(ctx.Done()),
}

for i := range mt.selectCases {
cases[1+i] = mt.selectCases[i]
}

return cases
}

func (mt *Tracker) removeSelectCase(i int) {
mt.mux.Lock()
defer mt.mux.Unlock()

mt.selectCases = slices.Delete(mt.selectCases, i-1, i)
}

func (mt *Tracker) track(ctx context.Context) {
for {
if mt.curUsage > mt.softLimit {
mt.logger.Panic().
Int64("usage", mt.curUsage).
Int64("softLimit", mt.softLimit).
Msg("track() loop should never be in memory excess!")
}

selectCases := mt.getSelectCases(ctx)

chosen, gotVal, alive := reflect.Select(selectCases)

if chosen == 0 {
mt.logger.Debug().
AnErr("contextErr", context.Cause(ctx)).
Msg("Stopping memory tracker.")

return
}

got := (gotVal.Interface()).(Unit)
mt.curUsage += got

if got < 0 {
mt.logger.Info().
Str("reclaimed", reportutils.FmtBytes(-got)).
Str("tracked", reportutils.FmtBytes(mt.curUsage)).
Msg("Reclaimed tracked memory.")
}

if !alive {
if got != 0 {
mt.logger.Panic().
Int64("receivedValue", got).
Msg("Got nonzero track value but channel is closed.")
}

// Closure of a channel indicates that the worker thread is
// finished.
mt.removeSelectCase(chosen)

continue
}

didSingleThread := false

for mt.curUsage > mt.softLimit {
reader := (selectCases[chosen].Chan.Interface()).(reader)

if !didSingleThread {
mt.logger.Warn().
Str("usage", reportutils.FmtBytes(mt.curUsage)).
Str("softLimit", reportutils.FmtBytes(mt.softLimit)).
Msg("Tracked memory usage now exceeds soft limit. Suspending concurrent reads until tracked usage falls.")

didSingleThread = true
}

got, alive := <-reader
mt.curUsage += got

if !alive {
mt.removeSelectCase(chosen)
}
}

if didSingleThread {
mt.logger.Info().
Str("usage", reportutils.FmtBytes(mt.curUsage)).
Str("softLimit", reportutils.FmtBytes(mt.softLimit)).
Msg("Tracked memory usage is now below soft limit. Resuming concurrent reads.")
}
}
}
12 changes: 9 additions & 3 deletions internal/reportutils/reportutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@ import (

const decimalPrecision = 2

// This could include signed ints, but we have no need for now.
// The bigger requirement is that it exclude uint8.
// This must exclude uint8.
type num16Plus interface {
constraints.Float | ~uint | ~uint16 | ~uint32 | ~uint64
constraints.Float | ~uint | ~uint16 | ~uint32 | ~uint64 | ~int64
}

type realNum interface {
Expand Down Expand Up @@ -68,6 +67,13 @@ func DurationToHMS(duration time.Duration) string {
return str
}

// FmtBytes is a convenience that combines BytesToUnit with FindBestUnit.
// Use it to format a single count of bytes.
func FmtBytes[T num16Plus](count T) string {
unit := FindBestUnit(count)
return BytesToUnit(count, unit) + " " + string(unit)
}

// BytesToUnit returns a stringified number that represents `count`
// in the given `unit`. For example, count=1024 and unit=KiB would
// return "1".
Expand Down
10 changes: 7 additions & 3 deletions internal/verifier/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"sync"
"time"

"github.com/10gen/migration-verifier/internal/memorytracker"
"github.com/10gen/migration-verifier/internal/retry"
mapset "github.com/deckarep/golang-set/v2"
"github.com/pkg/errors"
Expand Down Expand Up @@ -62,11 +63,14 @@ func (verifier *Verifier) waitForChangeStream() error {

func (verifier *Verifier) CheckWorker(ctx context.Context) error {
verifier.logger.Debug().Msgf("Starting %d verification workers", verifier.numWorkers)
memTracker := memorytracker.Start(ctx, verifier.logger, 40_000_000_000) // TODO
ctx, cancel := context.WithCancel(ctx)
wg := sync.WaitGroup{}
for i := 0; i < verifier.numWorkers; i++ {
wg.Add(1)
go verifier.Work(ctx, i, &wg)
trackerWriter := memTracker.AddWriter()
defer close(trackerWriter)
go verifier.Work(ctx, i, &wg, trackerWriter)
time.Sleep(10 * time.Millisecond)
}

Expand Down Expand Up @@ -345,7 +349,7 @@ func FetchFailedAndIncompleteTasks(ctx context.Context, coll *mongo.Collection,
return FailedTasks, IncompleteTasks, nil
}

func (verifier *Verifier) Work(ctx context.Context, workerNum int, wg *sync.WaitGroup) {
func (verifier *Verifier) Work(ctx context.Context, workerNum int, wg *sync.WaitGroup, trackerWriter memorytracker.Writer) {
defer wg.Done()
verifier.logger.Debug().Msgf("[Worker %d] Started", workerNum)
for {
Expand All @@ -371,7 +375,7 @@ func (verifier *Verifier) Work(ctx context.Context, workerNum int, wg *sync.Wait
}
}
} else {
verifier.ProcessVerifyTask(workerNum, task)
verifier.ProcessVerifyTask(workerNum, task, trackerWriter)
}
}
}
Expand Down
Loading
Loading