@@ -30,10 +30,18 @@ func init() {
30
30
}
31
31
32
32
var (
33
+ bitWidths3 = []uint {8 }
33
34
bitWidths2to18 = []uint {2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 , 15 , 16 , 17 , 18 }
34
35
bitWidths2to3 = []uint {2 , 3 }
35
36
)
36
37
38
+ func runTestWithBitWidthsOnly (t * testing.T , bitwidths []uint , fn func (* testing.T , ... Option )) {
39
+ t .Helper ()
40
+ for _ , bw := range bitwidths {
41
+ t .Run (fmt .Sprintf ("bitwidth=%d" , bw ), func (t * testing.T ) { fn (t , UseTreeBitWidth (bw )) })
42
+ }
43
+ }
44
+
37
45
func runTestWithBitWidths (t * testing.T , bitwidths []uint , fn func (* testing.T , ... Option )) {
38
46
t .Helper ()
39
47
if testing .Short () {
@@ -62,7 +70,7 @@ func newMockBlocks() *mockBlocks {
62
70
return & mockBlocks {make (map [cid.Cid ]block.Block ), sync.Mutex {}, 0 , 0 }
63
71
}
64
72
65
- func (mb * mockBlocks ) Get (c cid.Cid ) (block.Block , error ) {
73
+ func (mb * mockBlocks ) Get (ctx context. Context , c cid.Cid ) (block.Block , error ) {
66
74
mb .dataMu .Lock ()
67
75
defer mb .dataMu .Unlock ()
68
76
d , ok := mb .data [c ]
@@ -73,14 +81,41 @@ func (mb *mockBlocks) Get(c cid.Cid) (block.Block, error) {
73
81
return nil , fmt .Errorf ("Not Found" )
74
82
}
75
83
76
- func (mb * mockBlocks ) Put (b block.Block ) error {
84
+ func (mb * mockBlocks ) GetMany (ctx context.Context , cs []cid.Cid ) ([]block.Block , []cid.Cid , error ) {
85
+ mb .dataMu .Lock ()
86
+ defer mb .dataMu .Unlock ()
87
+ blocks := make ([]block.Block , 0 , len (cs ))
88
+ missingCIDs := make ([]cid.Cid , 0 , len (cs ))
89
+ for _ , c := range cs {
90
+ mb .getCount ++
91
+ d , ok := mb .data [c ]
92
+ if ! ok {
93
+ missingCIDs = append (missingCIDs , c )
94
+ } else {
95
+ blocks = append (blocks , d )
96
+ }
97
+ }
98
+ return blocks , missingCIDs , nil
99
+ }
100
+
101
+ func (mb * mockBlocks ) Put (ctx context.Context , b block.Block ) error {
77
102
mb .dataMu .Lock ()
78
103
defer mb .dataMu .Unlock ()
79
104
mb .putCount ++
80
105
mb .data [b .Cid ()] = b
81
106
return nil
82
107
}
83
108
109
+ func (mb * mockBlocks ) PutMany (ctx context.Context , bs []block.Block ) error {
110
+ mb .dataMu .Lock ()
111
+ defer mb .dataMu .Unlock ()
112
+ for _ , b := range bs {
113
+ mb .putCount ++
114
+ mb .data [b .Cid ()] = b
115
+ }
116
+ return nil
117
+ }
118
+
84
119
func (mb * mockBlocks ) report (b * testing.B ) {
85
120
mb .dataMu .Lock ()
86
121
defer mb .dataMu .Unlock ()
@@ -342,6 +377,7 @@ func TestForEachWithoutFlush(t *testing.T) {
342
377
require .NoError (t , err )
343
378
set1 := make (map [uint64 ]struct {})
344
379
set2 := make (map [uint64 ]struct {})
380
+ set3 := make (map [uint64 ]struct {})
345
381
for _ , val := range vals {
346
382
err := amt .Set (ctx , val , cborstr ("" ))
347
383
require .NoError (t , err )
@@ -357,14 +393,23 @@ func TestForEachWithoutFlush(t *testing.T) {
357
393
assert .Equal (t , make (map [uint64 ]struct {}), set1 )
358
394
359
395
// ensure it still works after flush
360
- _ , err = amt .Flush (ctx )
396
+ c , err : = amt .Flush (ctx )
361
397
require .NoError (t , err )
362
398
363
399
amt .ForEach (ctx , func (u uint64 , deferred * cbg.Deferred ) error {
364
400
delete (set2 , u )
365
401
return nil
366
402
})
367
403
assert .Equal (t , make (map [uint64 ]struct {}), set2 )
404
+
405
+ // ensure that it works with a loaded AMT
406
+ loadedAMT , err := LoadAMT (ctx , bs , c , opts ... )
407
+ err = loadedAMT .ForEach (ctx , func (u uint64 , deferred * cbg.Deferred ) error {
408
+ delete (set3 , u )
409
+ return nil
410
+ })
411
+ require .NoError (t , err )
412
+ assert .Equal (t , make (map [uint64 ]struct {}), set3 )
368
413
}
369
414
})
370
415
}
@@ -794,6 +839,94 @@ func TestForEach(t *testing.T) {
794
839
})
795
840
}
796
841
842
+ func TestForEachParallel (t * testing.T ) {
843
+ bs := cbor .NewGetManyCborStore (newMockBlocks ())
844
+ ctx := context .Background ()
845
+ a , err := NewAMT (bs )
846
+ require .NoError (t , err )
847
+
848
+ r := rand .New (rand .NewSource (101 ))
849
+
850
+ indexes := make (map [uint64 ]struct {})
851
+ for i := 0 ; i < 10000 ; i ++ {
852
+ if r .Intn (2 ) == 0 {
853
+ indexes [uint64 (i )] = struct {}{}
854
+ }
855
+ }
856
+
857
+ for i := range indexes {
858
+ if err := a .Set (ctx , i , cborstr ("value" )); err != nil {
859
+ t .Fatal (err )
860
+ }
861
+ }
862
+
863
+ for i := range indexes {
864
+ assertGet (ctx , t , a , i , "value" )
865
+ }
866
+
867
+ assertCount (t , a , uint64 (len (indexes )))
868
+
869
+ // test before flush
870
+ m := sync.Mutex {}
871
+ foundVals := make (map [uint64 ]struct {})
872
+ err = a .ForEachParallel (ctx , 16 , func (i uint64 , v * cbg.Deferred ) error {
873
+ m .Lock ()
874
+ foundVals [i ] = struct {}{}
875
+ m .Unlock ()
876
+ return nil
877
+ })
878
+ if err != nil {
879
+ t .Fatal (err )
880
+ }
881
+ if len (foundVals ) != len (indexes ) {
882
+ t .Fatal ("didnt see enough values" )
883
+ }
884
+
885
+ c , err := a .Flush (ctx )
886
+ if err != nil {
887
+ t .Fatal (err )
888
+ }
889
+
890
+ assertCount (t , a , uint64 (len (indexes )))
891
+
892
+ // test after flush
893
+ foundVals = make (map [uint64 ]struct {})
894
+ err = a .ForEachParallel (ctx , 16 , func (i uint64 , v * cbg.Deferred ) error {
895
+ m .Lock ()
896
+ foundVals [i ] = struct {}{}
897
+ m .Unlock ()
898
+ return nil
899
+ })
900
+ if err != nil {
901
+ t .Fatal (err )
902
+ }
903
+ if len (foundVals ) != len (indexes ) {
904
+ t .Fatal ("didnt see enough values" )
905
+ }
906
+
907
+ na , err := LoadAMT (ctx , bs , c )
908
+ if err != nil {
909
+ t .Fatal (err )
910
+ }
911
+
912
+ assertCount (t , na , uint64 (len (indexes )))
913
+
914
+ // test from loaded AMT
915
+ foundVals = make (map [uint64 ]struct {})
916
+ err = na .ForEachParallel (ctx , 16 , func (i uint64 , v * cbg.Deferred ) error {
917
+ m .Lock ()
918
+ foundVals [i ] = struct {}{}
919
+ m .Unlock ()
920
+ return nil
921
+ })
922
+ if err != nil {
923
+ t .Fatal (err )
924
+ }
925
+ if len (foundVals ) != len (indexes ) {
926
+ t .Fatal ("didnt see enough values" )
927
+ }
928
+ }
929
+
797
930
func TestForEachAt (t * testing.T ) {
798
931
runTestWithBitWidths (t , bitWidths2to18 , func (t * testing.T , opts ... Option ) {
799
932
bs := cbor .NewCborStore (newMockBlocks ())
@@ -858,6 +991,65 @@ func TestForEachAt(t *testing.T) {
858
991
})
859
992
}
860
993
994
+ func TestForEachAtParallel (t * testing.T ) {
995
+ runTestWithBitWidths (t , bitWidths2to18 , func (t * testing.T , opts ... Option ) {
996
+ bs := cbor .NewGetManyCborStore (newMockBlocks ())
997
+ ctx := context .Background ()
998
+ a , err := NewAMT (bs , opts ... )
999
+ require .NoError (t , err )
1000
+
1001
+ r := rand .New (rand .NewSource (101 ))
1002
+
1003
+ var indexes []uint64
1004
+ for i := 0 ; i < cbg .MaxLength ; i ++ { // above bitwidth 13, inserting more than cbg.MaxLength causes node.Values to exceed the cbg.MaxLength
1005
+ indexes = append (indexes , uint64 (i ))
1006
+ if err := a .Set (ctx , uint64 (i ), cborstr (fmt .Sprint (i ))); err != nil {
1007
+ t .Fatal (err )
1008
+ }
1009
+ }
1010
+
1011
+ for _ , i := range indexes {
1012
+ assertGet (ctx , t , a , i , fmt .Sprint (i ))
1013
+ }
1014
+
1015
+ assertCount (t , a , uint64 (len (indexes )))
1016
+
1017
+ c , err := a .Flush (ctx )
1018
+ if err != nil {
1019
+ t .Fatal (err )
1020
+ }
1021
+
1022
+ na , err := LoadAMT (ctx , bs , c , opts ... )
1023
+ if err != nil {
1024
+ t .Fatal (err )
1025
+ }
1026
+
1027
+ assertCount (t , na , uint64 (len (indexes )))
1028
+ m := sync.Mutex {}
1029
+ for try := 0 ; try < 10 ; try ++ {
1030
+ start := uint64 (r .Intn (cbg .MaxLength ))
1031
+
1032
+ expectedIndexes := make (map [uint64 ]struct {})
1033
+ for i := start ; i < cbg .MaxLength ; i ++ {
1034
+ expectedIndexes [i ] = struct {}{}
1035
+ }
1036
+
1037
+ err = na .ForEachAtParallel (ctx , 16 , start , func (i uint64 , v * cbg.Deferred ) error {
1038
+ m .Lock ()
1039
+ delete (expectedIndexes , i )
1040
+ m .Unlock ()
1041
+ return nil
1042
+ })
1043
+ if err != nil {
1044
+ t .Fatal (err )
1045
+ }
1046
+ if len (expectedIndexes ) != 0 {
1047
+ t .Fatal ("didnt see enough values" )
1048
+ }
1049
+ }
1050
+ })
1051
+ }
1052
+
861
1053
func TestFirstSetIndex (t * testing.T ) {
862
1054
runTestWithBitWidths (t , bitWidths2to18 , func (t * testing.T , opts ... Option ) {
863
1055
bs := cbor .NewCborStore (newMockBlocks ())
0 commit comments