Skip to content

Commit c51764d

Browse files
committed
feat(hamt): support alternate bitwidth
allow use of alternate bitwidths by making use of hashbits reader
1 parent 50955c2 commit c51764d

File tree

3 files changed

+58
-33
lines changed

3 files changed

+58
-33
lines changed

hamt.go

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,28 @@ import (
1313
)
1414

1515
const arrayWidth = 3
16+
const defaultBitWidth = 8
1617

1718
type Node struct {
1819
Bitfield *big.Int `refmt:"bf"`
1920
Pointers []*Pointer `refmt:"p"`
2021

2122
// for fetching and storing children
22-
store *CborIpldStore
23+
store *CborIpldStore
24+
bitWidth int
2325
}
2426

2527
// Option is a function that configures the node
2628
type Option func(*Node)
2729

2830
// UseTreeBitWidth allows you to set the width of the HAMT tree
2931
// in bits (from 1-8) via a customized hash function
30-
func UseTreeBitWidth(bitWidth uint8) Option {
31-
return func(*Node) {}
32+
func UseTreeBitWidth(bitWidth int) Option {
33+
return func(nd *Node) {
34+
if bitWidth > 0 && bitWidth <= 8 {
35+
nd.bitWidth = bitWidth
36+
}
37+
}
3238
}
3339

3440
// NewNode creates a new IPLD HAMT Node with the given store and given
@@ -38,6 +44,7 @@ func NewNode(cs *CborIpldStore, options ...Option) *Node {
3844
Bitfield: big.NewInt(0),
3945
Pointers: make([]*Pointer, 0),
4046
store: cs,
47+
bitWidth: defaultBitWidth,
4148
}
4249
// apply functional options to node before using
4350
for _, option := range options {
@@ -60,7 +67,7 @@ type Pointer struct {
6067
}
6168

6269
func (n *Node) Find(ctx context.Context, k string, out interface{}) error {
63-
return n.getValue(ctx, hash(k), 0, k, func(kv *KV) error {
70+
return n.getValue(ctx, &hashBits{b: hash(k)}, k, func(kv *KV) error {
6471
// used to just see if the thing exists in the set
6572
if out == nil {
6673
return nil
@@ -79,32 +86,32 @@ func (n *Node) Find(ctx context.Context, k string, out interface{}) error {
7986
}
8087

8188
func (n *Node) Delete(ctx context.Context, k string) error {
82-
return n.modifyValue(ctx, hash(k), 0, k, nil)
89+
return n.modifyValue(ctx, &hashBits{b: hash(k)}, k, nil)
8390
}
8491

8592
var ErrNotFound = fmt.Errorf("not found")
8693
var ErrMaxDepth = fmt.Errorf("attempted to traverse hamt beyond max depth")
8794

88-
func (n *Node) getValue(ctx context.Context, hv []byte, depth int, k string, cb func(*KV) error) error {
89-
if depth >= len(hv) {
95+
func (n *Node) getValue(ctx context.Context, hv *hashBits, k string, cb func(*KV) error) error {
96+
idx, err := hv.Next(n.bitWidth)
97+
if err != nil {
9098
return ErrMaxDepth
9199
}
92100

93-
idx := hv[depth]
94-
if n.Bitfield.Bit(int(idx)) == 0 {
101+
if n.Bitfield.Bit(idx) == 0 {
95102
return ErrNotFound
96103
}
97104

98-
cindex := byte(n.indexForBitPos(int(idx)))
105+
cindex := byte(n.indexForBitPos(idx))
99106

100107
c := n.getChild(cindex)
101108
if c.isShard() {
102-
chnd, err := c.loadChild(ctx, n.store)
109+
chnd, err := c.loadChild(ctx, n.store, n.bitWidth)
103110
if err != nil {
104111
return err
105112
}
106113

107-
return chnd.getValue(ctx, hv, depth+1, k, cb)
114+
return chnd.getValue(ctx, hv, k, cb)
108115
}
109116

110117
for _, kv := range c.KVs {
@@ -116,12 +123,13 @@ func (n *Node) getValue(ctx context.Context, hv []byte, depth int, k string, cb
116123
return ErrNotFound
117124
}
118125

119-
func (p *Pointer) loadChild(ctx context.Context, ns *CborIpldStore) (*Node, error) {
126+
func (p *Pointer) loadChild(ctx context.Context, ns *CborIpldStore, bitWidth int) (*Node, error) {
120127
if p.cache != nil {
121128
return p.cache, nil
122129
}
123130

124131
out, err := LoadNode(ctx, ns, p.Link)
132+
out.bitWidth = bitWidth
125133
if err != nil {
126134
return nil, err
127135
}
@@ -137,7 +145,7 @@ func LoadNode(ctx context.Context, cs *CborIpldStore, c cid.Cid, options ...Opti
137145
}
138146

139147
out.store = cs
140-
148+
out.bitWidth = defaultBitWidth
141149
// apply functional options to node before using
142150
for _, option := range options {
143151
option(&out)
@@ -160,7 +168,7 @@ func (n *Node) checkSize(ctx context.Context) (uint64, error) {
160168
totsize := uint64(len(blk.RawData()))
161169
for _, ch := range n.Pointers {
162170
if ch.isShard() {
163-
chnd, err := ch.loadChild(ctx, n.store)
171+
chnd, err := ch.loadChild(ctx, n.store, n.bitWidth)
164172
if err != nil {
165173
return 0, err
166174
}
@@ -212,7 +220,7 @@ func (n *Node) Set(ctx context.Context, k string, v interface{}) error {
212220
d = &cbg.Deferred{Raw: b}
213221
}
214222

215-
return n.modifyValue(ctx, hash(k), 0, k, d)
223+
return n.modifyValue(ctx, &hashBits{b: hash(k)}, k, d)
216224
}
217225

218226
func (n *Node) cleanChild(chnd *Node, cindex byte) error {
@@ -249,11 +257,11 @@ func (n *Node) cleanChild(chnd *Node, cindex byte) error {
249257
}
250258
}
251259

252-
func (n *Node) modifyValue(ctx context.Context, hv []byte, depth int, k string, v *cbg.Deferred) error {
253-
if depth >= len(hv) {
260+
func (n *Node) modifyValue(ctx context.Context, hv *hashBits, k string, v *cbg.Deferred) error {
261+
idx, err := hv.Next(n.bitWidth)
262+
if err != nil {
254263
return ErrMaxDepth
255264
}
256-
idx := int(hv[depth])
257265

258266
if n.Bitfield.Bit(idx) != 1 {
259267
return n.insertChild(idx, k, v)
@@ -263,12 +271,12 @@ func (n *Node) modifyValue(ctx context.Context, hv []byte, depth int, k string,
263271

264272
child := n.getChild(cindex)
265273
if child.isShard() {
266-
chnd, err := child.loadChild(ctx, n.store)
274+
chnd, err := child.loadChild(ctx, n.store, n.bitWidth)
267275
if err != nil {
268276
return err
269277
}
270278

271-
if err := chnd.modifyValue(ctx, hv, depth+1, k, v); err != nil {
279+
if err := chnd.modifyValue(ctx, hv, k, v); err != nil {
272280
return err
273281
}
274282

@@ -308,12 +316,15 @@ func (n *Node) modifyValue(ctx context.Context, hv []byte, depth int, k string,
308316
// If the array is full, create a subshard and insert everything into it
309317
if len(child.KVs) >= arrayWidth {
310318
sub := NewNode(n.store)
311-
if err := sub.modifyValue(ctx, hv, depth+1, k, v); err != nil {
319+
sub.bitWidth = n.bitWidth
320+
hvcopy := &hashBits{b: hv.b, consumed: hv.consumed}
321+
if err := sub.modifyValue(ctx, hvcopy, k, v); err != nil {
312322
return err
313323
}
314324

315325
for _, p := range child.KVs {
316-
if err := sub.modifyValue(ctx, hash(p.Key), depth+1, p.Key, p.Value); err != nil {
326+
chhv := &hashBits{b: hash(p.Key), consumed: hv.consumed}
327+
if err := sub.modifyValue(ctx, chhv, p.Key, p.Value); err != nil {
317328
return err
318329
}
319330
}
@@ -375,6 +386,7 @@ func (n *Node) getChild(i byte) *Pointer {
375386

376387
func (n *Node) Copy() *Node {
377388
nn := NewNode(n.store)
389+
nn.bitWidth = n.bitWidth
378390
nn.Bitfield.Set(n.Bitfield)
379391
nn.Pointers = make([]*Pointer, len(n.Pointers))
380392

@@ -403,7 +415,7 @@ func (p *Pointer) isShard() bool {
403415
func (n *Node) ForEach(ctx context.Context, f func(k string, val interface{}) error) error {
404416
for _, p := range n.Pointers {
405417
if p.isShard() {
406-
chnd, err := p.loadChild(ctx, n.store)
418+
chnd, err := p.loadChild(ctx, n.store, n.bitWidth)
407419
if err != nil {
408420
return err
409421
}

hamt_bench_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ func BenchmarkFind(b *testing.B) {
5757

5858
}
5959

60-
func doBenchmarkEntriesCount(num int, bitWidth uint8) func(b *testing.B) {
60+
func doBenchmarkEntriesCount(num int, bitWidth int) func(b *testing.B) {
6161
r := rander{rand.New(rand.NewSource(int64(num)))}
6262
return func(b *testing.B) {
6363
cs := NewCborStore()

hamt_test.go

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,22 @@ func TestCanonicalStructure(t *testing.T) {
5151
defer func() {
5252
hash = murmurHash
5353
}()
54-
addAndRemoveKeys(t, []string{"K"}, []string{"B"})
55-
addAndRemoveKeys(t, []string{"K0", "K1", "KAA1", "KAA2", "KAA3"}, []string{"KAA4"})
54+
addAndRemoveKeys(t, defaultBitWidth, []string{"K"}, []string{"B"})
55+
addAndRemoveKeys(t, defaultBitWidth, []string{"K0", "K1", "KAA1", "KAA2", "KAA3"}, []string{"KAA4"})
5656
}
5757

58+
func TestCanonicalStructureAlternateBitWidth(t *testing.T) {
59+
hash = identityHash
60+
defer func() {
61+
hash = murmurHash
62+
}()
63+
addAndRemoveKeys(t, 7, []string{"K"}, []string{"B"})
64+
addAndRemoveKeys(t, 7, []string{"K0", "K1", "KAA1", "KAA2", "KAA3"}, []string{"KAA4"})
65+
addAndRemoveKeys(t, 6, []string{"K"}, []string{"B"})
66+
addAndRemoveKeys(t, 6, []string{"K0", "K1", "KAA1", "KAA2", "KAA3"}, []string{"KAA4"})
67+
addAndRemoveKeys(t, 5, []string{"K"}, []string{"B"})
68+
addAndRemoveKeys(t, 5, []string{"K0", "K1", "KAA1", "KAA2", "KAA3"}, []string{"KAA4"})
69+
}
5870
func TestOverflow(t *testing.T) {
5971
hash = identityHash
6072
defer func() {
@@ -90,7 +102,7 @@ func TestOverflow(t *testing.T) {
90102
}
91103
}
92104

93-
func addAndRemoveKeys(t *testing.T, keys []string, extraKeys []string) {
105+
func addAndRemoveKeys(t *testing.T, bitWidth int, keys []string, extraKeys []string) {
94106
ctx := context.Background()
95107
vals := make(map[string][]byte)
96108
for i := 0; i < len(keys); i++ {
@@ -99,7 +111,7 @@ func addAndRemoveKeys(t *testing.T, keys []string, extraKeys []string) {
99111
}
100112

101113
cs := NewCborStore()
102-
begn := NewNode(cs)
114+
begn := NewNode(cs, UseTreeBitWidth(bitWidth))
103115
for _, k := range keys {
104116
if err := begn.Set(ctx, k, vals[k]); err != nil {
105117
t.Fatal(err)
@@ -122,7 +134,7 @@ func addAndRemoveKeys(t *testing.T, keys []string, extraKeys []string) {
122134
t.Fatal(err)
123135
}
124136
n.store = cs
125-
137+
n.bitWidth = bitWidth
126138
for k, v := range vals {
127139
var out []byte
128140
err := n.Find(ctx, k, &out)
@@ -157,6 +169,7 @@ func addAndRemoveKeys(t *testing.T, keys []string, extraKeys []string) {
157169
t.Fatal(err)
158170
}
159171
n2.store = cs
172+
n2.bitWidth = bitWidth
160173
if !nodesEqual(t, cs, &n, &n2) {
161174
t.Fatal("nodes should be equal")
162175
}
@@ -168,7 +181,7 @@ func dotGraphRec(n *Node, name *int) {
168181
if p.isShard() {
169182
*name++
170183
fmt.Printf("\tn%d -> n%d;\n", cur, *name)
171-
nd, err := p.loadChild(context.Background(), n.store)
184+
nd, err := p.loadChild(context.Background(), n.store, n.bitWidth)
172185
if err != nil {
173186
panic(err)
174187
}
@@ -200,7 +213,7 @@ func statsrec(n *Node, st *hamtStats) {
200213
st.totalNodes++
201214
for _, p := range n.Pointers {
202215
if p.isShard() {
203-
nd, err := p.loadChild(context.Background(), n.store)
216+
nd, err := p.loadChild(context.Background(), n.store, n.bitWidth)
204217
if err != nil {
205218
panic(err)
206219
}
@@ -305,7 +318,7 @@ func TestSetGet(t *testing.T) {
305318
t.Fatal(err)
306319
}
307320
n.store = cs
308-
321+
n.bitWidth = defaultBitWidth
309322
bef = time.Now()
310323
//for k, v := range vals {
311324
for _, k := range keys {

0 commit comments

Comments
 (0)