Skip to content

Commit 0cf6e65

Browse files
perf: state root cache for slot n+1 (#1378)
Co-authored-by: Esteban Dimitroff Hodi <[email protected]>
1 parent ddc0d74 commit 0cf6e65

File tree

2 files changed

+58
-16
lines changed

2 files changed

+58
-16
lines changed

lib/lambda_ethereum_consensus/fork_choice/handlers.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ defmodule LambdaEthereumConsensus.ForkChoice.Handlers do
220220
end)
221221

222222
with {:ok, new_state_info} <-
223-
StateTransition.verified_transition(state_info.beacon_state, block_info),
223+
StateTransition.verified_transition(state_info, block_info),
224224
{:ok, _execution_status} <- Task.await(payload_verification_task) do
225225
seconds_per_slot = ChainSpec.get("SECONDS_PER_SLOT")
226226
intervals_per_slot = Constants.intervals_per_slot()

lib/lambda_ethereum_consensus/state_transition/state_transition.ex

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,23 @@ defmodule LambdaEthereumConsensus.StateTransition do
1717

1818
import LambdaEthereumConsensus.Utils, only: [map_ok: 2]
1919

20-
@spec verified_transition(BeaconState.t(), BlockInfo.t()) ::
20+
@spec verified_transition(StateInfo.t() | BeaconState.t(), BlockInfo.t()) ::
2121
{:ok, StateInfo.t()} | {:error, String.t()}
22-
def verified_transition(beacon_state, block_info) do
23-
beacon_state
24-
|> transition(block_info.signed_block)
22+
def verified_transition(%StateInfo{} = state_info, block_info) do
23+
previous_roots = %{
24+
# We store the roots indexed by slot number to ensure slot matches when reusing them.
25+
state_info.beacon_state.slot => %{
26+
state_root: state_info.root,
27+
block_root: state_info.block_root
28+
}
29+
}
30+
31+
verified_transition(state_info.beacon_state, block_info, previous_roots)
32+
end
33+
34+
def verified_transition(%BeaconState{} = state, block_info, previous_roots \\ %{}) do
35+
state
36+
|> transition(block_info.signed_block, previous_roots)
2537
# Verify signature
2638
|> map_ok(fn st ->
2739
if block_signature_valid?(st, block_info.signed_block) do
@@ -31,10 +43,10 @@ defmodule LambdaEthereumConsensus.StateTransition do
3143
end
3244
end)
3345
|> map_ok(fn new_state ->
34-
with {:ok, state_info} <-
46+
with {:ok, new_state_info} <-
3547
StateInfo.from_beacon_state(new_state, block_root: block_info.root) do
36-
if block_info.signed_block.message.state_root == state_info.root do
37-
{:ok, state_info}
48+
if block_info.signed_block.message.state_root == new_state_info.root do
49+
{:ok, new_state_info}
3850
else
3951
{:error, "mismatched state roots"}
4052
end
@@ -43,25 +55,27 @@ defmodule LambdaEthereumConsensus.StateTransition do
4355
end
4456

4557
@spec transition(BeaconState.t(), SignedBeaconBlock.t()) :: {:ok, BeaconState.t()}
46-
def transition(beacon_state, signed_block) do
58+
def transition(beacon_state, signed_block, previous_roots \\ %{}) do
4759
block = signed_block.message
4860

4961
beacon_state
5062
# Process slots (including those with no blocks) since block
51-
|> process_slots(block.slot)
63+
|> process_slots(block.slot, previous_roots)
5264
# Process block
5365
|> map_ok(&process_block(&1, block))
5466
end
5567

56-
def process_slots(%BeaconState{slot: old_slot}, slot) when old_slot >= slot,
68+
def process_slots(state, slot, previous_roots \\ %{})
69+
70+
def process_slots(%BeaconState{slot: old_slot}, slot, _previous_roots) when old_slot >= slot,
5771
do: {:error, "slot is older than state"}
5872

59-
def process_slots(%BeaconState{slot: old_slot} = state, slot) do
73+
def process_slots(%BeaconState{slot: old_slot} = state, slot, previous_roots) do
6074
slots_per_epoch = ChainSpec.get("SLOTS_PER_EPOCH")
6175

6276
Enum.reduce((old_slot + 1)..slot//1, {:ok, state}, fn next_slot, acc ->
6377
acc
64-
|> map_ok(&process_slot/1)
78+
|> map_ok(&apply_process_slot(&1, previous_roots))
6579
# Process epoch on the start slot of the next epoch
6680
|> map_ok(&maybe_process_epoch(&1, rem(next_slot, slots_per_epoch)))
6781
|> map_ok(&{:ok, %BeaconState{&1 | slot: next_slot}})
@@ -71,11 +85,28 @@ defmodule LambdaEthereumConsensus.StateTransition do
7185
defp maybe_process_epoch(%BeaconState{} = state, 0), do: process_epoch(state)
7286
defp maybe_process_epoch(%BeaconState{} = state, _slot_in_epoch), do: {:ok, state}
7387

74-
defp process_slot(%BeaconState{} = state) do
88+
defp apply_process_slot(state, previous_roots) do
89+
Metrics.span_operation(:process_slot, nil, nil, fn -> process_slot(state, previous_roots) end)
90+
end
91+
92+
defp process_slot(%BeaconState{} = state, previous_roots) do
7593
start_time = System.monotonic_time(:millisecond)
7694

95+
slot_previous_roots = Map.get(previous_roots, state.slot, nil)
96+
7797
# Cache state root
78-
previous_state_root = Ssz.hash_tree_root!(state)
98+
previous_state_root =
99+
if slot_previous_roots do
100+
Logger.debug("Slot #{state.slot}: previous state root in cache",
101+
root: slot_previous_roots.state_root
102+
)
103+
104+
slot_previous_roots.state_root
105+
else
106+
Logger.warning("Slot #{state.slot}: no previous state root in cache")
107+
Ssz.hash_tree_root!(state)
108+
end
109+
79110
slots_per_historical_root = ChainSpec.get("SLOTS_PER_HISTORICAL_ROOT")
80111
cache_index = rem(state.slot, slots_per_historical_root)
81112
roots = List.replace_at(state.state_roots, cache_index, previous_state_root)
@@ -95,7 +126,18 @@ defmodule LambdaEthereumConsensus.StateTransition do
95126
end
96127

97128
# Cache block root
98-
previous_block_root = Ssz.hash_tree_root!(state.latest_block_header)
129+
previous_block_root =
130+
if slot_previous_roots do
131+
Logger.debug("Slot #{state.slot}, previous block root in cache",
132+
root: slot_previous_roots.block_root
133+
)
134+
135+
slot_previous_roots.block_root
136+
else
137+
Logger.warning("Slot #{state.slot}, no previous block root in cache")
138+
Ssz.hash_tree_root!(state.latest_block_header)
139+
end
140+
99141
roots = List.replace_at(state.block_roots, cache_index, previous_block_root)
100142

101143
end_time = System.monotonic_time(:millisecond)

0 commit comments

Comments
 (0)