Skip to content

Commit 85b7f1d

Browse files
authored
Merge pull request #32 from ipfs/feat/configure-hamt-tree-width-31
Configurable HAMT tree bitwidth
2 parents ee6e898 + c51764d commit 85b7f1d

File tree

5 files changed

+209
-45
lines changed

5 files changed

+209
-45
lines changed

hamt.go

Lines changed: 58 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -8,27 +8,49 @@ import (
88

99
cid "github.com/ipfs/go-cid"
1010
cbor "github.com/ipfs/go-ipld-cbor"
11-
murmur3 "github.com/spaolacci/murmur3"
1211
cbg "github.com/whyrusleeping/cbor-gen"
1312
xerrors "golang.org/x/xerrors"
1413
)
1514

1615
const arrayWidth = 3
16+
const defaultBitWidth = 8
1717

1818
type Node struct {
1919
Bitfield *big.Int `refmt:"bf"`
2020
Pointers []*Pointer `refmt:"p"`
2121

2222
// for fetching and storing children
23-
store *CborIpldStore
23+
store *CborIpldStore
24+
bitWidth int
2425
}
2526

26-
func NewNode(cs *CborIpldStore) *Node {
27-
return &Node{
27+
// Option is a function that configures the node
28+
type Option func(*Node)
29+
30+
// UseTreeBitWidth allows you to set the width of the HAMT tree
31+
// in bits (from 1-8) via a customized hash function
32+
func UseTreeBitWidth(bitWidth int) Option {
33+
return func(nd *Node) {
34+
if bitWidth > 0 && bitWidth <= 8 {
35+
nd.bitWidth = bitWidth
36+
}
37+
}
38+
}
39+
40+
// NewNode creates a new IPLD HAMT Node with the given store and given
41+
// options
42+
func NewNode(cs *CborIpldStore, options ...Option) *Node {
43+
nd := &Node{
2844
Bitfield: big.NewInt(0),
2945
Pointers: make([]*Pointer, 0),
3046
store: cs,
47+
bitWidth: defaultBitWidth,
48+
}
49+
// apply functional options to node before using
50+
for _, option := range options {
51+
option(nd)
3152
}
53+
return nd
3254
}
3355

3456
type KV struct {
@@ -44,14 +66,8 @@ type Pointer struct {
4466
cache *Node
4567
}
4668

47-
var hash = func(k string) []byte {
48-
h := murmur3.New128()
49-
h.Write([]byte(k))
50-
return h.Sum(nil)
51-
}
52-
5369
func (n *Node) Find(ctx context.Context, k string, out interface{}) error {
54-
return n.getValue(ctx, hash(k), 0, k, func(kv *KV) error {
70+
return n.getValue(ctx, &hashBits{b: hash(k)}, k, func(kv *KV) error {
5571
// used to just see if the thing exists in the set
5672
if out == nil {
5773
return nil
@@ -70,32 +86,32 @@ func (n *Node) Find(ctx context.Context, k string, out interface{}) error {
7086
}
7187

7288
func (n *Node) Delete(ctx context.Context, k string) error {
73-
return n.modifyValue(ctx, hash(k), 0, k, nil)
89+
return n.modifyValue(ctx, &hashBits{b: hash(k)}, k, nil)
7490
}
7591

7692
var ErrNotFound = fmt.Errorf("not found")
7793
var ErrMaxDepth = fmt.Errorf("attempted to traverse hamt beyond max depth")
7894

79-
func (n *Node) getValue(ctx context.Context, hv []byte, depth int, k string, cb func(*KV) error) error {
80-
if depth >= len(hv) {
95+
func (n *Node) getValue(ctx context.Context, hv *hashBits, k string, cb func(*KV) error) error {
96+
idx, err := hv.Next(n.bitWidth)
97+
if err != nil {
8198
return ErrMaxDepth
8299
}
83100

84-
idx := hv[depth]
85-
if n.Bitfield.Bit(int(idx)) == 0 {
101+
if n.Bitfield.Bit(idx) == 0 {
86102
return ErrNotFound
87103
}
88104

89-
cindex := byte(n.indexForBitPos(int(idx)))
105+
cindex := byte(n.indexForBitPos(idx))
90106

91107
c := n.getChild(cindex)
92108
if c.isShard() {
93-
chnd, err := c.loadChild(ctx, n.store)
109+
chnd, err := c.loadChild(ctx, n.store, n.bitWidth)
94110
if err != nil {
95111
return err
96112
}
97113

98-
return chnd.getValue(ctx, hv, depth+1, k, cb)
114+
return chnd.getValue(ctx, hv, k, cb)
99115
}
100116

101117
for _, kv := range c.KVs {
@@ -107,12 +123,13 @@ func (n *Node) getValue(ctx context.Context, hv []byte, depth int, k string, cb
107123
return ErrNotFound
108124
}
109125

110-
func (p *Pointer) loadChild(ctx context.Context, ns *CborIpldStore) (*Node, error) {
126+
func (p *Pointer) loadChild(ctx context.Context, ns *CborIpldStore, bitWidth int) (*Node, error) {
111127
if p.cache != nil {
112128
return p.cache, nil
113129
}
114130

115131
out, err := LoadNode(ctx, ns, p.Link)
132+
out.bitWidth = bitWidth
116133
if err != nil {
117134
return nil, err
118135
}
@@ -121,13 +138,19 @@ func (p *Pointer) loadChild(ctx context.Context, ns *CborIpldStore) (*Node, erro
121138
return out, nil
122139
}
123140

124-
func LoadNode(ctx context.Context, cs *CborIpldStore, c cid.Cid) (*Node, error) {
141+
func LoadNode(ctx context.Context, cs *CborIpldStore, c cid.Cid, options ...Option) (*Node, error) {
125142
var out Node
126143
if err := cs.Get(ctx, c, &out); err != nil {
127144
return nil, err
128145
}
129146

130147
out.store = cs
148+
out.bitWidth = defaultBitWidth
149+
// apply functional options to node before using
150+
for _, option := range options {
151+
option(&out)
152+
}
153+
131154
return &out, nil
132155
}
133156

@@ -145,7 +168,7 @@ func (n *Node) checkSize(ctx context.Context) (uint64, error) {
145168
totsize := uint64(len(blk.RawData()))
146169
for _, ch := range n.Pointers {
147170
if ch.isShard() {
148-
chnd, err := ch.loadChild(ctx, n.store)
171+
chnd, err := ch.loadChild(ctx, n.store, n.bitWidth)
149172
if err != nil {
150173
return 0, err
151174
}
@@ -197,7 +220,7 @@ func (n *Node) Set(ctx context.Context, k string, v interface{}) error {
197220
d = &cbg.Deferred{Raw: b}
198221
}
199222

200-
return n.modifyValue(ctx, hash(k), 0, k, d)
223+
return n.modifyValue(ctx, &hashBits{b: hash(k)}, k, d)
201224
}
202225

203226
func (n *Node) cleanChild(chnd *Node, cindex byte) error {
@@ -234,11 +257,11 @@ func (n *Node) cleanChild(chnd *Node, cindex byte) error {
234257
}
235258
}
236259

237-
func (n *Node) modifyValue(ctx context.Context, hv []byte, depth int, k string, v *cbg.Deferred) error {
238-
if depth >= len(hv) {
260+
func (n *Node) modifyValue(ctx context.Context, hv *hashBits, k string, v *cbg.Deferred) error {
261+
idx, err := hv.Next(n.bitWidth)
262+
if err != nil {
239263
return ErrMaxDepth
240264
}
241-
idx := int(hv[depth])
242265

243266
if n.Bitfield.Bit(idx) != 1 {
244267
return n.insertChild(idx, k, v)
@@ -248,12 +271,12 @@ func (n *Node) modifyValue(ctx context.Context, hv []byte, depth int, k string,
248271

249272
child := n.getChild(cindex)
250273
if child.isShard() {
251-
chnd, err := child.loadChild(ctx, n.store)
274+
chnd, err := child.loadChild(ctx, n.store, n.bitWidth)
252275
if err != nil {
253276
return err
254277
}
255278

256-
if err := chnd.modifyValue(ctx, hv, depth+1, k, v); err != nil {
279+
if err := chnd.modifyValue(ctx, hv, k, v); err != nil {
257280
return err
258281
}
259282

@@ -293,12 +316,15 @@ func (n *Node) modifyValue(ctx context.Context, hv []byte, depth int, k string,
293316
// If the array is full, create a subshard and insert everything into it
294317
if len(child.KVs) >= arrayWidth {
295318
sub := NewNode(n.store)
296-
if err := sub.modifyValue(ctx, hv, depth+1, k, v); err != nil {
319+
sub.bitWidth = n.bitWidth
320+
hvcopy := &hashBits{b: hv.b, consumed: hv.consumed}
321+
if err := sub.modifyValue(ctx, hvcopy, k, v); err != nil {
297322
return err
298323
}
299324

300325
for _, p := range child.KVs {
301-
if err := sub.modifyValue(ctx, hash(p.Key), depth+1, p.Key, p.Value); err != nil {
326+
chhv := &hashBits{b: hash(p.Key), consumed: hv.consumed}
327+
if err := sub.modifyValue(ctx, chhv, p.Key, p.Value); err != nil {
302328
return err
303329
}
304330
}
@@ -360,6 +386,7 @@ func (n *Node) getChild(i byte) *Pointer {
360386

361387
func (n *Node) Copy() *Node {
362388
nn := NewNode(n.store)
389+
nn.bitWidth = n.bitWidth
363390
nn.Bitfield.Set(n.Bitfield)
364391
nn.Pointers = make([]*Pointer, len(n.Pointers))
365392

@@ -388,7 +415,7 @@ func (p *Pointer) isShard() bool {
388415
func (n *Node) ForEach(ctx context.Context, f func(k string, val interface{}) error) error {
389416
for _, p := range n.Pointers {
390417
if p.isShard() {
391-
chnd, err := p.loadChild(ctx, n.store)
418+
chnd, err := p.loadChild(ctx, n.store, n.bitWidth)
392419
if err != nil {
393420
return err
394421
}

hamt_bench_test.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,20 @@ func BenchmarkSerializeNode(b *testing.B) {
4848
}
4949

5050
func BenchmarkFind(b *testing.B) {
51-
b.Run("find-10k", doBenchmarkEntriesCount(10000))
52-
b.Run("find-100k", doBenchmarkEntriesCount(100000))
53-
b.Run("find-1m", doBenchmarkEntriesCount(1000000))
51+
b.Run("find-10k", doBenchmarkEntriesCount(10000, 8))
52+
b.Run("find-100k", doBenchmarkEntriesCount(100000, 8))
53+
b.Run("find-1m", doBenchmarkEntriesCount(1000000, 8))
54+
b.Run("find-10k-bitwidth-5", doBenchmarkEntriesCount(10000, 5))
55+
b.Run("find-100k-bitwidth-5", doBenchmarkEntriesCount(100000, 5))
56+
b.Run("find-1m-bitwidth-5", doBenchmarkEntriesCount(1000000, 5))
57+
5458
}
5559

56-
func doBenchmarkEntriesCount(num int) func(b *testing.B) {
60+
func doBenchmarkEntriesCount(num int, bitWidth int) func(b *testing.B) {
5761
r := rander{rand.New(rand.NewSource(int64(num)))}
5862
return func(b *testing.B) {
5963
cs := NewCborStore()
60-
n := NewNode(cs)
64+
n := NewNode(cs, UseTreeBitWidth(bitWidth))
6165

6266
var keys []string
6367
for i := 0; i < num; i++ {
@@ -82,7 +86,7 @@ func doBenchmarkEntriesCount(num int) func(b *testing.B) {
8286
b.ReportAllocs()
8387

8488
for i := 0; i < b.N; i++ {
85-
nd, err := LoadNode(context.TODO(), cs, c)
89+
nd, err := LoadNode(context.TODO(), cs, c, UseTreeBitWidth(bitWidth))
8690
if err != nil {
8791
b.Fatal(err)
8892
}

hamt_test.go

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,22 @@ func TestCanonicalStructure(t *testing.T) {
5151
defer func() {
5252
hash = murmurHash
5353
}()
54-
addAndRemoveKeys(t, []string{"K"}, []string{"B"})
55-
addAndRemoveKeys(t, []string{"K0", "K1", "KAA1", "KAA2", "KAA3"}, []string{"KAA4"})
54+
addAndRemoveKeys(t, defaultBitWidth, []string{"K"}, []string{"B"})
55+
addAndRemoveKeys(t, defaultBitWidth, []string{"K0", "K1", "KAA1", "KAA2", "KAA3"}, []string{"KAA4"})
5656
}
5757

58+
func TestCanonicalStructureAlternateBitWidth(t *testing.T) {
59+
hash = identityHash
60+
defer func() {
61+
hash = murmurHash
62+
}()
63+
addAndRemoveKeys(t, 7, []string{"K"}, []string{"B"})
64+
addAndRemoveKeys(t, 7, []string{"K0", "K1", "KAA1", "KAA2", "KAA3"}, []string{"KAA4"})
65+
addAndRemoveKeys(t, 6, []string{"K"}, []string{"B"})
66+
addAndRemoveKeys(t, 6, []string{"K0", "K1", "KAA1", "KAA2", "KAA3"}, []string{"KAA4"})
67+
addAndRemoveKeys(t, 5, []string{"K"}, []string{"B"})
68+
addAndRemoveKeys(t, 5, []string{"K0", "K1", "KAA1", "KAA2", "KAA3"}, []string{"KAA4"})
69+
}
5870
func TestOverflow(t *testing.T) {
5971
hash = identityHash
6072
defer func() {
@@ -90,7 +102,7 @@ func TestOverflow(t *testing.T) {
90102
}
91103
}
92104

93-
func addAndRemoveKeys(t *testing.T, keys []string, extraKeys []string) {
105+
func addAndRemoveKeys(t *testing.T, bitWidth int, keys []string, extraKeys []string) {
94106
ctx := context.Background()
95107
vals := make(map[string][]byte)
96108
for i := 0; i < len(keys); i++ {
@@ -99,7 +111,7 @@ func addAndRemoveKeys(t *testing.T, keys []string, extraKeys []string) {
99111
}
100112

101113
cs := NewCborStore()
102-
begn := NewNode(cs)
114+
begn := NewNode(cs, UseTreeBitWidth(bitWidth))
103115
for _, k := range keys {
104116
if err := begn.Set(ctx, k, vals[k]); err != nil {
105117
t.Fatal(err)
@@ -122,7 +134,7 @@ func addAndRemoveKeys(t *testing.T, keys []string, extraKeys []string) {
122134
t.Fatal(err)
123135
}
124136
n.store = cs
125-
137+
n.bitWidth = bitWidth
126138
for k, v := range vals {
127139
var out []byte
128140
err := n.Find(ctx, k, &out)
@@ -157,6 +169,7 @@ func addAndRemoveKeys(t *testing.T, keys []string, extraKeys []string) {
157169
t.Fatal(err)
158170
}
159171
n2.store = cs
172+
n2.bitWidth = bitWidth
160173
if !nodesEqual(t, cs, &n, &n2) {
161174
t.Fatal("nodes should be equal")
162175
}
@@ -168,7 +181,7 @@ func dotGraphRec(n *Node, name *int) {
168181
if p.isShard() {
169182
*name++
170183
fmt.Printf("\tn%d -> n%d;\n", cur, *name)
171-
nd, err := p.loadChild(context.Background(), n.store)
184+
nd, err := p.loadChild(context.Background(), n.store, n.bitWidth)
172185
if err != nil {
173186
panic(err)
174187
}
@@ -200,7 +213,7 @@ func statsrec(n *Node, st *hamtStats) {
200213
st.totalNodes++
201214
for _, p := range n.Pointers {
202215
if p.isShard() {
203-
nd, err := p.loadChild(context.Background(), n.store)
216+
nd, err := p.loadChild(context.Background(), n.store, n.bitWidth)
204217
if err != nil {
205218
panic(err)
206219
}
@@ -305,7 +318,7 @@ func TestSetGet(t *testing.T) {
305318
t.Fatal(err)
306319
}
307320
n.store = cs
308-
321+
n.bitWidth = defaultBitWidth
309322
bef = time.Now()
310323
//for k, v := range vals {
311324
for _, k := range keys {

0 commit comments

Comments
 (0)