Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions blockchain/blockchain.go
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ func (b *Blockchain) deprecatedStore(
}

state := core.NewState(txn)
if err := state.Update(block.Number, stateUpdate, newClasses, false, true); err != nil {
if err := state.Update(block.Number, stateUpdate, newClasses, false); err != nil {
return err
}
if err := core.WriteBlockHeader(txn, block.Header); err != nil {
Expand Down Expand Up @@ -404,11 +404,11 @@ func (b *Blockchain) store(
return err
}

st, err := b.StateFactory.NewState(stateUpdate.OldRoot, nil)
st, err := b.StateFactory.NewState(stateUpdate.OldRoot, nil, batch)
if err != nil {
return err
}
if err := st.Update(block.Number, stateUpdate, newClasses, false, true); err != nil {
if err := st.Update(block.Number, stateUpdate, newClasses, false); err != nil {
return err
}

Expand Down Expand Up @@ -638,7 +638,7 @@ func (b *Blockchain) HeadState() (core.CommonStateReader, StateCloser, error) {
return nil, nil, err
}

state, err := b.StateFactory.NewState(header.GlobalStateRoot, txn)
state, err := b.StateFactory.NewState(header.GlobalStateRoot, txn, nil)

return state, noopStateCloser, err
}
Expand Down Expand Up @@ -792,7 +792,7 @@ func (b *Blockchain) getReverseStateDiff() (core.StateDiff, error) {
if err != nil {
return ret, err
}
state, err := state.New(stateUpdate.NewRoot, b.stateDB)
state, err := state.New(stateUpdate.NewRoot, b.stateDB, nil)
if err != nil {
return ret, err
}
Expand Down Expand Up @@ -873,7 +873,7 @@ func (b *Blockchain) revertHead(batch db.Batch) error {
return err
}

state, err := state.New(stateUpdate.NewRoot, b.stateDB)
state, err := state.New(stateUpdate.NewRoot, b.stateDB, batch)
if err != nil {
return err
}
Expand Down Expand Up @@ -940,7 +940,7 @@ func (b *Blockchain) Simulate(
txn := b.database.NewIndexedBatch()
defer txn.Close()

if err := b.updateStateRoots(txn, block, stateUpdate, newClasses, false); err != nil {
if err := b.updateStateRoots(txn, nil, block, stateUpdate, newClasses); err != nil {
return SimulateResult{}, err
}

Expand Down Expand Up @@ -975,7 +975,7 @@ func (b *Blockchain) Finalise(
) error {
if !b.StateFactory.UseNewState() {
err := b.database.Update(func(txn db.IndexedBatch) error {
if err := b.updateStateRoots(txn, block, stateUpdate, newClasses, true); err != nil {
if err := b.updateStateRoots(txn, nil, block, stateUpdate, newClasses); err != nil {
return err
}
commitments, err := b.updateBlockHash(block, stateUpdate)
Expand Down Expand Up @@ -1011,7 +1011,7 @@ func (b *Blockchain) Finalise(
}

err := b.database.Write(func(batch db.Batch) error {
if err := b.updateStateRoots(nil, block, stateUpdate, newClasses, true); err != nil {
if err := b.updateStateRoots(nil, batch, block, stateUpdate, newClasses); err != nil {
return err
}
commitments, err := b.updateBlockHash(block, stateUpdate)
Expand Down Expand Up @@ -1047,10 +1047,10 @@ func (b *Blockchain) Finalise(
// updateStateRoots computes and updates state roots in the block and state update
func (b *Blockchain) updateStateRoots(
txn db.IndexedBatch,
batch db.Batch,
block *core.Block,
stateUpdate *core.StateUpdate,
newClasses map[felt.Felt]core.ClassDefinition,
flushChanges bool,
) error {
var height uint64
var err error
Expand All @@ -1069,7 +1069,7 @@ func (b *Blockchain) updateStateRoots(
stateRoot = &felt.Zero
}

state, err := b.StateFactory.NewState(stateRoot, txn)
state, err := b.StateFactory.NewState(stateRoot, txn, batch)
if err != nil {
return err
}
Expand All @@ -1082,7 +1082,7 @@ func (b *Blockchain) updateStateRoots(
stateUpdate.OldRoot = &oldStateRoot

// Apply state update
if err = state.Update(block.Number, stateUpdate, newClasses, true, flushChanges); err != nil {
if err = state.Update(block.Number, stateUpdate, newClasses, true); err != nil {
return err
}

Expand Down
1 change: 0 additions & 1 deletion core/common_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ type CommonState interface {
update *StateUpdate,
declaredClasses map[felt.Felt]ClassDefinition,
skipVerifyNewRoot bool,
flushChanges bool,
) error
Revert(blockNum uint64, update *StateUpdate) error
Commitment() (felt.Felt, error)
Expand Down
1 change: 0 additions & 1 deletion core/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ func (s *State) Update(
update *StateUpdate,
declaredClasses map[felt.Felt]ClassDefinition,
skipVerifyNewRoot bool,
flushChanges bool, // TODO(maksym): added to satisfy the interface, but not used
) error {
err := s.verifyStateUpdateRoot(update.OldRoot)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion core/state/history.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type stateHistory struct {
}

func NewStateHistory(blockNum uint64, stateRoot *felt.Felt, db *StateDB) (stateHistory, error) {
state, err := New(stateRoot, db)
state, err := New(stateRoot, db, nil)
if err != nil {
return stateHistory{}, err
}
Expand Down
14 changes: 8 additions & 6 deletions core/state/history_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,11 @@ func TestStateHistoryClassOperations(t *testing.T) {
NewRoot: &felt.Zero,
StateDiff: &core.StateDiff{},
}
state, err := New(&felt.Zero, stateDB)
require.NoError(t, err)
err = state.Update(0, stateUpdate, classes, false, true)
batch := stateDB.disk.NewBatch()
state, err := New(&felt.Zero, stateDB, batch)
require.NoError(t, err)
require.NoError(t, state.Update(0, stateUpdate, classes, false))
require.NoError(t, batch.Write())
stateComm, err := state.Commitment()
require.NoError(t, err)

Expand All @@ -147,10 +148,11 @@ func TestStateHistoryClassOperations(t *testing.T) {
class2Hash: class2,
}

state, err = New(&stateComm, stateDB)
require.NoError(t, err)
err = state.Update(1, stateUpdate, classes2, false, true)
batch = stateDB.disk.NewBatch()
state, err = New(&stateComm, stateDB, batch)
require.NoError(t, err)
require.NoError(t, state.Update(1, stateUpdate, classes2, false))
require.NoError(t, batch.Write())

historyBlock0, err := NewStateHistory(0, &felt.Zero, stateDB)
require.NoError(t, err)
Expand Down
134 changes: 61 additions & 73 deletions core/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@
classTrie *trie2.Trie

stateObjects map[felt.Felt]*stateObject

batch db.Batch
}

func New(stateRoot *felt.Felt, db *StateDB) (*State, error) {
func New(stateRoot *felt.Felt, db *StateDB, batch db.Batch) (*State, error) {
contractTrie, err := db.ContractTrie(stateRoot)
if err != nil {
return nil, err
Expand All @@ -62,6 +64,7 @@
contractTrie: contractTrie,
classTrie: classTrie,
stateObjects: make(map[felt.Felt]*stateObject),
batch: batch,
}, nil
}

Expand Down Expand Up @@ -135,7 +138,7 @@

func (s *State) ContractStorageTrie(addr *felt.Felt) (core.CommonTrie, error) {
// todo: remove felt cast
return s.db.ContractStorageTrie((*felt.Felt)(&s.initRoot), addr)

Check failure on line 141 in core/state/state.go

View workflow job for this annotation

GitHub Actions / lint

unnecessary conversion (unconvert)
}

func (s *State) CompiledClassHash(
Expand Down Expand Up @@ -179,7 +182,6 @@
update *core.StateUpdate,
declaredClasses map[felt.Felt]core.ClassDefinition,
skipVerifyNewRoot bool,
flushChanges bool,
) error {
if err := s.verifyComm(update.OldRoot); err != nil {
return err
Expand Down Expand Up @@ -248,10 +250,8 @@
deployedContracts: update.StateDiff.ReplacedClasses,
})

if flushChanges {
if err := s.flush(blockNum, &stateUpdate, dirtyClasses, true); err != nil {
return err
}
if s.batch != nil {
return s.flush(blockNum, &stateUpdate, dirtyClasses, true)
}

return nil
Expand Down Expand Up @@ -318,19 +318,19 @@
if !newComm.Equal(update.OldRoot) {
return fmt.Errorf("state commitment mismatch: %v (expected) != %v (actual)", update.OldRoot, &newComm)
}

if err := s.flush(blockNum, &stateUpdate, dirtyClasses, false); err != nil {
return err
if s.batch != nil {
if err := s.flush(blockNum, &stateUpdate, dirtyClasses, false); err != nil {
return err
}
if err := s.deleteHistory(blockNum, update.StateDiff); err != nil {
return err
}
}

if err := s.db.stateCache.PopLayer(update.NewRoot, update.OldRoot); err != nil {
return err
}

if err := s.deleteHistory(blockNum, update.StateDiff); err != nil {
return err
}

return nil
}

Expand Down Expand Up @@ -484,7 +484,7 @@

su := stateUpdate{
// todo: remove felt cast
prevComm: felt.Felt(s.initRoot),

Check failure on line 487 in core/state/state.go

View workflow job for this annotation

GitHub Actions / lint

unnecessary conversion (unconvert)
curComm: newComm,
contractNodes: mergedContractNodes,
}
Expand All @@ -503,74 +503,64 @@
classes map[felt.Felt]core.ClassDefinition,
storeHistory bool,
) error {
p := pool.New().WithMaxGoroutines(runtime.GOMAXPROCS(0)).WithErrors()

p.Go(func() error {
return s.db.triedb.Update(
(*felt.StateRootHash)(&update.curComm),
(*felt.StateRootHash)(&update.prevComm),
blockNum,
update.classNodes,
update.contractNodes,
)
})

batch := s.db.disk.NewBatch()
p.Go(func() error {
for addr, obj := range s.stateObjects {
if obj == nil { // marked as deleted
if err := DeleteContract(batch, &addr); err != nil {
return err
}

// TODO(weiihann): handle hash-based, and there should be better ways of doing this
err := trieutils.DeleteStorageNodesByPath(batch, (*felt.Address)(&addr))
if err != nil {
return err
}
} else { // updated
if err := WriteContract(batch, &addr, obj.contract); err != nil {
return err
}
if err := s.db.triedb.Update(
(*felt.StateRootHash)(&update.curComm),
(*felt.StateRootHash)(&update.prevComm),
blockNum,
update.classNodes,
update.contractNodes,
s.batch,
); err != nil {
return err
}

if storeHistory {
for key, val := range obj.dirtyStorage {
if err := WriteStorageHistory(batch, &addr, &key, blockNum, val); err != nil {
return err
}
}
for addr, obj := range s.stateObjects {
if obj == nil { // marked as deleted
if err := DeleteContract(s.batch, &addr); err != nil {
return err
}

if err := WriteNonceHistory(batch, &addr, blockNum, &obj.contract.Nonce); err != nil {
return err
}
// TODO(weiihann): handle hash-based, and there should be better ways of doing this
err := trieutils.DeleteStorageNodesByPath(s.batch, (*felt.Address)(&addr))
if err != nil {
return err
}
} else { // updated
if err := WriteContract(s.batch, &addr, obj.contract); err != nil {
return err
}

if err := WriteClassHashHistory(batch, &addr, blockNum, &obj.contract.ClassHash); err != nil {
if storeHistory {
for key, val := range obj.dirtyStorage {
if err := WriteStorageHistory(s.batch, &addr, &key, blockNum, val); err != nil {
return err
}
}
}
}

for classHash, class := range classes {
if class == nil { // mark as deleted
if err := DeleteClass(batch, &classHash); err != nil {
if err := WriteNonceHistory(s.batch, &addr, blockNum, &obj.contract.Nonce); err != nil {
return err
}
} else {
if err := WriteClass(batch, &classHash, class, blockNum); err != nil {

if err := WriteClassHashHistory(s.batch, &addr, blockNum, &obj.contract.ClassHash); err != nil {
return err
}
}
}
}

return nil
})

if err := p.Wait(); err != nil {
return err
for classHash, class := range classes {
if class == nil { // mark as deleted
if err := DeleteClass(s.batch, &classHash); err != nil {
return err
}
} else {
if err := WriteClass(s.batch, &classHash, class, blockNum); err != nil {
return err
}
}
}

return batch.Write()
return nil
}

func (s *State) updateClassTrie(
Expand Down Expand Up @@ -772,39 +762,37 @@
}

func (s *State) deleteHistory(blockNum uint64, diff *core.StateDiff) error {
batch := s.db.disk.NewBatch()

for addr, storage := range diff.StorageDiffs {
for key := range storage {
if err := DeleteStorageHistory(batch, &addr, &key, blockNum); err != nil {
if err := DeleteStorageHistory(s.batch, &addr, &key, blockNum); err != nil {
return err
}
}
}

for addr := range diff.Nonces {
if err := DeleteNonceHistory(batch, &addr, blockNum); err != nil {
if err := DeleteNonceHistory(s.batch, &addr, blockNum); err != nil {
return err
}
}

for addr := range diff.ReplacedClasses {
if err := DeleteClassHashHistory(batch, &addr, blockNum); err != nil {
if err := DeleteClassHashHistory(s.batch, &addr, blockNum); err != nil {
return err
}
}

for addr := range diff.DeployedContracts {
if err := DeleteNonceHistory(batch, &addr, blockNum); err != nil {
if err := DeleteNonceHistory(s.batch, &addr, blockNum); err != nil {
return err
}

if err := DeleteClassHashHistory(batch, &addr, blockNum); err != nil {
if err := DeleteClassHashHistory(s.batch, &addr, blockNum); err != nil {
return err
}
}

return batch.Write()
return nil
}

func (s *State) compareContracts(a, b felt.Felt) int {
Expand Down
Loading
Loading