@@ -13,22 +13,28 @@ import (
13
13
)
14
14
15
15
const arrayWidth = 3
16
+ const defaultBitWidth = 8
16
17
17
18
type Node struct {
18
19
Bitfield * big.Int `refmt:"bf"`
19
20
Pointers []* Pointer `refmt:"p"`
20
21
21
22
// for fetching and storing children
22
- store * CborIpldStore
23
+ store * CborIpldStore
24
+ bitWidth int
23
25
}
24
26
25
27
// Option is a function that configures the node
26
28
type Option func (* Node )
27
29
28
30
// UseTreeBitWidth allows you to set the width of the HAMT tree
29
31
// in bits (from 1-8) via a customized hash function
30
- func UseTreeBitWidth (bitWidth uint8 ) Option {
31
- return func (* Node ) {}
32
+ func UseTreeBitWidth (bitWidth int ) Option {
33
+ return func (nd * Node ) {
34
+ if bitWidth > 0 && bitWidth <= 8 {
35
+ nd .bitWidth = bitWidth
36
+ }
37
+ }
32
38
}
33
39
34
40
// NewNode creates a new IPLD HAMT Node with the given store and given
@@ -38,6 +44,7 @@ func NewNode(cs *CborIpldStore, options ...Option) *Node {
38
44
Bitfield : big .NewInt (0 ),
39
45
Pointers : make ([]* Pointer , 0 ),
40
46
store : cs ,
47
+ bitWidth : defaultBitWidth ,
41
48
}
42
49
// apply functional options to node before using
43
50
for _ , option := range options {
@@ -60,7 +67,7 @@ type Pointer struct {
60
67
}
61
68
62
69
func (n * Node ) Find (ctx context.Context , k string , out interface {}) error {
63
- return n .getValue (ctx , hash (k ), 0 , k , func (kv * KV ) error {
70
+ return n .getValue (ctx , & hashBits { b : hash (k )} , k , func (kv * KV ) error {
64
71
// used to just see if the thing exists in the set
65
72
if out == nil {
66
73
return nil
@@ -79,32 +86,32 @@ func (n *Node) Find(ctx context.Context, k string, out interface{}) error {
79
86
}
80
87
81
88
func (n * Node ) Delete (ctx context.Context , k string ) error {
82
- return n .modifyValue (ctx , hash (k ), 0 , k , nil )
89
+ return n .modifyValue (ctx , & hashBits { b : hash (k )} , k , nil )
83
90
}
84
91
85
92
var ErrNotFound = fmt .Errorf ("not found" )
86
93
var ErrMaxDepth = fmt .Errorf ("attempted to traverse hamt beyond max depth" )
87
94
88
- func (n * Node ) getValue (ctx context.Context , hv []byte , depth int , k string , cb func (* KV ) error ) error {
89
- if depth >= len (hv ) {
95
+ func (n * Node ) getValue (ctx context.Context , hv * hashBits , k string , cb func (* KV ) error ) error {
96
+ idx , err := hv .Next (n .bitWidth )
97
+ if err != nil {
90
98
return ErrMaxDepth
91
99
}
92
100
93
- idx := hv [depth ]
94
- if n .Bitfield .Bit (int (idx )) == 0 {
101
+ if n .Bitfield .Bit (idx ) == 0 {
95
102
return ErrNotFound
96
103
}
97
104
98
- cindex := byte (n .indexForBitPos (int ( idx ) ))
105
+ cindex := byte (n .indexForBitPos (idx ))
99
106
100
107
c := n .getChild (cindex )
101
108
if c .isShard () {
102
- chnd , err := c .loadChild (ctx , n .store )
109
+ chnd , err := c .loadChild (ctx , n .store , n . bitWidth )
103
110
if err != nil {
104
111
return err
105
112
}
106
113
107
- return chnd .getValue (ctx , hv , depth + 1 , k , cb )
114
+ return chnd .getValue (ctx , hv , k , cb )
108
115
}
109
116
110
117
for _ , kv := range c .KVs {
@@ -116,12 +123,13 @@ func (n *Node) getValue(ctx context.Context, hv []byte, depth int, k string, cb
116
123
return ErrNotFound
117
124
}
118
125
119
- func (p * Pointer ) loadChild (ctx context.Context , ns * CborIpldStore ) (* Node , error ) {
126
+ func (p * Pointer ) loadChild (ctx context.Context , ns * CborIpldStore , bitWidth int ) (* Node , error ) {
120
127
if p .cache != nil {
121
128
return p .cache , nil
122
129
}
123
130
124
131
out , err := LoadNode (ctx , ns , p .Link )
132
+ out .bitWidth = bitWidth
125
133
if err != nil {
126
134
return nil , err
127
135
}
@@ -137,7 +145,7 @@ func LoadNode(ctx context.Context, cs *CborIpldStore, c cid.Cid, options ...Opti
137
145
}
138
146
139
147
out .store = cs
140
-
148
+ out . bitWidth = defaultBitWidth
141
149
// apply functional options to node before using
142
150
for _ , option := range options {
143
151
option (& out )
@@ -160,7 +168,7 @@ func (n *Node) checkSize(ctx context.Context) (uint64, error) {
160
168
totsize := uint64 (len (blk .RawData ()))
161
169
for _ , ch := range n .Pointers {
162
170
if ch .isShard () {
163
- chnd , err := ch .loadChild (ctx , n .store )
171
+ chnd , err := ch .loadChild (ctx , n .store , n . bitWidth )
164
172
if err != nil {
165
173
return 0 , err
166
174
}
@@ -212,7 +220,7 @@ func (n *Node) Set(ctx context.Context, k string, v interface{}) error {
212
220
d = & cbg.Deferred {Raw : b }
213
221
}
214
222
215
- return n .modifyValue (ctx , hash (k ), 0 , k , d )
223
+ return n .modifyValue (ctx , & hashBits { b : hash (k )} , k , d )
216
224
}
217
225
218
226
func (n * Node ) cleanChild (chnd * Node , cindex byte ) error {
@@ -249,11 +257,11 @@ func (n *Node) cleanChild(chnd *Node, cindex byte) error {
249
257
}
250
258
}
251
259
252
- func (n * Node ) modifyValue (ctx context.Context , hv []byte , depth int , k string , v * cbg.Deferred ) error {
253
- if depth >= len (hv ) {
260
+ func (n * Node ) modifyValue (ctx context.Context , hv * hashBits , k string , v * cbg.Deferred ) error {
261
+ idx , err := hv .Next (n .bitWidth )
262
+ if err != nil {
254
263
return ErrMaxDepth
255
264
}
256
- idx := int (hv [depth ])
257
265
258
266
if n .Bitfield .Bit (idx ) != 1 {
259
267
return n .insertChild (idx , k , v )
@@ -263,12 +271,12 @@ func (n *Node) modifyValue(ctx context.Context, hv []byte, depth int, k string,
263
271
264
272
child := n .getChild (cindex )
265
273
if child .isShard () {
266
- chnd , err := child .loadChild (ctx , n .store )
274
+ chnd , err := child .loadChild (ctx , n .store , n . bitWidth )
267
275
if err != nil {
268
276
return err
269
277
}
270
278
271
- if err := chnd .modifyValue (ctx , hv , depth + 1 , k , v ); err != nil {
279
+ if err := chnd .modifyValue (ctx , hv , k , v ); err != nil {
272
280
return err
273
281
}
274
282
@@ -308,12 +316,15 @@ func (n *Node) modifyValue(ctx context.Context, hv []byte, depth int, k string,
308
316
// If the array is full, create a subshard and insert everything into it
309
317
if len (child .KVs ) >= arrayWidth {
310
318
sub := NewNode (n .store )
311
- if err := sub .modifyValue (ctx , hv , depth + 1 , k , v ); err != nil {
319
+ sub .bitWidth = n .bitWidth
320
+ hvcopy := & hashBits {b : hv .b , consumed : hv .consumed }
321
+ if err := sub .modifyValue (ctx , hvcopy , k , v ); err != nil {
312
322
return err
313
323
}
314
324
315
325
for _ , p := range child .KVs {
316
- if err := sub .modifyValue (ctx , hash (p .Key ), depth + 1 , p .Key , p .Value ); err != nil {
326
+ chhv := & hashBits {b : hash (p .Key ), consumed : hv .consumed }
327
+ if err := sub .modifyValue (ctx , chhv , p .Key , p .Value ); err != nil {
317
328
return err
318
329
}
319
330
}
@@ -375,6 +386,7 @@ func (n *Node) getChild(i byte) *Pointer {
375
386
376
387
func (n * Node ) Copy () * Node {
377
388
nn := NewNode (n .store )
389
+ nn .bitWidth = n .bitWidth
378
390
nn .Bitfield .Set (n .Bitfield )
379
391
nn .Pointers = make ([]* Pointer , len (n .Pointers ))
380
392
@@ -403,7 +415,7 @@ func (p *Pointer) isShard() bool {
403
415
func (n * Node ) ForEach (ctx context.Context , f func (k string , val interface {}) error ) error {
404
416
for _ , p := range n .Pointers {
405
417
if p .isShard () {
406
- chnd , err := p .loadChild (ctx , n .store )
418
+ chnd , err := p .loadChild (ctx , n .store , n . bitWidth )
407
419
if err != nil {
408
420
return err
409
421
}
0 commit comments