@@ -3,6 +3,7 @@ package hamt
3
3
import (
4
4
"bytes"
5
5
"context"
6
+ "crypto/sha256"
6
7
"encoding/hex"
7
8
"fmt"
8
9
"math/rand"
@@ -68,41 +69,27 @@ var shortIdentityHash = func(k []byte) []byte {
68
69
return res
69
70
}
70
71
71
- var murmurHash = hash
72
-
73
72
func TestCanonicalStructure (t * testing.T ) {
74
- hash = identityHash
75
- defer func () {
76
- hash = murmurHash
77
- }()
78
- addAndRemoveKeys (t , defaultBitWidth , []string {"K" }, []string {"B" })
79
- addAndRemoveKeys (t , defaultBitWidth , []string {"K0" , "K1" , "KAA1" , "KAA2" , "KAA3" }, []string {"KAA4" })
73
+ addAndRemoveKeys (t , []string {"K" }, []string {"B" }, UseHashFunction (identityHash ))
74
+ addAndRemoveKeys (t , []string {"K0" , "K1" , "KAA1" , "KAA2" , "KAA3" }, []string {"KAA4" })
80
75
}
81
76
82
77
func TestCanonicalStructureAlternateBitWidth (t * testing.T ) {
83
- hash = identityHash
84
- defer func () {
85
- hash = murmurHash
86
- }()
87
- addAndRemoveKeys (t , 7 , []string {"K" }, []string {"B" })
88
- addAndRemoveKeys (t , 7 , []string {"K0" , "K1" , "KAA1" , "KAA2" , "KAA3" }, []string {"KAA4" })
89
- addAndRemoveKeys (t , 6 , []string {"K" }, []string {"B" })
90
- addAndRemoveKeys (t , 6 , []string {"K0" , "K1" , "KAA1" , "KAA2" , "KAA3" }, []string {"KAA4" })
91
- addAndRemoveKeys (t , 5 , []string {"K" }, []string {"B" })
92
- addAndRemoveKeys (t , 5 , []string {"K0" , "K1" , "KAA1" , "KAA2" , "KAA3" }, []string {"KAA4" })
78
+ addAndRemoveKeys (t , []string {"K" }, []string {"B" }, UseTreeBitWidth (7 ), UseHashFunction (identityHash ))
79
+ addAndRemoveKeys (t , []string {"K0" , "K1" , "KAA1" , "KAA2" , "KAA3" }, []string {"KAA4" }, UseTreeBitWidth (7 ), UseHashFunction (identityHash ))
80
+ addAndRemoveKeys (t , []string {"K" }, []string {"B" }, UseTreeBitWidth (6 ), UseHashFunction (identityHash ))
81
+ addAndRemoveKeys (t , []string {"K0" , "K1" , "KAA1" , "KAA2" , "KAA3" }, []string {"KAA4" }, UseTreeBitWidth (6 ), UseHashFunction (identityHash ))
82
+ addAndRemoveKeys (t , []string {"K" }, []string {"B" }, UseTreeBitWidth (5 ), UseHashFunction (identityHash ))
83
+ addAndRemoveKeys (t , []string {"K0" , "K1" , "KAA1" , "KAA2" , "KAA3" }, []string {"KAA4" }, UseTreeBitWidth (5 ), UseHashFunction (identityHash ))
93
84
}
94
85
func TestOverflow (t * testing.T ) {
95
- hash = identityHash
96
- defer func () {
97
- hash = murmurHash
98
- }()
99
86
keys := make ([]string , 4 )
100
87
for i := range keys {
101
88
keys [i ] = strings .Repeat ("A" , 32 ) + fmt .Sprintf ("%d" , i )
102
89
}
103
90
104
91
cs := cbor .NewCborStore (newMockBlocks ())
105
- n := NewNode (cs )
92
+ n := NewNode (cs , UseHashFunction ( identityHash ) )
106
93
for _ , k := range keys [:3 ] {
107
94
if err := n .Set (context .Background (), k , "foobar" ); err != nil {
108
95
t .Error (err )
@@ -120,13 +107,13 @@ func TestOverflow(t *testing.T) {
120
107
}
121
108
122
109
// Now, try fetching with a shorter hash function.
123
- hash = shortIdentityHash
110
+ n . hash = shortIdentityHash
124
111
if err := n .Find (context .Background (), keys [0 ], nil ); err != ErrMaxDepth {
125
112
t .Errorf ("expected error %q, got %q" , ErrMaxDepth , err )
126
113
}
127
114
}
128
115
129
- func addAndRemoveKeys (t * testing.T , bitWidth int , keys []string , extraKeys []string ) {
116
+ func addAndRemoveKeys (t * testing.T , keys []string , extraKeys []string , options ... Option ) {
130
117
ctx := context .Background ()
131
118
vals := make (map [string ][]byte )
132
119
for i := 0 ; i < len (keys ); i ++ {
@@ -135,7 +122,7 @@ func addAndRemoveKeys(t *testing.T, bitWidth int, keys []string, extraKeys []str
135
122
}
136
123
137
124
cs := cbor .NewCborStore (newMockBlocks ())
138
- begn := NewNode (cs , UseTreeBitWidth ( bitWidth ) )
125
+ begn := NewNode (cs , options ... )
139
126
for _ , k := range keys {
140
127
if err := begn .Set (ctx , k , vals [k ]); err != nil {
141
128
t .Fatal (err )
@@ -158,7 +145,8 @@ func addAndRemoveKeys(t *testing.T, bitWidth int, keys []string, extraKeys []str
158
145
t .Fatal (err )
159
146
}
160
147
n .store = cs
161
- n .bitWidth = bitWidth
148
+ n .hash = begn .hash
149
+ n .bitWidth = begn .bitWidth
162
150
for k , v := range vals {
163
151
var out []byte
164
152
err := n .Find (ctx , k , & out )
@@ -193,7 +181,8 @@ func addAndRemoveKeys(t *testing.T, bitWidth int, keys []string, extraKeys []str
193
181
t .Fatal (err )
194
182
}
195
183
n2 .store = cs
196
- n2 .bitWidth = bitWidth
184
+ n2 .hash = begn .hash
185
+ n2 .bitWidth = begn .bitWidth
197
186
if ! nodesEqual (t , cs , & n , & n2 ) {
198
187
t .Fatal ("nodes should be equal" )
199
188
}
@@ -205,7 +194,7 @@ func dotGraphRec(n *Node, name *int) {
205
194
if p .isShard () {
206
195
* name ++
207
196
fmt .Printf ("\t n%d -> n%d;\n " , cur , * name )
208
- nd , err := p .loadChild (context .Background (), n .store , n .bitWidth )
197
+ nd , err := p .loadChild (context .Background (), n .store , n .bitWidth , n . hash )
209
198
if err != nil {
210
199
panic (err )
211
200
}
@@ -237,7 +226,7 @@ func statsrec(n *Node, st *hamtStats) {
237
226
st .totalNodes ++
238
227
for _ , p := range n .Pointers {
239
228
if p .isShard () {
240
- nd , err := p .loadChild (context .Background (), n .store , n .bitWidth )
229
+ nd , err := p .loadChild (context .Background (), n .store , n .bitWidth , n . hash )
241
230
if err != nil {
242
231
panic (err )
243
232
}
@@ -251,17 +240,28 @@ func statsrec(n *Node, st *hamtStats) {
251
240
}
252
241
253
242
func TestHash (t * testing.T ) {
254
- h1 := hash ([]byte ("abcd" ))
255
- h2 := hash ([]byte ("abce" ))
243
+ h1 := defaultHashFunction ([]byte ("abcd" ))
244
+ h2 := defaultHashFunction ([]byte ("abce" ))
256
245
if h1 [0 ] == h2 [0 ] && h1 [1 ] == h2 [1 ] && h1 [3 ] == h2 [3 ] {
257
246
t .Fatal ("Hash should give different strings different hash prefixes" )
258
247
}
259
248
}
260
249
261
250
func TestBasic (t * testing.T ) {
251
+ testBasic (t )
252
+ }
253
+
254
+ func TestSha256 (t * testing.T ) {
255
+ testBasic (t , UseHashFunction (func (in []byte ) []byte {
256
+ out := sha256 .Sum256 (in )
257
+ return out [:]
258
+ }))
259
+ }
260
+
261
+ func testBasic (t * testing.T , options ... Option ) {
262
262
ctx := context .Background ()
263
263
cs := cbor .NewCborStore (newMockBlocks ())
264
- begn := NewNode (cs )
264
+ begn := NewNode (cs , options ... )
265
265
266
266
val := []byte ("cat dog bear" )
267
267
if err := begn .Set (ctx , "foo" , val ); err != nil {
@@ -282,7 +282,7 @@ func TestBasic(t *testing.T) {
282
282
t .Fatal (err )
283
283
}
284
284
285
- n , err := LoadNode (ctx , cs , c )
285
+ n , err := LoadNode (ctx , cs , c , options ... )
286
286
if err != nil {
287
287
t .Fatal (err )
288
288
}
@@ -381,6 +381,7 @@ func TestSetGet(t *testing.T) {
381
381
t .Fatal (err )
382
382
}
383
383
n .store = cs
384
+ n .hash = defaultHashFunction
384
385
n .bitWidth = defaultBitWidth
385
386
bef = time .Now ()
386
387
//for k, v := range vals {
0 commit comments