Skip to content

Commit ff5cf2d

Browse files
committed
state atomicity
1 parent 3e01f64 commit ff5cf2d

File tree

25 files changed

+428
-254
lines changed

25 files changed

+428
-254
lines changed

blockchain/blockchain.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ func (b *Blockchain) deprecatedStore(
339339
}
340340

341341
state := core.NewState(txn)
342-
if err := state.Update(block.Number, stateUpdate, newClasses, false, true); err != nil {
342+
if err := state.Update(block.Number, stateUpdate, newClasses, false); err != nil {
343343
return err
344344
}
345345
if err := core.WriteBlockHeader(txn, block.Header); err != nil {
@@ -404,11 +404,11 @@ func (b *Blockchain) store(
404404
return err
405405
}
406406

407-
st, err := b.StateFactory.NewState(stateUpdate.OldRoot, nil)
407+
st, err := b.StateFactory.NewState(stateUpdate.OldRoot, nil, batch)
408408
if err != nil {
409409
return err
410410
}
411-
if err := st.Update(block.Number, stateUpdate, newClasses, false, true); err != nil {
411+
if err := st.Update(block.Number, stateUpdate, newClasses, false); err != nil {
412412
return err
413413
}
414414

@@ -638,7 +638,7 @@ func (b *Blockchain) HeadState() (core.CommonStateReader, StateCloser, error) {
638638
return nil, nil, err
639639
}
640640

641-
state, err := b.StateFactory.NewState(header.GlobalStateRoot, txn)
641+
state, err := b.StateFactory.NewState(header.GlobalStateRoot, txn, nil)
642642

643643
return state, noopStateCloser, err
644644
}
@@ -792,7 +792,7 @@ func (b *Blockchain) getReverseStateDiff() (core.StateDiff, error) {
792792
if err != nil {
793793
return ret, err
794794
}
795-
state, err := state.New(stateUpdate.NewRoot, b.stateDB)
795+
state, err := state.New(stateUpdate.NewRoot, b.stateDB, nil)
796796
if err != nil {
797797
return ret, err
798798
}
@@ -873,7 +873,7 @@ func (b *Blockchain) revertHead(batch db.Batch) error {
873873
return err
874874
}
875875

876-
state, err := state.New(stateUpdate.NewRoot, b.stateDB)
876+
state, err := state.New(stateUpdate.NewRoot, b.stateDB, batch)
877877
if err != nil {
878878
return err
879879
}
@@ -940,7 +940,7 @@ func (b *Blockchain) Simulate(
940940
txn := b.database.NewIndexedBatch()
941941
defer txn.Close()
942942

943-
if err := b.updateStateRoots(txn, block, stateUpdate, newClasses, false); err != nil {
943+
if err := b.updateStateRoots(txn, nil, block, stateUpdate, newClasses); err != nil {
944944
return SimulateResult{}, err
945945
}
946946

@@ -975,7 +975,7 @@ func (b *Blockchain) Finalise(
975975
) error {
976976
if !b.StateFactory.UseNewState() {
977977
err := b.database.Update(func(txn db.IndexedBatch) error {
978-
if err := b.updateStateRoots(txn, block, stateUpdate, newClasses, true); err != nil {
978+
if err := b.updateStateRoots(txn, nil, block, stateUpdate, newClasses); err != nil {
979979
return err
980980
}
981981
commitments, err := b.updateBlockHash(block, stateUpdate)
@@ -1011,7 +1011,7 @@ func (b *Blockchain) Finalise(
10111011
}
10121012

10131013
err := b.database.Write(func(batch db.Batch) error {
1014-
if err := b.updateStateRoots(nil, block, stateUpdate, newClasses, true); err != nil {
1014+
if err := b.updateStateRoots(nil, batch, block, stateUpdate, newClasses); err != nil {
10151015
return err
10161016
}
10171017
commitments, err := b.updateBlockHash(block, stateUpdate)
@@ -1047,10 +1047,10 @@ func (b *Blockchain) Finalise(
10471047
// updateStateRoots computes and updates state roots in the block and state update
10481048
func (b *Blockchain) updateStateRoots(
10491049
txn db.IndexedBatch,
1050+
batch db.Batch,
10501051
block *core.Block,
10511052
stateUpdate *core.StateUpdate,
10521053
newClasses map[felt.Felt]core.ClassDefinition,
1053-
flushChanges bool,
10541054
) error {
10551055
var height uint64
10561056
var err error
@@ -1069,7 +1069,7 @@ func (b *Blockchain) updateStateRoots(
10691069
stateRoot = &felt.Zero
10701070
}
10711071

1072-
state, err := b.StateFactory.NewState(stateRoot, txn)
1072+
state, err := b.StateFactory.NewState(stateRoot, txn, batch)
10731073
if err != nil {
10741074
return err
10751075
}
@@ -1082,7 +1082,7 @@ func (b *Blockchain) updateStateRoots(
10821082
stateUpdate.OldRoot = &oldStateRoot
10831083

10841084
// Apply state update
1085-
if err = state.Update(block.Number, stateUpdate, newClasses, true, flushChanges); err != nil {
1085+
if err = state.Update(block.Number, stateUpdate, newClasses, true); err != nil {
10861086
return err
10871087
}
10881088

core/common_state.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ type CommonState interface {
2121
update *StateUpdate,
2222
declaredClasses map[felt.Felt]ClassDefinition,
2323
skipVerifyNewRoot bool,
24-
flushChanges bool,
2524
) error
2625
Revert(blockNum uint64, update *StateUpdate) error
2726
Commitment() (felt.Felt, error)

core/state.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,6 @@ func (s *State) Update(
202202
update *StateUpdate,
203203
declaredClasses map[felt.Felt]ClassDefinition,
204204
skipVerifyNewRoot bool,
205-
flushChanges bool, // TODO(maksym): added to satisfy the interface, but not used
206205
) error {
207206
err := s.verifyStateUpdateRoot(update.OldRoot)
208207
if err != nil {

core/state/history.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ type stateHistory struct {
1717
}
1818

1919
func NewStateHistory(blockNum uint64, stateRoot *felt.Felt, db *StateDB) (stateHistory, error) {
20-
state, err := New(stateRoot, db)
20+
state, err := New(stateRoot, db, nil)
2121
if err != nil {
2222
return stateHistory{}, err
2323
}

core/state/history_test.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,11 @@ func TestStateHistoryClassOperations(t *testing.T) {
131131
NewRoot: &felt.Zero,
132132
StateDiff: &core.StateDiff{},
133133
}
134-
state, err := New(&felt.Zero, stateDB)
135-
require.NoError(t, err)
136-
err = state.Update(0, stateUpdate, classes, false, true)
134+
batch := stateDB.disk.NewBatch()
135+
state, err := New(&felt.Zero, stateDB, batch)
137136
require.NoError(t, err)
137+
require.NoError(t, state.Update(0, stateUpdate, classes, false))
138+
require.NoError(t, batch.Write())
138139
stateComm, err := state.Commitment()
139140
require.NoError(t, err)
140141

@@ -147,10 +148,11 @@ func TestStateHistoryClassOperations(t *testing.T) {
147148
class2Hash: class2,
148149
}
149150

150-
state, err = New(&stateComm, stateDB)
151-
require.NoError(t, err)
152-
err = state.Update(1, stateUpdate, classes2, false, true)
151+
batch = stateDB.disk.NewBatch()
152+
state, err = New(&stateComm, stateDB, batch)
153153
require.NoError(t, err)
154+
require.NoError(t, state.Update(1, stateUpdate, classes2, false))
155+
require.NoError(t, batch.Write())
154156

155157
historyBlock0, err := NewStateHistory(0, &felt.Zero, stateDB)
156158
require.NoError(t, err)

core/state/state.go

Lines changed: 61 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,11 @@ type State struct {
4343
classTrie *trie2.Trie
4444

4545
stateObjects map[felt.Felt]*stateObject
46+
47+
batch db.Batch
4648
}
4749

48-
func New(stateRoot *felt.Felt, db *StateDB) (*State, error) {
50+
func New(stateRoot *felt.Felt, db *StateDB, batch db.Batch) (*State, error) {
4951
contractTrie, err := db.ContractTrie(stateRoot)
5052
if err != nil {
5153
return nil, err
@@ -62,6 +64,7 @@ func New(stateRoot *felt.Felt, db *StateDB) (*State, error) {
6264
contractTrie: contractTrie,
6365
classTrie: classTrie,
6466
stateObjects: make(map[felt.Felt]*stateObject),
67+
batch: batch,
6568
}, nil
6669
}
6770

@@ -179,7 +182,6 @@ func (s *State) Update(
179182
update *core.StateUpdate,
180183
declaredClasses map[felt.Felt]core.ClassDefinition,
181184
skipVerifyNewRoot bool,
182-
flushChanges bool,
183185
) error {
184186
if err := s.verifyComm(update.OldRoot); err != nil {
185187
return err
@@ -248,10 +250,8 @@ func (s *State) Update(
248250
deployedContracts: update.StateDiff.ReplacedClasses,
249251
})
250252

251-
if flushChanges {
252-
if err := s.flush(blockNum, &stateUpdate, dirtyClasses, true); err != nil {
253-
return err
254-
}
253+
if s.batch != nil {
254+
return s.flush(blockNum, &stateUpdate, dirtyClasses, true)
255255
}
256256

257257
return nil
@@ -318,19 +318,19 @@ func (s *State) Revert(blockNum uint64, update *core.StateUpdate) error {
318318
if !newComm.Equal(update.OldRoot) {
319319
return fmt.Errorf("state commitment mismatch: %v (expected) != %v (actual)", update.OldRoot, &newComm)
320320
}
321-
322-
if err := s.flush(blockNum, &stateUpdate, dirtyClasses, false); err != nil {
323-
return err
321+
if s.batch != nil {
322+
if err := s.flush(blockNum, &stateUpdate, dirtyClasses, false); err != nil {
323+
return err
324+
}
325+
if err := s.deleteHistory(blockNum, update.StateDiff); err != nil {
326+
return err
327+
}
324328
}
325329

326330
if err := s.db.stateCache.PopLayer(update.NewRoot, update.OldRoot); err != nil {
327331
return err
328332
}
329333

330-
if err := s.deleteHistory(blockNum, update.StateDiff); err != nil {
331-
return err
332-
}
333-
334334
return nil
335335
}
336336

@@ -503,74 +503,64 @@ func (s *State) flush(
503503
classes map[felt.Felt]core.ClassDefinition,
504504
storeHistory bool,
505505
) error {
506-
p := pool.New().WithMaxGoroutines(runtime.GOMAXPROCS(0)).WithErrors()
507-
508-
p.Go(func() error {
509-
return s.db.triedb.Update(
510-
(*felt.StateRootHash)(&update.curComm),
511-
(*felt.StateRootHash)(&update.prevComm),
512-
blockNum,
513-
update.classNodes,
514-
update.contractNodes,
515-
)
516-
})
517-
518-
batch := s.db.disk.NewBatch()
519-
p.Go(func() error {
520-
for addr, obj := range s.stateObjects {
521-
if obj == nil { // marked as deleted
522-
if err := DeleteContract(batch, &addr); err != nil {
523-
return err
524-
}
525-
526-
// TODO(weiihann): handle hash-based, and there should be better ways of doing this
527-
err := trieutils.DeleteStorageNodesByPath(batch, (*felt.Address)(&addr))
528-
if err != nil {
529-
return err
530-
}
531-
} else { // updated
532-
if err := WriteContract(batch, &addr, obj.contract); err != nil {
533-
return err
534-
}
506+
if err := s.db.triedb.Update(
507+
(*felt.StateRootHash)(&update.curComm),
508+
(*felt.StateRootHash)(&update.prevComm),
509+
blockNum,
510+
update.classNodes,
511+
update.contractNodes,
512+
s.batch,
513+
); err != nil {
514+
return err
515+
}
535516

536-
if storeHistory {
537-
for key, val := range obj.dirtyStorage {
538-
if err := WriteStorageHistory(batch, &addr, &key, blockNum, val); err != nil {
539-
return err
540-
}
541-
}
517+
for addr, obj := range s.stateObjects {
518+
if obj == nil { // marked as deleted
519+
if err := DeleteContract(s.batch, &addr); err != nil {
520+
return err
521+
}
542522

543-
if err := WriteNonceHistory(batch, &addr, blockNum, &obj.contract.Nonce); err != nil {
544-
return err
545-
}
523+
// TODO(weiihann): handle hash-based, and there should be better ways of doing this
524+
err := trieutils.DeleteStorageNodesByPath(s.batch, (*felt.Address)(&addr))
525+
if err != nil {
526+
return err
527+
}
528+
} else { // updated
529+
if err := WriteContract(s.batch, &addr, obj.contract); err != nil {
530+
return err
531+
}
546532

547-
if err := WriteClassHashHistory(batch, &addr, blockNum, &obj.contract.ClassHash); err != nil {
533+
if storeHistory {
534+
for key, val := range obj.dirtyStorage {
535+
if err := WriteStorageHistory(s.batch, &addr, &key, blockNum, val); err != nil {
548536
return err
549537
}
550538
}
551-
}
552-
}
553539

554-
for classHash, class := range classes {
555-
if class == nil { // mark as deleted
556-
if err := DeleteClass(batch, &classHash); err != nil {
540+
if err := WriteNonceHistory(s.batch, &addr, blockNum, &obj.contract.Nonce); err != nil {
557541
return err
558542
}
559-
} else {
560-
if err := WriteClass(batch, &classHash, class, blockNum); err != nil {
543+
544+
if err := WriteClassHashHistory(s.batch, &addr, blockNum, &obj.contract.ClassHash); err != nil {
561545
return err
562546
}
563547
}
564548
}
549+
}
565550

566-
return nil
567-
})
568-
569-
if err := p.Wait(); err != nil {
570-
return err
551+
for classHash, class := range classes {
552+
if class == nil { // mark as deleted
553+
if err := DeleteClass(s.batch, &classHash); err != nil {
554+
return err
555+
}
556+
} else {
557+
if err := WriteClass(s.batch, &classHash, class, blockNum); err != nil {
558+
return err
559+
}
560+
}
571561
}
572562

573-
return batch.Write()
563+
return nil
574564
}
575565

576566
func (s *State) updateClassTrie(
@@ -772,39 +762,37 @@ func (s *State) valueAt(prefix []byte, blockNum uint64, cb func(val []byte) erro
772762
}
773763

774764
func (s *State) deleteHistory(blockNum uint64, diff *core.StateDiff) error {
775-
batch := s.db.disk.NewBatch()
776-
777765
for addr, storage := range diff.StorageDiffs {
778766
for key := range storage {
779-
if err := DeleteStorageHistory(batch, &addr, &key, blockNum); err != nil {
767+
if err := DeleteStorageHistory(s.batch, &addr, &key, blockNum); err != nil {
780768
return err
781769
}
782770
}
783771
}
784772

785773
for addr := range diff.Nonces {
786-
if err := DeleteNonceHistory(batch, &addr, blockNum); err != nil {
774+
if err := DeleteNonceHistory(s.batch, &addr, blockNum); err != nil {
787775
return err
788776
}
789777
}
790778

791779
for addr := range diff.ReplacedClasses {
792-
if err := DeleteClassHashHistory(batch, &addr, blockNum); err != nil {
780+
if err := DeleteClassHashHistory(s.batch, &addr, blockNum); err != nil {
793781
return err
794782
}
795783
}
796784

797785
for addr := range diff.DeployedContracts {
798-
if err := DeleteNonceHistory(batch, &addr, blockNum); err != nil {
786+
if err := DeleteNonceHistory(s.batch, &addr, blockNum); err != nil {
799787
return err
800788
}
801789

802-
if err := DeleteClassHashHistory(batch, &addr, blockNum); err != nil {
790+
if err := DeleteClassHashHistory(s.batch, &addr, blockNum); err != nil {
803791
return err
804792
}
805793
}
806794

807-
return batch.Write()
795+
return nil
808796
}
809797

810798
func (s *State) compareContracts(a, b felt.Felt) int {

0 commit comments

Comments
 (0)