|
| 1 | +# All nodes are of the form [path1, child1, path2, child2] |
| 2 | +# or <value> |
| 3 | + |
| 4 | +from ethereum import utils |
| 5 | +from ethereum.db import EphemDB, ListeningDB |
| 6 | +import rlp, sys |
| 7 | +import copy |
| 8 | + |
| 9 | +hashfunc = utils.sha3 |
| 10 | + |
| 11 | +HASHLEN = 32 |
| 12 | + |
| 13 | + |
| 14 | +# 0100000101010111010000110100100101001001 -> ASCII |
| 15 | +def decode_bin(x): |
| 16 | + return ''.join([chr(int(x[i:i+8], 2)) for i in range(0, len(x), 8)]) |
| 17 | + |
| 18 | + |
| 19 | +# ASCII -> 0100000101010111010000110100100101001001 |
| 20 | +def encode_bin(x): |
| 21 | + o = '' |
| 22 | + for c in x: |
| 23 | + c = ord(c) |
| 24 | + p = '' |
| 25 | + for i in range(8): |
| 26 | + p = str(c % 2) + p |
| 27 | + c /= 2 |
| 28 | + o += p |
| 29 | + return o |
| 30 | + |
| 31 | + |
| 32 | +# Encodes a binary list [0,1,0,1,1,0] of any length into bytes |
| 33 | +def encode_bin_path(li): |
| 34 | + if li == []: |
| 35 | + return '' |
| 36 | + b = ''.join([str(x) for x in li]) |
| 37 | + b2 = '0' * ((4 - len(b)) % 4) + b |
| 38 | + prefix = ['00', '01', '10', '11'][len(b) % 4] |
| 39 | + if len(b2) % 8 == 4: |
| 40 | + return decode_bin('00' + prefix + b2) |
| 41 | + else: |
| 42 | + return decode_bin('100000' + prefix + b2) |
| 43 | + |
| 44 | + |
| 45 | +# Decodes bytes into a binary list |
| 46 | +def decode_bin_path(p): |
| 47 | + if p == '': |
| 48 | + return [] |
| 49 | + p = encode_bin(p) |
| 50 | + if p[0] == '1': |
| 51 | + p = p[4:] |
| 52 | + assert p[0:2] == '00' |
| 53 | + L = ['00', '01', '10', '11'].index(p[2:4]) |
| 54 | + p = p[4+((4 - L) % 4):] |
| 55 | + return [(1 if x == '1' else 0) for x in p] |
| 56 | + |
| 57 | + |
| 58 | +# Get a node from a database if needed |
| 59 | +def dbget(node, db): |
| 60 | + if len(node) == HASHLEN: |
| 61 | + return rlp.decode(db.get(node)) |
| 62 | + return node |
| 63 | + |
| 64 | + |
| 65 | +# Place a node into a database if needed |
| 66 | +def dbput(node, db): |
| 67 | + r = rlp.encode(node) |
| 68 | + if len(r) == HASHLEN or len(r) > HASHLEN * 2: |
| 69 | + h = hashfunc(r) |
| 70 | + db.put(h, r) |
| 71 | + return h |
| 72 | + return node |
| 73 | + |
| 74 | + |
| 75 | +# Get a value from a tree |
| 76 | +def get(node, db, key): |
| 77 | + node = dbget(node, db) |
| 78 | + if key == []: |
| 79 | + return node[0] |
| 80 | + elif len(node) == 1 or len(node) == 0: |
| 81 | + return '' |
| 82 | + else: |
| 83 | + sub = dbget(node[key[0]], db) |
| 84 | + if len(sub) == 2: |
| 85 | + subpath, subnode = sub |
| 86 | + else: |
| 87 | + subpath, subnode = '', sub[0] |
| 88 | + subpath = decode_bin_path(subpath) |
| 89 | + if key[1:len(subpath)+1] != subpath: |
| 90 | + return '' |
| 91 | + return get(subnode, db, key[len(subpath)+1:]) |
| 92 | + |
| 93 | + |
| 94 | +# Get length of shared prefix of inputs |
| 95 | +def get_shared_length(l1, l2): |
| 96 | + i = 0 |
| 97 | + while i < len(l1) and i < len(l2) and l1[i] == l2[i]: |
| 98 | + i += 1 |
| 99 | + return i |
| 100 | + |
| 101 | + |
| 102 | +# Replace ['', v] with [v] and compact nodes into hashes |
| 103 | +# if needed |
| 104 | +def contract_node(n, db): |
| 105 | + if len(n[0]) == 2 and n[0][0] == '': |
| 106 | + n[0] = [n[0][1]] |
| 107 | + if len(n[1]) == 2 and n[1][0] == '': |
| 108 | + n[1] = [n[1][1]] |
| 109 | + if len(n[0]) != 32: |
| 110 | + n[0] = dbput(n[0], db) |
| 111 | + if len(n[1]) != 32: |
| 112 | + n[1] = dbput(n[1], db) |
| 113 | + return dbput(n, db) |
| 114 | + |
| 115 | + |
| 116 | +# Update a trie |
| 117 | +def update(node, db, key, val): |
| 118 | + node = dbget(node, db) |
| 119 | + # Unfortunately this particular design does not allow |
| 120 | + # a node to have one child, so at the root for empty |
| 121 | + # tries we need to add two dummy children |
| 122 | + if node == '': |
| 123 | + node = [dbput([encode_bin_path([]), ''], db), |
| 124 | + dbput([encode_bin_path([1]), ''], db)] |
| 125 | + if key == []: |
| 126 | + node = [val] |
| 127 | + elif len(node) == 1: |
| 128 | + raise Exception("DB must be prefix-free") |
| 129 | + else: |
| 130 | + assert len(node) == 2, node |
| 131 | + sub = dbget(node[key[0]], db) |
| 132 | + if len(sub) == 2: |
| 133 | + _subpath, subnode = sub |
| 134 | + else: |
| 135 | + _subpath, subnode = '', sub[0] |
| 136 | + subpath = decode_bin_path(_subpath) |
| 137 | + sl = get_shared_length(subpath, key[1:]) |
| 138 | + if sl == len(subpath): |
| 139 | + node[key[0]] = [_subpath, update(subnode, db, key[sl+1:], val)] |
| 140 | + else: |
| 141 | + subpath_next = subpath[sl] |
| 142 | + n = [0, 0] |
| 143 | + n[subpath_next] = [encode_bin_path(subpath[sl+1:]), subnode] |
| 144 | + n[(1 - subpath_next)] = [encode_bin_path(key[sl+2:]), [val]] |
| 145 | + n = contract_node(n, db) |
| 146 | + node[key[0]] = dbput([encode_bin_path(subpath[:sl]), n], db) |
| 147 | + return contract_node(node, db) |
| 148 | + |
| 149 | + |
| 150 | +# Compression algorithm specialized for merkle proof databases |
| 151 | +# The idea is similar to standard compression algorithms, where |
| 152 | +# you replace an instance of a repeat with a pointer to the repeat, |
| 153 | +# except that here you replace an instance of a hash of a value |
| 154 | +# with the pointer of a value. This is useful since merkle branches |
| 155 | +# usually include nodes which contain hashes of each other |
| 156 | +magic = '\xff\x39' |
| 157 | + |
| 158 | + |
| 159 | +def compress_db(db): |
| 160 | + out = [] |
| 161 | + values = db.kv.values() |
| 162 | + keys = [hashfunc(x) for x in values] |
| 163 | + assert len(keys) < 65300 |
| 164 | + for v in values: |
| 165 | + o = '' |
| 166 | + pos = 0 |
| 167 | + while pos < len(v): |
| 168 | + done = False |
| 169 | + if v[pos:pos+2] == magic: |
| 170 | + o += magic + magic |
| 171 | + done = True |
| 172 | + pos += 2 |
| 173 | + for i, k in enumerate(keys): |
| 174 | + if v[pos:].startswith(k): |
| 175 | + o += magic + chr(i // 256) + chr(i % 256) |
| 176 | + done = True |
| 177 | + pos += len(k) |
| 178 | + break |
| 179 | + if not done: |
| 180 | + o += v[pos] |
| 181 | + pos += 1 |
| 182 | + out.append(o) |
| 183 | + return rlp.encode(out) |
| 184 | + |
| 185 | + |
| 186 | +def decompress_db(ins): |
| 187 | + ins = rlp.decode(ins) |
| 188 | + vals = [None] * len(ins) |
| 189 | + |
| 190 | + def decipher(i): |
| 191 | + if vals[i] is None: |
| 192 | + v = ins[i] |
| 193 | + o = '' |
| 194 | + pos = 0 |
| 195 | + while pos < len(v): |
| 196 | + if v[pos:pos+2] == magic: |
| 197 | + if v[pos+2:pos+4] == magic: |
| 198 | + o += magic |
| 199 | + else: |
| 200 | + ind = ord(v[pos+2]) * 256 + ord(v[pos+3]) |
| 201 | + o += hashfunc(decipher(ind)) |
| 202 | + pos += 4 |
| 203 | + else: |
| 204 | + o += v[pos] |
| 205 | + pos += 1 |
| 206 | + vals[i] = o |
| 207 | + return vals[i] |
| 208 | + |
| 209 | + for i in range(len(ins)): |
| 210 | + decipher(i) |
| 211 | + |
| 212 | + o = EphemDB() |
| 213 | + for v in vals: |
| 214 | + o.put(hashfunc(v), v) |
| 215 | + return o |
| 216 | + |
| 217 | + |
| 218 | +# Convert a merkle branch directly into RLP (ie. remove |
| 219 | +# the hashing indirection). As it turns out, this is a |
| 220 | +# really compact way to represent a branch |
| 221 | +def compress_branch(db, root): |
| 222 | + o = dbget(copy.copy(root), db) |
| 223 | + |
| 224 | + def evaluate_node(x): |
| 225 | + for i in range(len(x)): |
| 226 | + if len(x[i]) == HASHLEN and x[i] in db.kv: |
| 227 | + x[i] = evaluate_node(dbget(x[i], db)) |
| 228 | + elif isinstance(x, list): |
| 229 | + x[i] = evaluate_node(x[i]) |
| 230 | + return x |
| 231 | + |
| 232 | + o2 = rlp.encode(evaluate_node(o)) |
| 233 | + return o2 |
| 234 | + |
| 235 | + |
| 236 | +def decompress_branch(branch): |
| 237 | + branch = rlp.decode(branch) |
| 238 | + db = EphemDB() |
| 239 | + |
| 240 | + def evaluate_node(x): |
| 241 | + if isinstance(x, list): |
| 242 | + x = [evaluate_node(n) for n in x] |
| 243 | + x = dbput(x, db) |
| 244 | + return x |
| 245 | + evaluate_node(branch) |
| 246 | + return db |
| 247 | + |
| 248 | + |
| 249 | +# Test with n nodes |
| 250 | +def test(n): |
| 251 | + db = EphemDB() |
| 252 | + x = '' |
| 253 | + for i in range(n): |
| 254 | + k = hashfunc(str(i)) |
| 255 | + v = hashfunc('v'+str(i)) |
| 256 | + x = update(x, db, [int(a) for a in encode_bin(rlp.encode(k))], v) |
| 257 | + print x |
| 258 | + print sum([len(val) for key, val in db.db.items()]) |
| 259 | + l1 = ListeningDB(db) |
| 260 | + o = 0 |
| 261 | + p = 0 |
| 262 | + q = 0 |
| 263 | + ecks = x |
| 264 | + for i in range(min(n, 100)): |
| 265 | + x = copy.deepcopy(ecks) |
| 266 | + k = hashfunc(str(i)) |
| 267 | + v = hashfunc('v'+str(i)) |
| 268 | + l2 = ListeningDB(l1) |
| 269 | + v2 = get(x, l2, [int(a) for a in encode_bin(rlp.encode(k))]) |
| 270 | + assert v == v2 |
| 271 | + o += sum([len(val) for key, val in l2.kv.items()]) |
| 272 | + cdb = compress_db(l2) |
| 273 | + p += len(cdb) |
| 274 | + assert decompress_db(cdb).kv == l2.kv |
| 275 | + cbr = compress_branch(l2, x) |
| 276 | + print rlp.decode(cbr) |
| 277 | + q += len(cbr) |
| 278 | + dbranch = decompress_branch(cbr) |
| 279 | + assert v == get(x, dbranch, [int(a) for a in encode_bin(rlp.encode(k))]) |
| 280 | + # for k in l2.kv: |
| 281 | + # assert k in dbranch.kv |
| 282 | + print 'Total db size: %d' % sum([len(val) for key, val in l1.kv.items()]) |
| 283 | + print 'Avg proof size: %d' % (o / min(n, 100)) |
| 284 | + print 'Avg compressed proof size: %d' % (p / min(n, 100)) |
| 285 | + print 'Avg branch size: %d' % (q / min(n, 100)) |
| 286 | + print 'Compressed db size: %d' % len(compress_db(l1)) |
| 287 | + return db, x |
0 commit comments