Skip to content
2 changes: 1 addition & 1 deletion lib/lambda_ethereum_consensus/fork_choice/handlers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ defmodule LambdaEthereumConsensus.ForkChoice.Handlers do
end)

with {:ok, new_state_info} <-
StateTransition.verified_transition(state_info.beacon_state, block_info),
StateTransition.verified_transition(state_info, block_info),
{:ok, _execution_status} <- Task.await(payload_verification_task) do
seconds_per_slot = ChainSpec.get("SECONDS_PER_SLOT")
intervals_per_slot = Constants.intervals_per_slot()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,23 @@ defmodule LambdaEthereumConsensus.StateTransition do

import LambdaEthereumConsensus.Utils, only: [map_ok: 2]

@spec verified_transition(BeaconState.t(), BlockInfo.t()) ::
@spec verified_transition(StateInfo.t() | BeaconState.t(), BlockInfo.t()) ::
{:ok, StateInfo.t()} | {:error, String.t()}
def verified_transition(beacon_state, block_info) do
beacon_state
|> transition(block_info.signed_block)
def verified_transition(%StateInfo{} = state_info, block_info) do
previous_roots = %{
# We store the roots indexed by slot number to ensure slot matches when reusing them.
state_info.beacon_state.slot => %{
state_root: state_info.root,
block_root: state_info.block_root
}
}

verified_transition(state_info.beacon_state, block_info, previous_roots)
end

def verified_transition(%BeaconState{} = state, block_info, previous_roots \\ %{}) do
state
|> transition(block_info.signed_block, previous_roots)
# Verify signature
|> map_ok(fn st ->
if block_signature_valid?(st, block_info.signed_block) do
Expand All @@ -31,10 +43,10 @@ defmodule LambdaEthereumConsensus.StateTransition do
end
end)
|> map_ok(fn new_state ->
with {:ok, state_info} <-
with {:ok, new_state_info} <-
StateInfo.from_beacon_state(new_state, block_root: block_info.root) do
if block_info.signed_block.message.state_root == state_info.root do
{:ok, state_info}
if block_info.signed_block.message.state_root == new_state_info.root do
{:ok, new_state_info}
else
{:error, "mismatched state roots"}
end
Expand All @@ -43,25 +55,27 @@ defmodule LambdaEthereumConsensus.StateTransition do
end

@spec transition(BeaconState.t(), SignedBeaconBlock.t()) :: {:ok, BeaconState.t()}
def transition(beacon_state, signed_block) do
def transition(beacon_state, signed_block, previous_roots \\ %{}) do
block = signed_block.message

beacon_state
# Process slots (including those with no blocks) since block
|> process_slots(block.slot)
|> process_slots(block.slot, previous_roots)
# Process block
|> map_ok(&process_block(&1, block))
end

def process_slots(%BeaconState{slot: old_slot}, slot) when old_slot >= slot,
def process_slots(state, slot, previous_roots \\ %{})

def process_slots(%BeaconState{slot: old_slot}, slot, _previous_roots) when old_slot >= slot,
do: {:error, "slot is older than state"}

def process_slots(%BeaconState{slot: old_slot} = state, slot) do
def process_slots(%BeaconState{slot: old_slot} = state, slot, previous_roots) do
slots_per_epoch = ChainSpec.get("SLOTS_PER_EPOCH")

Enum.reduce((old_slot + 1)..slot//1, {:ok, state}, fn next_slot, acc ->
acc
|> map_ok(&process_slot/1)
|> map_ok(&apply_process_slot(&1, previous_roots))
# Process epoch on the start slot of the next epoch
|> map_ok(&maybe_process_epoch(&1, rem(next_slot, slots_per_epoch)))
|> map_ok(&{:ok, %BeaconState{&1 | slot: next_slot}})
Expand All @@ -71,11 +85,28 @@ defmodule LambdaEthereumConsensus.StateTransition do
defp maybe_process_epoch(%BeaconState{} = state, 0), do: process_epoch(state)
defp maybe_process_epoch(%BeaconState{} = state, _slot_in_epoch), do: {:ok, state}

defp process_slot(%BeaconState{} = state) do
defp apply_process_slot(state, previous_roots) do
Metrics.span_operation(:process_slot, nil, nil, fn -> process_slot(state, previous_roots) end)
end

defp process_slot(%BeaconState{} = state, previous_roots) do
start_time = System.monotonic_time(:millisecond)

slot_previous_roots = Map.get(previous_roots, state.slot, nil)

# Cache state root
previous_state_root = Ssz.hash_tree_root!(state)
previous_state_root =
if slot_previous_roots do
Logger.debug("Slot #{state.slot}: previous state root in cache",
root: slot_previous_roots.state_root
)

slot_previous_roots.state_root
else
Logger.warning("Slot #{state.slot}: no previous state root in cache")
Ssz.hash_tree_root!(state)
end

slots_per_historical_root = ChainSpec.get("SLOTS_PER_HISTORICAL_ROOT")
cache_index = rem(state.slot, slots_per_historical_root)
roots = List.replace_at(state.state_roots, cache_index, previous_state_root)
Expand All @@ -95,7 +126,18 @@ defmodule LambdaEthereumConsensus.StateTransition do
end

# Cache block root
previous_block_root = Ssz.hash_tree_root!(state.latest_block_header)
previous_block_root =
if slot_previous_roots do
Logger.debug("Slot #{state.slot}, previous block root in cache",
root: slot_previous_roots.block_root
)

slot_previous_roots.block_root
else
Logger.warning("Slot #{state.slot}, no previous block root in cache")
Ssz.hash_tree_root!(state.latest_block_header)
end

roots = List.replace_at(state.block_roots, cache_index, previous_block_root)

end_time = System.monotonic_time(:millisecond)
Expand Down
Loading