diff --git a/lib/lambda_ethereum_consensus/fork_choice/handlers.ex b/lib/lambda_ethereum_consensus/fork_choice/handlers.ex index df12d606f..86f29038b 100644 --- a/lib/lambda_ethereum_consensus/fork_choice/handlers.ex +++ b/lib/lambda_ethereum_consensus/fork_choice/handlers.ex @@ -220,7 +220,7 @@ defmodule LambdaEthereumConsensus.ForkChoice.Handlers do end) with {:ok, new_state_info} <- - StateTransition.verified_transition(state_info.beacon_state, block_info), + StateTransition.verified_transition(state_info, block_info), {:ok, _execution_status} <- Task.await(payload_verification_task) do seconds_per_slot = ChainSpec.get("SECONDS_PER_SLOT") intervals_per_slot = Constants.intervals_per_slot() diff --git a/lib/lambda_ethereum_consensus/state_transition/state_transition.ex b/lib/lambda_ethereum_consensus/state_transition/state_transition.ex index 486bbb190..c20fdf891 100644 --- a/lib/lambda_ethereum_consensus/state_transition/state_transition.ex +++ b/lib/lambda_ethereum_consensus/state_transition/state_transition.ex @@ -17,11 +17,23 @@ defmodule LambdaEthereumConsensus.StateTransition do import LambdaEthereumConsensus.Utils, only: [map_ok: 2] - @spec verified_transition(BeaconState.t(), BlockInfo.t()) :: + @spec verified_transition(StateInfo.t() | BeaconState.t(), BlockInfo.t()) :: {:ok, StateInfo.t()} | {:error, String.t()} - def verified_transition(beacon_state, block_info) do - beacon_state - |> transition(block_info.signed_block) + def verified_transition(%StateInfo{} = state_info, block_info) do + previous_roots = %{ + # We store the roots indexed by slot number to ensure slot matches when reusing them. + state_info.beacon_state.slot => %{ + state_root: state_info.root, + block_root: state_info.block_root + } + } + + verified_transition(state_info.beacon_state, block_info, previous_roots) + end + + def verified_transition(%BeaconState{} = state, block_info, previous_roots \\ %{}) do + state + |> transition(block_info.signed_block, previous_roots) # Verify signature |> map_ok(fn st -> if block_signature_valid?(st, block_info.signed_block) do @@ -31,10 +43,10 @@ defmodule LambdaEthereumConsensus.StateTransition do end end) |> map_ok(fn new_state -> - with {:ok, state_info} <- + with {:ok, new_state_info} <- StateInfo.from_beacon_state(new_state, block_root: block_info.root) do - if block_info.signed_block.message.state_root == state_info.root do - {:ok, state_info} + if block_info.signed_block.message.state_root == new_state_info.root do + {:ok, new_state_info} else {:error, "mismatched state roots"} end @@ -43,25 +55,27 @@ defmodule LambdaEthereumConsensus.StateTransition do end @spec transition(BeaconState.t(), SignedBeaconBlock.t()) :: {:ok, BeaconState.t()} - def transition(beacon_state, signed_block) do + def transition(beacon_state, signed_block, previous_roots \\ %{}) do block = signed_block.message beacon_state # Process slots (including those with no blocks) since block - |> process_slots(block.slot) + |> process_slots(block.slot, previous_roots) # Process block |> map_ok(&process_block(&1, block)) end - def process_slots(%BeaconState{slot: old_slot}, slot) when old_slot >= slot, + def process_slots(state, slot, previous_roots \\ %{}) + + def process_slots(%BeaconState{slot: old_slot}, slot, _previous_roots) when old_slot >= slot, do: {:error, "slot is older than state"} - def process_slots(%BeaconState{slot: old_slot} = state, slot) do + def process_slots(%BeaconState{slot: old_slot} = state, slot, previous_roots) do slots_per_epoch = ChainSpec.get("SLOTS_PER_EPOCH") Enum.reduce((old_slot + 1)..slot//1, {:ok, state}, fn next_slot, acc -> acc - |> map_ok(&process_slot/1) + |> map_ok(&apply_process_slot(&1, previous_roots)) # Process epoch on the start slot of the next epoch |> map_ok(&maybe_process_epoch(&1, rem(next_slot, slots_per_epoch))) |> map_ok(&{:ok, %BeaconState{&1 | slot: next_slot}}) @@ -71,11 +85,28 @@ defmodule LambdaEthereumConsensus.StateTransition do defp maybe_process_epoch(%BeaconState{} = state, 0), do: process_epoch(state) defp maybe_process_epoch(%BeaconState{} = state, _slot_in_epoch), do: {:ok, state} - defp process_slot(%BeaconState{} = state) do + defp apply_process_slot(state, previous_roots) do + Metrics.span_operation(:process_slot, nil, nil, fn -> process_slot(state, previous_roots) end) + end + + defp process_slot(%BeaconState{} = state, previous_roots) do start_time = System.monotonic_time(:millisecond) + slot_previous_roots = Map.get(previous_roots, state.slot, nil) + # Cache state root - previous_state_root = Ssz.hash_tree_root!(state) + previous_state_root = + if slot_previous_roots do + Logger.debug("Slot #{state.slot}: previous state root in cache", + root: slot_previous_roots.state_root + ) + + slot_previous_roots.state_root + else + Logger.warning("Slot #{state.slot}: no previous state root in cache") + Ssz.hash_tree_root!(state) + end + slots_per_historical_root = ChainSpec.get("SLOTS_PER_HISTORICAL_ROOT") cache_index = rem(state.slot, slots_per_historical_root) roots = List.replace_at(state.state_roots, cache_index, previous_state_root) @@ -95,7 +126,18 @@ defmodule LambdaEthereumConsensus.StateTransition do end # Cache block root - previous_block_root = Ssz.hash_tree_root!(state.latest_block_header) + previous_block_root = + if slot_previous_roots do + Logger.debug("Slot #{state.slot}, previous block root in cache", + root: slot_previous_roots.block_root + ) + + slot_previous_roots.block_root + else + Logger.warning("Slot #{state.slot}, no previous block root in cache") + Ssz.hash_tree_root!(state.latest_block_header) + end + roots = List.replace_at(state.block_roots, cache_index, previous_block_root) end_time = System.monotonic_time(:millisecond)