Skip to content
This repository was archived by the owner on May 23, 2023. It is now read-only.

Commit dd18d33

Browse files
committed
Added benchmarks for binary trie
1 parent 243162b commit dd18d33

File tree

1 file changed

+287
-0
lines changed

1 file changed

+287
-0
lines changed

ethereum/tests/bintrie.py

Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
# All nodes are of the form [path1, child1, path2, child2]
2+
# or <value>
3+
4+
from ethereum import utils
5+
from ethereum.db import EphemDB, ListeningDB
6+
import rlp, sys
7+
import copy
8+
9+
hashfunc = utils.sha3
10+
11+
HASHLEN = 32
12+
13+
14+
# 0100000101010111010000110100100101001001 -> ASCII
15+
def decode_bin(x):
16+
return ''.join([chr(int(x[i:i+8], 2)) for i in range(0, len(x), 8)])
17+
18+
19+
# ASCII -> 0100000101010111010000110100100101001001
20+
def encode_bin(x):
21+
o = ''
22+
for c in x:
23+
c = ord(c)
24+
p = ''
25+
for i in range(8):
26+
p = str(c % 2) + p
27+
c /= 2
28+
o += p
29+
return o
30+
31+
32+
# Encodes a binary list [0,1,0,1,1,0] of any length into bytes
33+
def encode_bin_path(li):
34+
if li == []:
35+
return ''
36+
b = ''.join([str(x) for x in li])
37+
b2 = '0' * ((4 - len(b)) % 4) + b
38+
prefix = ['00', '01', '10', '11'][len(b) % 4]
39+
if len(b2) % 8 == 4:
40+
return decode_bin('00' + prefix + b2)
41+
else:
42+
return decode_bin('100000' + prefix + b2)
43+
44+
45+
# Decodes bytes into a binary list
46+
def decode_bin_path(p):
47+
if p == '':
48+
return []
49+
p = encode_bin(p)
50+
if p[0] == '1':
51+
p = p[4:]
52+
assert p[0:2] == '00'
53+
L = ['00', '01', '10', '11'].index(p[2:4])
54+
p = p[4+((4 - L) % 4):]
55+
return [(1 if x == '1' else 0) for x in p]
56+
57+
58+
# Get a node from a database if needed
59+
def dbget(node, db):
60+
if len(node) == HASHLEN:
61+
return rlp.decode(db.get(node))
62+
return node
63+
64+
65+
# Place a node into a database if needed
66+
def dbput(node, db):
67+
r = rlp.encode(node)
68+
if len(r) == HASHLEN or len(r) > HASHLEN * 2:
69+
h = hashfunc(r)
70+
db.put(h, r)
71+
return h
72+
return node
73+
74+
75+
# Get a value from a tree
76+
def get(node, db, key):
77+
node = dbget(node, db)
78+
if key == []:
79+
return node[0]
80+
elif len(node) == 1 or len(node) == 0:
81+
return ''
82+
else:
83+
sub = dbget(node[key[0]], db)
84+
if len(sub) == 2:
85+
subpath, subnode = sub
86+
else:
87+
subpath, subnode = '', sub[0]
88+
subpath = decode_bin_path(subpath)
89+
if key[1:len(subpath)+1] != subpath:
90+
return ''
91+
return get(subnode, db, key[len(subpath)+1:])
92+
93+
94+
# Get length of shared prefix of inputs
95+
def get_shared_length(l1, l2):
96+
i = 0
97+
while i < len(l1) and i < len(l2) and l1[i] == l2[i]:
98+
i += 1
99+
return i
100+
101+
102+
# Replace ['', v] with [v] and compact nodes into hashes
103+
# if needed
104+
def contract_node(n, db):
105+
if len(n[0]) == 2 and n[0][0] == '':
106+
n[0] = [n[0][1]]
107+
if len(n[1]) == 2 and n[1][0] == '':
108+
n[1] = [n[1][1]]
109+
if len(n[0]) != 32:
110+
n[0] = dbput(n[0], db)
111+
if len(n[1]) != 32:
112+
n[1] = dbput(n[1], db)
113+
return dbput(n, db)
114+
115+
116+
# Update a trie
117+
def update(node, db, key, val):
118+
node = dbget(node, db)
119+
# Unfortunately this particular design does not allow
120+
# a node to have one child, so at the root for empty
121+
# tries we need to add two dummy children
122+
if node == '':
123+
node = [dbput([encode_bin_path([]), ''], db),
124+
dbput([encode_bin_path([1]), ''], db)]
125+
if key == []:
126+
node = [val]
127+
elif len(node) == 1:
128+
raise Exception("DB must be prefix-free")
129+
else:
130+
assert len(node) == 2, node
131+
sub = dbget(node[key[0]], db)
132+
if len(sub) == 2:
133+
_subpath, subnode = sub
134+
else:
135+
_subpath, subnode = '', sub[0]
136+
subpath = decode_bin_path(_subpath)
137+
sl = get_shared_length(subpath, key[1:])
138+
if sl == len(subpath):
139+
node[key[0]] = [_subpath, update(subnode, db, key[sl+1:], val)]
140+
else:
141+
subpath_next = subpath[sl]
142+
n = [0, 0]
143+
n[subpath_next] = [encode_bin_path(subpath[sl+1:]), subnode]
144+
n[(1 - subpath_next)] = [encode_bin_path(key[sl+2:]), [val]]
145+
n = contract_node(n, db)
146+
node[key[0]] = dbput([encode_bin_path(subpath[:sl]), n], db)
147+
return contract_node(node, db)
148+
149+
150+
# Compression algorithm specialized for merkle proof databases
151+
# The idea is similar to standard compression algorithms, where
152+
# you replace an instance of a repeat with a pointer to the repeat,
153+
# except that here you replace an instance of a hash of a value
154+
# with the pointer of a value. This is useful since merkle branches
155+
# usually include nodes which contain hashes of each other
156+
magic = '\xff\x39'
157+
158+
159+
def compress_db(db):
160+
out = []
161+
values = db.kv.values()
162+
keys = [hashfunc(x) for x in values]
163+
assert len(keys) < 65300
164+
for v in values:
165+
o = ''
166+
pos = 0
167+
while pos < len(v):
168+
done = False
169+
if v[pos:pos+2] == magic:
170+
o += magic + magic
171+
done = True
172+
pos += 2
173+
for i, k in enumerate(keys):
174+
if v[pos:].startswith(k):
175+
o += magic + chr(i // 256) + chr(i % 256)
176+
done = True
177+
pos += len(k)
178+
break
179+
if not done:
180+
o += v[pos]
181+
pos += 1
182+
out.append(o)
183+
return rlp.encode(out)
184+
185+
186+
def decompress_db(ins):
187+
ins = rlp.decode(ins)
188+
vals = [None] * len(ins)
189+
190+
def decipher(i):
191+
if vals[i] is None:
192+
v = ins[i]
193+
o = ''
194+
pos = 0
195+
while pos < len(v):
196+
if v[pos:pos+2] == magic:
197+
if v[pos+2:pos+4] == magic:
198+
o += magic
199+
else:
200+
ind = ord(v[pos+2]) * 256 + ord(v[pos+3])
201+
o += hashfunc(decipher(ind))
202+
pos += 4
203+
else:
204+
o += v[pos]
205+
pos += 1
206+
vals[i] = o
207+
return vals[i]
208+
209+
for i in range(len(ins)):
210+
decipher(i)
211+
212+
o = EphemDB()
213+
for v in vals:
214+
o.put(hashfunc(v), v)
215+
return o
216+
217+
218+
# Convert a merkle branch directly into RLP (ie. remove
219+
# the hashing indirection). As it turns out, this is a
220+
# really compact way to represent a branch
221+
def compress_branch(db, root):
222+
o = dbget(copy.copy(root), db)
223+
224+
def evaluate_node(x):
225+
for i in range(len(x)):
226+
if len(x[i]) == HASHLEN and x[i] in db.kv:
227+
x[i] = evaluate_node(dbget(x[i], db))
228+
elif isinstance(x, list):
229+
x[i] = evaluate_node(x[i])
230+
return x
231+
232+
o2 = rlp.encode(evaluate_node(o))
233+
return o2
234+
235+
236+
def decompress_branch(branch):
237+
branch = rlp.decode(branch)
238+
db = EphemDB()
239+
240+
def evaluate_node(x):
241+
if isinstance(x, list):
242+
x = [evaluate_node(n) for n in x]
243+
x = dbput(x, db)
244+
return x
245+
evaluate_node(branch)
246+
return db
247+
248+
249+
# Test with n nodes
250+
def test(n):
251+
db = EphemDB()
252+
x = ''
253+
for i in range(n):
254+
k = hashfunc(str(i))
255+
v = hashfunc('v'+str(i))
256+
x = update(x, db, [int(a) for a in encode_bin(rlp.encode(k))], v)
257+
print x
258+
print sum([len(val) for key, val in db.db.items()])
259+
l1 = ListeningDB(db)
260+
o = 0
261+
p = 0
262+
q = 0
263+
ecks = x
264+
for i in range(min(n, 100)):
265+
x = copy.deepcopy(ecks)
266+
k = hashfunc(str(i))
267+
v = hashfunc('v'+str(i))
268+
l2 = ListeningDB(l1)
269+
v2 = get(x, l2, [int(a) for a in encode_bin(rlp.encode(k))])
270+
assert v == v2
271+
o += sum([len(val) for key, val in l2.kv.items()])
272+
cdb = compress_db(l2)
273+
p += len(cdb)
274+
assert decompress_db(cdb).kv == l2.kv
275+
cbr = compress_branch(l2, x)
276+
print rlp.decode(cbr)
277+
q += len(cbr)
278+
dbranch = decompress_branch(cbr)
279+
assert v == get(x, dbranch, [int(a) for a in encode_bin(rlp.encode(k))])
280+
# for k in l2.kv:
281+
# assert k in dbranch.kv
282+
print 'Total db size: %d' % sum([len(val) for key, val in l1.kv.items()])
283+
print 'Avg proof size: %d' % (o / min(n, 100))
284+
print 'Avg compressed proof size: %d' % (p / min(n, 100))
285+
print 'Avg branch size: %d' % (q / min(n, 100))
286+
print 'Compressed db size: %d' % len(compress_db(l1))
287+
return db, x

0 commit comments

Comments
 (0)