Skip to content

Commit 0e6a2fd

Browse files
committed
memory tracker
1 parent 0863b5d commit 0e6a2fd

File tree

4 files changed

+147
-14
lines changed

4 files changed

+147
-14
lines changed

internal/documentmap/documentmap.go

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ import (
3434
"fmt"
3535

3636
"github.com/10gen/migration-verifier/internal/logger"
37+
"github.com/10gen/migration-verifier/internal/memorytracker"
38+
"github.com/10gen/migration-verifier/internal/reportutils"
3739
"github.com/10gen/migration-verifier/internal/types"
3840
"go.mongodb.org/mongo-driver/bson"
3941
"go.mongodb.org/mongo-driver/mongo"
@@ -92,14 +94,14 @@ func (m *Map) CloneEmpty() *Map {
9294
// own goroutine.
9395
//
9496
// As a safeguard, this panics if called more than once.
95-
func (m *Map) ImportFromCursor(ctx context.Context, cursor *mongo.Cursor) error {
97+
func (m *Map) ImportFromCursor(ctx context.Context, cursor *mongo.Cursor, trackerWriter memorytracker.Writer) error {
9698
if m.imported {
9799
panic("Refuse duplicate call!")
98100
}
99101

100102
m.imported = true
101103

102-
var bytesReturned int64
104+
var bytesReturned uint64
103105
bytesReturned, nDocumentsReturned := 0, 0
104106

105107
for cursor.Next(ctx) {
@@ -113,11 +115,17 @@ func (m *Map) ImportFromCursor(ctx context.Context, cursor *mongo.Cursor) error
113115
}
114116

115117
nDocumentsReturned++
116-
bytesReturned += (int64)(len(cursor.Current))
118+
bytesReturned += (uint64)(len(cursor.Current))
119+
120+
// This will block if needs be to prevent OOMs.
121+
trackerWriter <- memorytracker.Unit(bytesReturned)
117122

118123
m.copyAndAddDocument(cursor.Current)
119124
}
120-
m.logger.Debug().Msgf("Find returned %d documents containing %d bytes", nDocumentsReturned, bytesReturned)
125+
m.logger.Debug().
126+
Int("documentedReturned", nDocumentsReturned).
127+
Str("totalSize", reportutils.BytesToUnit(bytesReturned, reportutils.FindBestUnit(bytesReturned))).
128+
Msgf("Finished reading %#q query.", "find")
121129

122130
return nil
123131
}
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
package memorytracker
2+
3+
import (
4+
"reflect"
5+
"slices"
6+
"sync"
7+
8+
"github.com/10gen/migration-verifier/internal/logger"
9+
)
10+
11+
type Unit = int64
12+
type reader = <-chan Unit
13+
type writer = chan<- Unit
14+
15+
type Tracker struct {
16+
logger *logger.Logger
17+
max Unit
18+
cur Unit
19+
selectCases []reflect.SelectCase
20+
mux sync.RWMutex
21+
}
22+
23+
func Start(logger *logger.Logger, max Unit) *Tracker {
24+
tracker := Tracker{max: max}
25+
26+
go tracker.track()
27+
28+
return &tracker
29+
}
30+
31+
func (mt *Tracker) AddWriter() writer {
32+
mt.mux.RLock()
33+
defer mt.mux.RUnlock()
34+
35+
newChan := make(chan Unit)
36+
37+
mt.selectCases = append(mt.selectCases, reflect.SelectCase{
38+
Dir: reflect.SelectRecv,
39+
Chan: reflect.ValueOf(newChan),
40+
})
41+
42+
return newChan
43+
}
44+
45+
func (mt *Tracker) getSelectCases() []reflect.SelectCase {
46+
mt.mux.RLock()
47+
defer mt.mux.RUnlock()
48+
49+
return slices.Clone(mt.selectCases)
50+
}
51+
52+
func (mt *Tracker) removeSelectCase(i int) {
53+
mt.mux.Lock()
54+
defer mt.mux.Unlock()
55+
56+
mt.selectCases = slices.Delete(mt.selectCases, i, 1+i)
57+
}
58+
59+
func (mt *Tracker) track() {
60+
for {
61+
if mt.cur <= mt.max {
62+
mt.logger.Panic().
63+
Int64("usage", mt.cur).
64+
Int64("softLimit", mt.max).
65+
Msg("track() loop should never be in memory excess!")
66+
}
67+
68+
selectCases := mt.getSelectCases()
69+
70+
chosen, gotVal, alive := reflect.Select(selectCases)
71+
72+
got := (gotVal.Interface()).(Unit)
73+
mt.cur += got
74+
75+
if alive {
76+
if got == 0 {
77+
mt.logger.Panic().Msg("Got zero track value but channel is not closed.")
78+
}
79+
} else {
80+
if got != 0 {
81+
mt.logger.Panic().
82+
Int64("receivedValue", got).
83+
Msg("Got nonzero track value but channel is closed.")
84+
}
85+
86+
mt.removeSelectCase(chosen)
87+
continue
88+
}
89+
90+
didSingleThread := false
91+
92+
for mt.cur > mt.max {
93+
reader := (selectCases[chosen].Chan.Interface()).(reader)
94+
95+
if !didSingleThread {
96+
mt.logger.Warn().
97+
Int64("usage", mt.cur).
98+
Int64("softLimit", mt.max).
99+
Msg("Tracked memory usage now exceeds soft limit. Suspending concurrent reads until tracked usage falls.")
100+
101+
didSingleThread = true
102+
}
103+
104+
got, alive := <-reader
105+
mt.cur += got
106+
107+
if !alive {
108+
mt.removeSelectCase(chosen)
109+
}
110+
}
111+
112+
if didSingleThread {
113+
mt.logger.Info().
114+
Int64("usage", mt.cur).
115+
Int64("softLimit", mt.max).
116+
Msg("Tracked memory usage is now below soft limit. Resuming concurrent reads.")
117+
}
118+
}
119+
}

internal/verifier/check.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"sync"
77
"time"
88

9+
"github.com/10gen/migration-verifier/internal/memorytracker"
910
"github.com/10gen/migration-verifier/internal/retry"
1011
mapset "github.com/deckarep/golang-set/v2"
1112
"github.com/pkg/errors"
@@ -62,11 +63,13 @@ func (verifier *Verifier) waitForChangeStream() error {
6263

6364
func (verifier *Verifier) CheckWorker(ctx context.Context) error {
6465
verifier.logger.Debug().Msgf("Starting %d verification workers", verifier.numWorkers)
66+
memTracker := memorytracker.Start(verifier.logger, 40_000_000) // TODO
6567
ctx, cancel := context.WithCancel(ctx)
6668
wg := sync.WaitGroup{}
6769
for i := 0; i < verifier.numWorkers; i++ {
6870
wg.Add(1)
69-
go verifier.Work(ctx, i, &wg)
71+
trackerWriter := memTracker.AddWriter()
72+
go verifier.Work(ctx, i, &wg, trackerWriter)
7073
time.Sleep(10 * time.Millisecond)
7174
}
7275

@@ -345,7 +348,7 @@ func FetchFailedAndIncompleteTasks(ctx context.Context, coll *mongo.Collection,
345348
return FailedTasks, IncompleteTasks, nil
346349
}
347350

348-
func (verifier *Verifier) Work(ctx context.Context, workerNum int, wg *sync.WaitGroup) {
351+
func (verifier *Verifier) Work(ctx context.Context, workerNum int, wg *sync.WaitGroup, trackerWriter memorytracker.Writer) {
349352
defer wg.Done()
350353
verifier.logger.Debug().Msgf("[Worker %d] Started", workerNum)
351354
for {
@@ -371,7 +374,7 @@ func (verifier *Verifier) Work(ctx context.Context, workerNum int, wg *sync.Wait
371374
}
372375
}
373376
} else {
374-
verifier.ProcessVerifyTask(workerNum, task)
377+
verifier.ProcessVerifyTask(workerNum, task, trackerWriter)
375378
}
376379
}
377380
}

internal/verifier/migration_verifier.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818

1919
"github.com/10gen/migration-verifier/internal/documentmap"
2020
"github.com/10gen/migration-verifier/internal/logger"
21+
"github.com/10gen/migration-verifier/internal/memorytracker"
2122
"github.com/10gen/migration-verifier/internal/partitions"
2223
"github.com/10gen/migration-verifier/internal/reportutils"
2324
"github.com/10gen/migration-verifier/internal/retry"
@@ -137,6 +138,8 @@ type Verifier struct {
137138
// The verifier only checks documents within the filter.
138139
globalFilter map[string]any
139140

141+
memoryTracker *memorytracker.Tracker
142+
140143
pprofInterval time.Duration
141144
}
142145

@@ -457,8 +460,8 @@ func (verifier *Verifier) getDocumentsCursor(ctx context.Context, collection *mo
457460
return collection.Database().RunCommandCursor(ctx, findCmd, runCommandOptions)
458461
}
459462

460-
func (verifier *Verifier) FetchAndCompareDocuments(task *VerificationTask) ([]VerificationResult, types.DocumentCount, types.ByteCount, error) {
461-
srcClientMap, dstClientMap, err := verifier.fetchDocuments(task)
463+
func (verifier *Verifier) FetchAndCompareDocuments(task *VerificationTask, trackerWriter memorytracker.Writer) ([]VerificationResult, types.DocumentCount, types.ByteCount, error) {
464+
srcClientMap, dstClientMap, err := verifier.fetchDocuments(task, trackerWriter)
462465
if err != nil {
463466
return nil, 0, 0, err
464467
}
@@ -472,7 +475,7 @@ func (verifier *Verifier) FetchAndCompareDocuments(task *VerificationTask) ([]Ve
472475
}
473476

474477
// This is split out to allow unit testing of fetching separate from comparison.
475-
func (verifier *Verifier) fetchDocuments(task *VerificationTask) (*documentmap.Map, *documentmap.Map, error) {
478+
func (verifier *Verifier) fetchDocuments(task *VerificationTask, trackerWriter memorytracker.Writer) (*documentmap.Map, *documentmap.Map, error) {
476479

477480
var srcErr, dstErr error
478481

@@ -489,7 +492,7 @@ func (verifier *Verifier) fetchDocuments(task *VerificationTask) (*documentmap.M
489492
verifier.srcStartAtTs, task)
490493

491494
if srcErr == nil {
492-
srcErr = srcClientMap.ImportFromCursor(ctx, cursor)
495+
srcErr = srcClientMap.ImportFromCursor(ctx, cursor, trackerWriter)
493496
}
494497

495498
return srcErr
@@ -501,7 +504,7 @@ func (verifier *Verifier) fetchDocuments(task *VerificationTask) (*documentmap.M
501504
nil /*startAtTs*/, task)
502505

503506
if dstErr == nil {
504-
dstErr = dstClientMap.ImportFromCursor(ctx, cursor)
507+
dstErr = dstClientMap.ImportFromCursor(ctx, cursor, trackerWriter)
505508
}
506509

507510
return dstErr
@@ -632,10 +635,10 @@ func (verifier *Verifier) compareOneDocument(srcClientDoc, dstClientDoc bson.Raw
632635
}}, nil
633636
}
634637

635-
func (verifier *Verifier) ProcessVerifyTask(workerNum int, task *VerificationTask) {
638+
func (verifier *Verifier) ProcessVerifyTask(workerNum int, task *VerificationTask, trackerWriter memorytracker.Writer) {
636639
verifier.logger.Debug().Msgf("[Worker %d] Processing verify task", workerNum)
637640

638-
problems, docsCount, bytesCount, err := verifier.FetchAndCompareDocuments(task)
641+
problems, docsCount, bytesCount, err := verifier.FetchAndCompareDocuments(task, trackerWriter)
639642

640643
if err != nil {
641644
task.Status = verificationTaskFailed

0 commit comments

Comments
 (0)