Skip to content
This repository was archived by the owner on May 23, 2023. It is now read-only.

Commit b0437c0

Browse files
authored
Merge pull request #423 from ethereum/snapshot
State snapshot module
2 parents 1d6991c + b8b2e26 commit b0437c0

File tree

2 files changed

+245
-0
lines changed

2 files changed

+245
-0
lines changed

ethereum/pruning_trie.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,49 @@ def to_dict(self):
884884
res[key] = value
885885
return res
886886

887+
def iter_branch(self):
888+
for key_str, value in self._iter_branch(self.root_node):
889+
if key_str:
890+
nibbles = [int(x) for x in key_str.split(b'+')]
891+
else:
892+
nibbles = []
893+
key = nibbles_to_bin(without_terminator(nibbles))
894+
yield key, value
895+
896+
def _iter_branch(self, node):
897+
'''yield (key, value) stored in this and the descendant nodes
898+
:param node: node in form of list, or BLANK_NODE
899+
900+
.. note::
901+
Here key is in full form, rather than key of the individual node
902+
'''
903+
if node == BLANK_NODE:
904+
raise StopIteration
905+
906+
node_type = self._get_node_type(node)
907+
908+
if is_key_value_type(node_type):
909+
nibbles = without_terminator(unpack_to_nibbles(node[0]))
910+
key = b'+'.join([to_string(x) for x in nibbles])
911+
if node_type == NODE_TYPE_EXTENSION:
912+
sub_tree = self._iter_branch(self._decode_to_node(node[1]))
913+
else:
914+
sub_tree = [(to_string(NIBBLE_TERMINATOR), node[1])]
915+
916+
# prepend key of this node to the keys of children
917+
for sub_key, sub_value in sub_tree:
918+
full_key = (key + b'+' + sub_key).strip(b'+')
919+
yield (full_key, sub_value)
920+
921+
elif node_type == NODE_TYPE_BRANCH:
922+
for i in range(16):
923+
sub_tree = self._iter_branch(self._decode_to_node(node[i]))
924+
for sub_key, sub_value in sub_tree:
925+
full_key = (str_to_bytes(str(i)) + b'+' + sub_key).strip(b'+')
926+
yield (full_key, sub_value)
927+
if node[16]:
928+
yield (to_string(NIBBLE_TERMINATOR), node[-1])
929+
887930
def get(self, key):
888931
return self._get(self.root_node, bin_to_nibbles(to_string(key)))
889932

