Skip to content

Commit f65a943

Browse files
committed
parallel ForEach unit test
1 parent 15c1451 commit f65a943

File tree

1 file changed

+195
-3
lines changed

1 file changed

+195
-3
lines changed

amt_test.go

Lines changed: 195 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,18 @@ func init() {
3030
}
3131

3232
var (
33+
bitWidths3 = []uint{8}
3334
bitWidths2to18 = []uint{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18}
3435
bitWidths2to3 = []uint{2, 3}
3536
)
3637

38+
func runTestWithBitWidthsOnly(t *testing.T, bitwidths []uint, fn func(*testing.T, ...Option)) {
39+
t.Helper()
40+
for _, bw := range bitwidths {
41+
t.Run(fmt.Sprintf("bitwidth=%d", bw), func(t *testing.T) { fn(t, UseTreeBitWidth(bw)) })
42+
}
43+
}
44+
3745
func runTestWithBitWidths(t *testing.T, bitwidths []uint, fn func(*testing.T, ...Option)) {
3846
t.Helper()
3947
if testing.Short() {
@@ -62,7 +70,7 @@ func newMockBlocks() *mockBlocks {
6270
return &mockBlocks{make(map[cid.Cid]block.Block), sync.Mutex{}, 0, 0}
6371
}
6472

65-
func (mb *mockBlocks) Get(c cid.Cid) (block.Block, error) {
73+
func (mb *mockBlocks) Get(ctx context.Context, c cid.Cid) (block.Block, error) {
6674
mb.dataMu.Lock()
6775
defer mb.dataMu.Unlock()
6876
d, ok := mb.data[c]
@@ -73,14 +81,41 @@ func (mb *mockBlocks) Get(c cid.Cid) (block.Block, error) {
7381
return nil, fmt.Errorf("Not Found")
7482
}
7583

76-
func (mb *mockBlocks) Put(b block.Block) error {
84+
func (mb *mockBlocks) GetMany(ctx context.Context, cs []cid.Cid) ([]block.Block, []cid.Cid, error) {
85+
mb.dataMu.Lock()
86+
defer mb.dataMu.Unlock()
87+
blocks := make([]block.Block, 0, len(cs))
88+
missingCIDs := make([]cid.Cid, 0, len(cs))
89+
for _, c := range cs {
90+
mb.getCount++
91+
d, ok := mb.data[c]
92+
if !ok {
93+
missingCIDs = append(missingCIDs, c)
94+
} else {
95+
blocks = append(blocks, d)
96+
}
97+
}
98+
return blocks, missingCIDs, nil
99+
}
100+
101+
func (mb *mockBlocks) Put(ctx context.Context, b block.Block) error {
77102
mb.dataMu.Lock()
78103
defer mb.dataMu.Unlock()
79104
mb.putCount++
80105
mb.data[b.Cid()] = b
81106
return nil
82107
}
83108

109+
func (mb *mockBlocks) PutMany(ctx context.Context, bs []block.Block) error {
110+
mb.dataMu.Lock()
111+
defer mb.dataMu.Unlock()
112+
for _, b := range bs {
113+
mb.putCount++
114+
mb.data[b.Cid()] = b
115+
}
116+
return nil
117+
}
118+
84119
func (mb *mockBlocks) report(b *testing.B) {
85120
mb.dataMu.Lock()
86121
defer mb.dataMu.Unlock()
@@ -342,6 +377,7 @@ func TestForEachWithoutFlush(t *testing.T) {
342377
require.NoError(t, err)
343378
set1 := make(map[uint64]struct{})
344379
set2 := make(map[uint64]struct{})
380+
set3 := make(map[uint64]struct{})
345381
for _, val := range vals {
346382
err := amt.Set(ctx, val, cborstr(""))
347383
require.NoError(t, err)
@@ -357,14 +393,23 @@ func TestForEachWithoutFlush(t *testing.T) {
357393
assert.Equal(t, make(map[uint64]struct{}), set1)
358394

359395
// ensure it still works after flush
360-
_, err = amt.Flush(ctx)
396+
c, err := amt.Flush(ctx)
361397
require.NoError(t, err)
362398

363399
amt.ForEach(ctx, func(u uint64, deferred *cbg.Deferred) error {
364400
delete(set2, u)
365401
return nil
366402
})
367403
assert.Equal(t, make(map[uint64]struct{}), set2)
404+
405+
// ensure that it works with a loaded AMT
406+
loadedAMT, err := LoadAMT(ctx, bs, c, opts...)
407+
err = loadedAMT.ForEach(ctx, func(u uint64, deferred *cbg.Deferred) error {
408+
delete(set3, u)
409+
return nil
410+
})
411+
require.NoError(t, err)
412+
assert.Equal(t, make(map[uint64]struct{}), set3)
368413
}
369414
})
370415
}
@@ -794,6 +839,94 @@ func TestForEach(t *testing.T) {
794839
})
795840
}
796841

842+
func TestForEachParallel(t *testing.T) {
843+
bs := cbor.NewGetManyCborStore(newMockBlocks())
844+
ctx := context.Background()
845+
a, err := NewAMT(bs)
846+
require.NoError(t, err)
847+
848+
r := rand.New(rand.NewSource(101))
849+
850+
indexes := make(map[uint64]struct{})
851+
for i := 0; i < 10000; i++ {
852+
if r.Intn(2) == 0 {
853+
indexes[uint64(i)] = struct{}{}
854+
}
855+
}
856+
857+
for i := range indexes {
858+
if err := a.Set(ctx, i, cborstr("value")); err != nil {
859+
t.Fatal(err)
860+
}
861+
}
862+
863+
for i := range indexes {
864+
assertGet(ctx, t, a, i, "value")
865+
}
866+
867+
assertCount(t, a, uint64(len(indexes)))
868+
869+
// test before flush
870+
m := sync.Mutex{}
871+
foundVals := make(map[uint64]struct{})
872+
err = a.ForEachParallel(ctx, 16, func(i uint64, v *cbg.Deferred) error {
873+
m.Lock()
874+
foundVals[i] = struct{}{}
875+
m.Unlock()
876+
return nil
877+
})
878+
if err != nil {
879+
t.Fatal(err)
880+
}
881+
if len(foundVals) != len(indexes) {
882+
t.Fatal("didnt see enough values")
883+
}
884+
885+
c, err := a.Flush(ctx)
886+
if err != nil {
887+
t.Fatal(err)
888+
}
889+
890+
assertCount(t, a, uint64(len(indexes)))
891+
892+
// test after flush
893+
foundVals = make(map[uint64]struct{})
894+
err = a.ForEachParallel(ctx, 16, func(i uint64, v *cbg.Deferred) error {
895+
m.Lock()
896+
foundVals[i] = struct{}{}
897+
m.Unlock()
898+
return nil
899+
})
900+
if err != nil {
901+
t.Fatal(err)
902+
}
903+
if len(foundVals) != len(indexes) {
904+
t.Fatal("didnt see enough values")
905+
}
906+
907+
na, err := LoadAMT(ctx, bs, c)
908+
if err != nil {
909+
t.Fatal(err)
910+
}
911+
912+
assertCount(t, na, uint64(len(indexes)))
913+
914+
// test from loaded AMT
915+
foundVals = make(map[uint64]struct{})
916+
err = na.ForEachParallel(ctx, 16, func(i uint64, v *cbg.Deferred) error {
917+
m.Lock()
918+
foundVals[i] = struct{}{}
919+
m.Unlock()
920+
return nil
921+
})
922+
if err != nil {
923+
t.Fatal(err)
924+
}
925+
if len(foundVals) != len(indexes) {
926+
t.Fatal("didnt see enough values")
927+
}
928+
}
929+
797930
func TestForEachAt(t *testing.T) {
798931
runTestWithBitWidths(t, bitWidths2to18, func(t *testing.T, opts ...Option) {
799932
bs := cbor.NewCborStore(newMockBlocks())
@@ -858,6 +991,65 @@ func TestForEachAt(t *testing.T) {
858991
})
859992
}
860993

994+
func TestForEachAtParallel(t *testing.T) {
995+
runTestWithBitWidths(t, bitWidths2to18, func(t *testing.T, opts ...Option) {
996+
bs := cbor.NewGetManyCborStore(newMockBlocks())
997+
ctx := context.Background()
998+
a, err := NewAMT(bs, opts...)
999+
require.NoError(t, err)
1000+
1001+
r := rand.New(rand.NewSource(101))
1002+
1003+
var indexes []uint64
1004+
for i := 0; i < cbg.MaxLength; i++ { // above bitwidth 13, inserting more than cbg.MaxLength causes node.Values to exceed the cbg.MaxLength
1005+
indexes = append(indexes, uint64(i))
1006+
if err := a.Set(ctx, uint64(i), cborstr(fmt.Sprint(i))); err != nil {
1007+
t.Fatal(err)
1008+
}
1009+
}
1010+
1011+
for _, i := range indexes {
1012+
assertGet(ctx, t, a, i, fmt.Sprint(i))
1013+
}
1014+
1015+
assertCount(t, a, uint64(len(indexes)))
1016+
1017+
c, err := a.Flush(ctx)
1018+
if err != nil {
1019+
t.Fatal(err)
1020+
}
1021+
1022+
na, err := LoadAMT(ctx, bs, c, opts...)
1023+
if err != nil {
1024+
t.Fatal(err)
1025+
}
1026+
1027+
assertCount(t, na, uint64(len(indexes)))
1028+
m := sync.Mutex{}
1029+
for try := 0; try < 10; try++ {
1030+
start := uint64(r.Intn(cbg.MaxLength))
1031+
1032+
expectedIndexes := make(map[uint64]struct{})
1033+
for i := start; i < cbg.MaxLength; i++ {
1034+
expectedIndexes[i] = struct{}{}
1035+
}
1036+
1037+
err = na.ForEachAtParallel(ctx, 16, start, func(i uint64, v *cbg.Deferred) error {
1038+
m.Lock()
1039+
delete(expectedIndexes, i)
1040+
m.Unlock()
1041+
return nil
1042+
})
1043+
if err != nil {
1044+
t.Fatal(err)
1045+
}
1046+
if len(expectedIndexes) != 0 {
1047+
t.Fatal("didnt see enough values")
1048+
}
1049+
}
1050+
})
1051+
}
1052+
8611053
func TestFirstSetIndex(t *testing.T) {
8621054
runTestWithBitWidths(t, bitWidths2to18, func(t *testing.T, opts ...Option) {
8631055
bs := cbor.NewCborStore(newMockBlocks())

0 commit comments

Comments
 (0)