Skip to content

Commit bc91f4a

Browse files
authored
Merge pull request #49 from ipfs/fix/custom-hash
feat: allow custom hash functions
2 parents 0310ad2 + 4d3002f commit bc91f4a

File tree

4 files changed

+62
-52
lines changed

4 files changed

+62
-52
lines changed

go.sum

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@ github.com/ipfs/go-cid v0.0.1/go.mod h1:GHWU/WuQdMPmIosc4Yn1bcCT7dSeX4lBafM7iqUP
1111
github.com/ipfs/go-cid v0.0.2/go.mod h1:GHWU/WuQdMPmIosc4Yn1bcCT7dSeX4lBafM7iqUPQvM=
1212
github.com/ipfs/go-cid v0.0.3 h1:UIAh32wymBpStoe83YCzwVQQ5Oy/H0FdxvUS6DJDzms=
1313
github.com/ipfs/go-cid v0.0.3/go.mod h1:GHWU/WuQdMPmIosc4Yn1bcCT7dSeX4lBafM7iqUPQvM=
14-
github.com/ipfs/go-cid v0.0.4 h1:UlfXKrZx1DjZoBhQHmNHLC1fK1dUJDN20Y28A7s+gJ8=
15-
github.com/ipfs/go-cid v0.0.4/go.mod h1:4LLaPOQwmk5z9LBgQnpkivrx8BJjUyGwTXCd5Xfj6+M=
1614
github.com/ipfs/go-cid v0.0.6-0.20200501230655-7c82f3b81c00 h1:QN88Q0kT2QiDaLxpR/SDsqOBtNIEF/F3n96gSDUimkA=
1715
github.com/ipfs/go-cid v0.0.6-0.20200501230655-7c82f3b81c00/go.mod h1:plgt+Y5MnOey4vO4UlUazGqdbEXuFYitED67FexhXog=
1816
github.com/ipfs/go-ipfs-util v0.0.1 h1:Wz9bL2wB2YBJqggkA4dD7oSmqB4cAnpNbGrlHJulv50=
@@ -61,10 +59,6 @@ github.com/warpfork/go-wish v0.0.0-20180510122957-5ad1f5abf436 h1:qOpVTI+BrstcjT
6159
github.com/warpfork/go-wish v0.0.0-20180510122957-5ad1f5abf436/go.mod h1:x6AKhvSSexNrVSrViXSHUEbICjmGXhtgABaHIySUSGw=
6260
github.com/whyrusleeping/cbor-gen v0.0.0-20200123233031-1cdf64d27158 h1:WXhVOwj2USAXB5oMDwRl3piOux2XMV9TANaYxXHdkoE=
6361
github.com/whyrusleeping/cbor-gen v0.0.0-20200123233031-1cdf64d27158/go.mod h1:Xj/M2wWU+QdTdRbu/L/1dIZY8/Wb2K9pAhtroQuxJJI=
64-
github.com/whyrusleeping/cbor-gen v0.0.0-20200414195334-429a0b5e922e h1:JY8o/ebUUrCYetWmjRCNghxC59cOEaili83rxPRQCLw=
65-
github.com/whyrusleeping/cbor-gen v0.0.0-20200414195334-429a0b5e922e/go.mod h1:Xj/M2wWU+QdTdRbu/L/1dIZY8/Wb2K9pAhtroQuxJJI=
66-
github.com/whyrusleeping/cbor-gen v0.0.0-20200501014322-5f9941ef88e0 h1:dmdwCOVtJAm7qwONARangN4jgCisVFmSJ486JZ1LYaA=
67-
github.com/whyrusleeping/cbor-gen v0.0.0-20200501014322-5f9941ef88e0/go.mod h1:Xj/M2wWU+QdTdRbu/L/1dIZY8/Wb2K9pAhtroQuxJJI=
6862
github.com/whyrusleeping/cbor-gen v0.0.0-20200504204219-64967432584d h1:Y25auOnuZb/GuJvqMflRSDWBz8/HBRME8fiD+H8zLfs=
6963
github.com/whyrusleeping/cbor-gen v0.0.0-20200504204219-64967432584d/go.mod h1:W5MvapuoHRP8rz4vxjwCK1pDqF1aQcWsV5PZ+AHbqdg=
7064
golang.org/x/crypto v0.0.0-20190211182817-74369b46fc67 h1:ng3VDlRp5/DHpSWl02R4rM9I+8M2rhmsuLwAMmkLQWE=