ethereum/snapshot.py

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
import rlp
2+
from ethereum import blocks
3+
from ethereum.blocks import Account, BlockHeader, Block, CachedBlock
4+
from ethereum.utils import is_numeric, is_string, encode_hex, decode_hex, zpad, scan_bin, big_endian_to_int
5+
from ethereum.securetrie import SecureTrie
6+
from ethereum.trie import BLANK_NODE, BLANK_ROOT
7+
from ethereum.pruning_trie import Trie
8+
9+
10+
class FakeHeader(object):
11+
def __init__(self, number, hash, state_root, gas_limit, timestamp):
12+
self.number = number
13+
self.hash = hash
14+
self.state_root = state_root
15+
self.gas_limit = gas_limit
16+
self.timestamp = timestamp
17+
18+
19+
class FakeBlock(object):
20+
def __init__(self, env, header, chain_diff):
21+
self.env = env
22+
self.config = env.config
23+
self.header = header
24+
self.uncles = []
25+
self.number = header.number
26+
self.hash = header.hash
27+
self.gas_limit = header.gas_limit
28+
self.difficulty = header.difficulty
29+
self.timestamp = header.timestamp
30+
self._chain_diff = chain_diff
31+
32+
def chain_difficulty(self):
33+
return self._chain_diff
34+
35+
def has_parent(self):
36+
return False
37+
38+
def get_ancestor_list(self, n):
39+
if n == 0 or self.header.number == 0:
40+
return []
41+
p = FakeBlock(self.env, self.header, 0)
42+
return [p] + p.get_ancestor_list(n - 1)
43+
44+
45+
def create_snapshot(chain, recent=1024):
46+
env = chain.env
47+
assert recent > env.config['MAX_UNCLE_DEPTH']+2
48+
49+
head_block = chain.head
50+
base_block_hash = chain.index.get_block_by_number(max(head_block.number-recent, 0))
51+
base_block = chain.get(base_block_hash)
52+
53+
snapshot = create_env_snapshot(base_block)
54+
snapshot['base'] = create_base_snapshot(base_block)
55+
snapshot['blocks'] = create_blocks_snapshot(base_block, head_block)
56+
snapshot['alloc'] = create_state_snapshot(env, base_block.state)
57+
58+
return snapshot
59+
60+
61+
def create_env_snapshot(base):
62+
return {
63+
'chainDifficulty': snapshot_form(base.chain_difficulty())
64+
}
65+
66+
67+
def create_base_snapshot(base):
68+
return snapshot_form(rlp.encode(base.header))
69+
70+
71+
def create_state_snapshot(env, state_trie):
72+
alloc = dict()
73+
count = 0
74+
for addr, account_rlp in state_trie.iter_branch():
75+
alloc[encode_hex(addr)] = create_account_snapshot(env, account_rlp)
76+
count += 1
77+
print "[%d] created account snapshot %s" % (count, encode_hex(addr))
78+
return alloc
79+
80+
81+
def create_account_snapshot(env, rlpdata):
82+
account = get_account(env, rlpdata)
83+
storage_trie = SecureTrie(Trie(env.db, account.storage))
84+
storage = dict()
85+
for k, v in storage_trie.iter_branch():
86+
storage[encode_hex(k.lstrip('\x00') or '\x00')] = encode_hex(v)
87+
return {
88+
'nonce': snapshot_form(account.nonce),
89+
'balance': snapshot_form(account.balance),
90+
'code': encode_hex(account.code),
91+
'storage': storage
92+
}
93+
94+
95+
def create_blocks_snapshot(base, head):
96+
recent_blocks = list()
97+
block = head
98+
while True:
99+
recent_blocks.append(snapshot_form(rlp.encode(block)))
100+
if block.prevhash != base.hash:
101+
block = block.get_parent()
102+
else:
103+
break
104+
recent_blocks.reverse()
105+
return recent_blocks
106+
107+
108+
def load_snapshot(chain, snapshot):
109+
base_header = rlp.decode(scan_bin(snapshot['base']), BlockHeader)
110+
111+
limit = len(snapshot['blocks'])
112+
# first block is child of base block
113+
first_block_rlp = scan_bin(snapshot['blocks'][0])
114+
first_header_data = rlp.decode(first_block_rlp)[0]
115+
head_block_rlp = scan_bin(snapshot['blocks'][limit-1])
116+
head_header_data = rlp.decode(head_block_rlp)[0]
117+
118+
state = load_state(chain.env, snapshot['alloc'])
119+
assert state.root_hash == base_header.state_root
120+
121+
_get_block_header = blocks.get_block_header
122+
def get_block_header(db, blockhash):
123+
if blockhash == first_header_data[0]: # first block's prevhash
124+
return base_header
125+
return _get_block_header(db, blockhash)
126+
blocks.get_block_header = get_block_header
127+
128+
_get_block = blocks.get_block
129+
def get_block(env, blockhash):
130+
if blockhash == first_header_data[0]:
131+
return FakeBlock(env, get_block_header(env.db, blockhash), int(snapshot['chainDifficulty']))
132+
return _get_block(env, blockhash)
133+
blocks.get_block = get_block
134+
135+
def validate_uncles():
136+
return True
137+
138+
print "Start loading recent blocks from snapshot"
139+
first_block = rlp.decode(first_block_rlp, Block, env=chain.env)
140+
chain.index.add_block(first_block)
141+
chain._store_block(first_block)
142+
chain.blockchain.put('HEAD', first_block.hash)
143+
chain.blockchain.put(chain.index._block_by_number_key(first_block.number), first_block.hash)
144+
chain.blockchain.commit()
145+
chain._update_head_candidate()
146+
147+
count = 0
148+
for block_rlp in snapshot['blocks'][1:]:
149+
block_rlp = scan_bin(block_rlp)
150+
block = rlp.decode(block_rlp, Block, env=chain.env)
151+
if count < chain.env.config['MAX_UNCLE_DEPTH']+2:
152+
block.__setattr__('validate_uncles', validate_uncles)
153+
if not chain.add_block(block):
154+
print "Failed to load block #%d (%s), abort." % (block.number, encode_hex(block.hash)[:8])
155+
else:
156+
count += 1
157+
print "[%d] block #%d (%s) added" % (count, block.number, encode_hex(block.hash)[:8])
158+
print "Snapshot loaded."
159+
160+
161+
def load_state(env, alloc):
162+
db = env.db
163+
state = SecureTrie(Trie(db, BLANK_ROOT))
164+
count = 0
165+
print "Start loading state from snapshot"
166+
for addr in alloc:
167+
account = alloc[addr]
168+
acct = Account.blank_account(db, env.config['ACCOUNT_INITIAL_NONCE'])
169+
if len(account['storage']) > 0:
170+
t = SecureTrie(Trie(db, BLANK_ROOT))
171+
for k in account['storage']:
172+
v = account['storage'][k]
173+
enckey = zpad(decode_hex(k), 32)
174+
t.update(enckey, decode_hex(v))
175+
acct.storage = t.root_hash
176+
if account['nonce']:
177+
acct.nonce = int(account['nonce'])
178+
if account['balance']:
179+
acct.balance = int(account['balance'])
180+
if account['code']:
181+
acct.code = decode_hex(account['code'])
182+
state.update(decode_hex(addr), rlp.encode(acct))
183+
count += 1
184+
if count % 1000 == 0:
185+
db.commit()
186+
print "[%d] loaded account %s" % (count, addr)
187+
db.commit()
188+
return state
189+
190+
191+
def get_account(env, rlpdata):
192+
if rlpdata != BLANK_NODE:
193+
return rlp.decode(rlpdata, Account, db=env.db)
194+
else:
195+
return Account.blank_account(env.db, env.config['ACCOUNT_INITIAL_NONCE'])
196+
197+
198+
def snapshot_form(val):
199+
if is_numeric(val):
200+
return str(val)
201+
elif is_string(val):
202+
return b'0x' + encode_hex(val)

0 commit comments

Comments
 (0)