hamt.go

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ type Node struct {
2020
Pointers []*Pointer `refmt:"p"`
2121

2222
bitWidth int
23+
hash func([]byte) []byte
2324

2425
// for fetching and storing children
2526
store cbor.IpldStore
@@ -38,13 +39,23 @@ func UseTreeBitWidth(bitWidth int) Option {
3839
}
3940
}
4041

42+
// UseHashFunction allows you to set the hash function used by the HAMT. It
43+
// defaults to murmur3 but you should use sha256 when an attacker can pick the
44+
// keys.
45+
func UseHashFunction(hash func([]byte) []byte) Option {
46+
return func(nd *Node) {
47+
nd.hash = hash
48+
}
49+
}
50+
4151
// NewNode creates a new IPLD HAMT Node with the given store and given
4252
// options
4353
func NewNode(cs cbor.IpldStore, options ...Option) *Node {
4454
nd := &Node{
4555
Bitfield: big.NewInt(0),
4656
Pointers: make([]*Pointer, 0),
4757
store: cs,
58+
hash: defaultHashFunction,
4859
bitWidth: defaultBitWidth,
4960
}
5061
// apply functional options to node before using
@@ -68,7 +79,7 @@ type Pointer struct {
6879
}
6980

7081
func (n *Node) Find(ctx context.Context, k string, out interface{}) error {
71-
return n.getValue(ctx, &hashBits{b: hash([]byte(k))}, k, func(kv *KV) error {
82+
return n.getValue(ctx, &hashBits{b: n.hash([]byte(k))}, k, func(kv *KV) error {
7283
// used to just see if the thing exists in the set
7384
if out == nil {
7485
return nil
@@ -88,7 +99,7 @@ func (n *Node) Find(ctx context.Context, k string, out interface{}) error {
8899

89100
func (n *Node) FindRaw(ctx context.Context, k string) ([]byte, error) {
90101
var ret []byte
91-
err := n.getValue(ctx, &hashBits{b: hash([]byte(k))}, k, func(kv *KV) error {
102+
err := n.getValue(ctx, &hashBits{b: n.hash([]byte(k))}, k, func(kv *KV) error {
92103
ret = kv.Value.Raw
93104
return nil
94105
})
@@ -97,7 +108,7 @@ func (n *Node) FindRaw(ctx context.Context, k string) ([]byte, error) {
97108

98109
func (n *Node) Delete(ctx context.Context, k string) error {
99110
kb := []byte(k)
100-
return n.modifyValue(ctx, &hashBits{b: hash(kb)}, kb, nil)
111+
return n.modifyValue(ctx, &hashBits{b: n.hash(kb)}, kb, nil)
101112
}
102113

103114
var ErrNotFound = fmt.Errorf("not found")
@@ -117,7 +128,7 @@ func (n *Node) getValue(ctx context.Context, hv *hashBits, k string, cb func(*KV
117128

118129
c := n.getChild(cindex)
119130
if c.isShard() {
120-
chnd, err := c.loadChild(ctx, n.store, n.bitWidth)
131+
chnd, err := c.loadChild(ctx, n.store, n.bitWidth, n.hash)
121132
if err != nil {
122133
return err
123134
}
@@ -134,7 +145,7 @@ func (n *Node) getValue(ctx context.Context, hv *hashBits, k string, cb func(*KV
134145
return ErrNotFound
135146
}
136147

137-
func (p *Pointer) loadChild(ctx context.Context, ns cbor.IpldStore, bitWidth int) (*Node, error) {
148+
func (p *Pointer) loadChild(ctx context.Context, ns cbor.IpldStore, bitWidth int, hash func([]byte) []byte) (*Node, error) {
138149
if p.cache != nil {
139150
return p.cache, nil
140151
}
@@ -144,6 +155,7 @@ func (p *Pointer) loadChild(ctx context.Context, ns cbor.IpldStore, bitWidth int
144155
return nil, err
145156
}
146157
out.bitWidth = bitWidth
158+
out.hash = hash
147159

148160
p.cache = out
149161
return out, nil
@@ -157,6 +169,7 @@ func LoadNode(ctx context.Context, cs cbor.IpldStore, c cid.Cid, options ...Opti
157169

158170
out.store = cs
159171
out.bitWidth = defaultBitWidth
172+
out.hash = defaultHashFunction
160173
// apply functional options to node before using
161174
for _, option := range options {
162175
option(&out)
@@ -179,7 +192,7 @@ func (n *Node) checkSize(ctx context.Context) (uint64, error) {
179192
totsize := uint64(len(def.Raw))
180193
for _, ch := range n.Pointers {
181194
if ch.isShard() {
182-
chnd, err := ch.loadChild(ctx, n.store, n.bitWidth)
195+
chnd, err := ch.loadChild(ctx, n.store, n.bitWidth, n.hash)
183196
if err != nil {
184197
return 0, err
185198
}
@@ -217,7 +230,7 @@ func (n *Node) Flush(ctx context.Context) error {
217230
func (n *Node) SetRaw(ctx context.Context, k string, raw []byte) error {
218231
d := &cbg.Deferred{Raw: raw}
219232
kb := []byte(k)
220-
return n.modifyValue(ctx, &hashBits{b: hash(kb)}, kb, d)
233+
return n.modifyValue(ctx, &hashBits{b: n.hash(kb)}, kb, d)
221234
}
222235

223236
func (n *Node) Set(ctx context.Context, k string, v interface{}) error {
@@ -240,7 +253,7 @@ func (n *Node) Set(ctx context.Context, k string, v interface{}) error {
240253
d = &cbg.Deferred{Raw: b}
241254
}
242255

243-
return n.modifyValue(ctx, &hashBits{b: hash(kb)}, kb, d)
256+
return n.modifyValue(ctx, &hashBits{b: n.hash(kb)}, kb, d)
244257
}
245258

246259
func (n *Node) cleanChild(chnd *Node, cindex byte) error {
@@ -291,7 +304,7 @@ func (n *Node) modifyValue(ctx context.Context, hv *hashBits, k []byte, v *cbg.D
291304

292305
child := n.getChild(cindex)
293306
if child.isShard() {
294-
chnd, err := child.loadChild(ctx, n.store, n.bitWidth)
307+
chnd, err := child.loadChild(ctx, n.store, n.bitWidth, n.hash)
295308
if err != nil {
296309
return err
297310
}
@@ -337,13 +350,14 @@ func (n *Node) modifyValue(ctx context.Context, hv *hashBits, k []byte, v *cbg.D
337350
if len(child.KVs) >= arrayWidth {
338351
sub := NewNode(n.store)
339352
sub.bitWidth = n.bitWidth
353+
sub.hash = n.hash
340354
hvcopy := &hashBits{b: hv.b, consumed: hv.consumed}
341355
if err := sub.modifyValue(ctx, hvcopy, k, v); err != nil {
342356
return err
343357
}
344358

345359
for _, p := range child.KVs {
346-
chhv := &hashBits{b: hash(p.Key), consumed: hv.consumed}
360+
chhv := &hashBits{b: n.hash([]byte(p.Key)), consumed: hv.consumed}
347361
if err := sub.modifyValue(ctx, chhv, p.Key, p.Value); err != nil {
348362
return err
349363
}
@@ -407,6 +421,7 @@ func (n *Node) getChild(i byte) *Pointer {
407421
func (n *Node) Copy() *Node {
408422
nn := NewNode(n.store)
409423
nn.bitWidth = n.bitWidth
424+
nn.hash = n.hash
410425
nn.Bitfield.Set(n.Bitfield)
411426
nn.Pointers = make([]*Pointer, len(n.Pointers))
412427

@@ -435,7 +450,7 @@ func (p *Pointer) isShard() bool {
435450
func (n *Node) ForEach(ctx context.Context, f func(k string, val interface{}) error) error {
436451
for _, p := range n.Pointers {
437452
if p.isShard() {
438-
chnd, err := p.loadChild(ctx, n.store, n.bitWidth)
453+
chnd, err := p.loadChild(ctx, n.store, n.bitWidth, n.hash)
439454
if err != nil {
440455
return err
441456
}

hamt_test.go

Lines changed: 35 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package hamt
33
import (
44
"bytes"
55
"context"
6+
"crypto/sha256"
67
"encoding/hex"
78
"fmt"
89
"math/rand"
@@ -68,41 +69,27 @@ var shortIdentityHash = func(k []byte) []byte {
6869
return res
6970
}
7071

71-
var murmurHash = hash
72-
7372
func TestCanonicalStructure(t *testing.T) {
74-
hash = identityHash
75-
defer func() {
76-
hash = murmurHash
77-
}()
78-
addAndRemoveKeys(t, defaultBitWidth, []string{"K"}, []string{"B"})
79-
addAndRemoveKeys(t, defaultBitWidth, []string{"K0", "K1", "KAA1", "KAA2", "KAA3"}, []string{"KAA4"})
73+
addAndRemoveKeys(t, []string{"K"}, []string{"B"}, UseHashFunction(identityHash))
74+
addAndRemoveKeys(t, []string{"K0", "K1", "KAA1", "KAA2", "KAA3"}, []string{"KAA4"})
8075
}
8176

8277
func TestCanonicalStructureAlternateBitWidth(t *testing.T) {
83-
hash = identityHash
84-
defer func() {
85-
hash = murmurHash
86-
}()
87-
addAndRemoveKeys(t, 7, []string{"K"}, []string{"B"})
88-
addAndRemoveKeys(t, 7, []string{"K0", "K1", "KAA1", "KAA2", "KAA3"}, []string{"KAA4"})
89-
addAndRemoveKeys(t, 6, []string{"K"}, []string{"B"})
90-
addAndRemoveKeys(t, 6, []string{"K0", "K1", "KAA1", "KAA2", "KAA3"}, []string{"KAA4"})
91-
addAndRemoveKeys(t, 5, []string{"K"}, []string{"B"})
92-
addAndRemoveKeys(t, 5, []string{"K0", "K1", "KAA1", "KAA2", "KAA3"}, []string{"KAA4"})
78+
addAndRemoveKeys(t, []string{"K"}, []string{"B"}, UseTreeBitWidth(7), UseHashFunction(identityHash))
79+
addAndRemoveKeys(t, []string{"K0", "K1", "KAA1", "KAA2", "KAA3"}, []string{"KAA4"}, UseTreeBitWidth(7), UseHashFunction(identityHash))
80+
addAndRemoveKeys(t, []string{"K"}, []string{"B"}, UseTreeBitWidth(6), UseHashFunction(identityHash))
81+
addAndRemoveKeys(t, []string{"K0", "K1", "KAA1", "KAA2", "KAA3"}, []string{"KAA4"}, UseTreeBitWidth(6), UseHashFunction(identityHash))
82+
addAndRemoveKeys(t, []string{"K"}, []string{"B"}, UseTreeBitWidth(5), UseHashFunction(identityHash))
83+
addAndRemoveKeys(t, []string{"K0", "K1", "KAA1", "KAA2", "KAA3"}, []string{"KAA4"}, UseTreeBitWidth(5), UseHashFunction(identityHash))
9384
}
9485
func TestOverflow(t *testing.T) {
95-
hash = identityHash
96-
defer func() {
97-
hash = murmurHash
98-
}()
9986
keys := make([]string, 4)
10087
for i := range keys {
10188
keys[i] = strings.Repeat("A", 32) + fmt.Sprintf("%d", i)
10289
}
10390

10491
cs := cbor.NewCborStore(newMockBlocks())
105-
n := NewNode(cs)
92+
n := NewNode(cs, UseHashFunction(identityHash))
10693
for _, k := range keys[:3] {
10794
if err := n.Set(context.Background(), k, "foobar"); err != nil {
10895
t.Error(err)
@@ -120,13 +107,13 @@ func TestOverflow(t *testing.T) {
120107
}
121108

122109
// Now, try fetching with a shorter hash function.
123-
hash = shortIdentityHash
110+
n.hash = shortIdentityHash
124111
if err := n.Find(context.Background(), keys[0], nil); err != ErrMaxDepth {
125112
t.Errorf("expected error %q, got %q", ErrMaxDepth, err)
126113
}
127114
}
128115

129-
func addAndRemoveKeys(t *testing.T, bitWidth int, keys []string, extraKeys []string) {
116+
func addAndRemoveKeys(t *testing.T, keys []string, extraKeys []string, options ...Option) {
130117
ctx := context.Background()
131118
vals := make(map[string][]byte)
132119
for i := 0; i < len(keys); i++ {
@@ -135,7 +122,7 @@ func addAndRemoveKeys(t *testing.T, bitWidth int, keys []string, extraKeys []str
135122
}
136123

137124
cs := cbor.NewCborStore(newMockBlocks())
138-
begn := NewNode(cs, UseTreeBitWidth(bitWidth))
125+
begn := NewNode(cs, options...)
139126
for _, k := range keys {
140127
if err := begn.Set(ctx, k, vals[k]); err != nil {
141128
t.Fatal(err)
@@ -158,7 +145,8 @@ func addAndRemoveKeys(t *testing.T, bitWidth int, keys []string, extraKeys []str
158145
t.Fatal(err)
159146
}
160147
n.store = cs
161-
n.bitWidth = bitWidth
148+
n.hash = begn.hash
149+
n.bitWidth = begn.bitWidth
162150
for k, v := range vals {
163151
var out []byte
164152
err := n.Find(ctx, k, &out)
@@ -193,7 +181,8 @@ func addAndRemoveKeys(t *testing.T, bitWidth int, keys []string, extraKeys []str
193181
t.Fatal(err)
194182
}
195183
n2.store = cs
196-
n2.bitWidth = bitWidth
184+
n2.hash = begn.hash
185+
n2.bitWidth = begn.bitWidth
197186
if !nodesEqual(t, cs, &n, &n2) {
198187
t.Fatal("nodes should be equal")
199188
}
@@ -205,7 +194,7 @@ func dotGraphRec(n *Node, name *int) {
205194
if p.isShard() {
206195
*name++
207196
fmt.Printf("\tn%d -> n%d;\n", cur, *name)
208-
nd, err := p.loadChild(context.Background(), n.store, n.bitWidth)
197+
nd, err := p.loadChild(context.Background(), n.store, n.bitWidth, n.hash)
209198
if err != nil {
210199
panic(err)
211200
}
@@ -237,7 +226,7 @@ func statsrec(n *Node, st *hamtStats) {
237226
st.totalNodes++
238227
for _, p := range n.Pointers {
239228
if p.isShard() {
240-
nd, err := p.loadChild(context.Background(), n.store, n.bitWidth)
229+
nd, err := p.loadChild(context.Background(), n.store, n.bitWidth, n.hash)
241230
if err != nil {
242231
panic(err)
243232
}
@@ -251,17 +240,28 @@ func statsrec(n *Node, st *hamtStats) {
251240
}
252241

253242
func TestHash(t *testing.T) {
254-
h1 := hash([]byte("abcd"))
255-
h2 := hash([]byte("abce"))
243+
h1 := defaultHashFunction([]byte("abcd"))
244+
h2 := defaultHashFunction([]byte("abce"))
256245
if h1[0] == h2[0] && h1[1] == h2[1] && h1[3] == h2[3] {
257246
t.Fatal("Hash should give different strings different hash prefixes")
258247
}
259248
}
260249

261250
func TestBasic(t *testing.T) {
251+
testBasic(t)
252+
}
253+
254+
func TestSha256(t *testing.T) {
255+
testBasic(t, UseHashFunction(func(in []byte) []byte {
256+
out := sha256.Sum256(in)
257+
return out[:]
258+
}))
259+
}
260+
261+
func testBasic(t *testing.T, options ...Option) {
262262
ctx := context.Background()
263263
cs := cbor.NewCborStore(newMockBlocks())
264-
begn := NewNode(cs)
264+
begn := NewNode(cs, options...)
265265

266266
val := []byte("cat dog bear")
267267
if err := begn.Set(ctx, "foo", val); err != nil {
@@ -282,7 +282,7 @@ func TestBasic(t *testing.T) {
282282
t.Fatal(err)
283283
}
284284

285-
n, err := LoadNode(ctx, cs, c)
285+
n, err := LoadNode(ctx, cs, c, options...)
286286
if err != nil {
287287
t.Fatal(err)
288288
}
@@ -381,6 +381,7 @@ func TestSetGet(t *testing.T) {
381381
t.Fatal(err)
382382
}
383383
n.store = cs
384+
n.hash = defaultHashFunction
384385
n.bitWidth = defaultBitWidth
385386
bef = time.Now()
386387
//for k, v := range vals {

hash.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func (hb *hashBits) next(i int) int {
4949
}
5050
}
5151

52-
var hash = func(val []byte) []byte {
52+
func defaultHashFunction(val []byte) []byte {
5353
h := murmur3.New64()
5454
h.Write(val)
5555
return h.Sum(nil)

0 commit comments

Comments
 (0)