Skip to content

Commit e0a6f54

Browse files
authored
perf(trie): add HashedPostStateSorted::from_reverts (#20047)
1 parent 98e9a1d commit e0a6f54

File tree

11 files changed

+437
-124
lines changed

11 files changed

+437
-124
lines changed

crates/chain-state/src/deferred_trie.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ impl DeferredTrieData {
155155
ancestors: &[Self],
156156
) -> ComputedTrieData {
157157
// Sort the current block's hashed state and trie updates
158-
let sorted_hashed_state = Arc::new(hashed_state.clone().into_sorted());
158+
let sorted_hashed_state = Arc::new(hashed_state.clone_into_sorted());
159159
let sorted_trie_updates = Arc::new(trie_updates.clone().into_sorted());
160160

161161
// Merge trie data from ancestors (oldest -> newest so later state takes precedence)

crates/engine/tree/src/tree/payload_validator.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,7 @@ where
647647

648648
// Extend state overlay with current block's sorted state.
649649
input.prefix_sets.extend(hashed_state.construct_prefix_sets());
650-
let sorted_hashed_state = hashed_state.clone().into_sorted();
650+
let sorted_hashed_state = hashed_state.clone_into_sorted();
651651
Arc::make_mut(&mut input.state).extend_ref(&sorted_hashed_state);
652652

653653
let TrieInputSorted { nodes, state, prefix_sets: prefix_sets_mut } = input;

crates/stages/stages/src/stages/merkle_changesets.rs

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@ use reth_stages_api::{
1212
BlockErrorKind, ExecInput, ExecOutput, Stage, StageCheckpoint, StageError, StageId,
1313
UnwindInput, UnwindOutput,
1414
};
15-
use reth_trie::{updates::TrieUpdates, HashedPostState, KeccakKeyHasher, StateRoot, TrieInput};
15+
use reth_trie::{
16+
updates::TrieUpdates, HashedPostStateSorted, KeccakKeyHasher, StateRoot, TrieInputSorted,
17+
};
1618
use reth_trie_db::{DatabaseHashedPostState, DatabaseStateRoot};
17-
use std::ops::Range;
19+
use std::{ops::Range, sync::Arc};
1820
use tracing::{debug, error};
1921

2022
/// The `MerkleChangeSets` stage.
@@ -105,12 +107,12 @@ impl MerkleChangeSets {
105107
Ok(target_start..target_end)
106108
}
107109

108-
/// Calculates the trie updates given a [`TrieInput`], asserting that the resulting state root
109-
/// matches the expected one for the block.
110+
/// Calculates the trie updates given a [`TrieInputSorted`], asserting that the resulting state
111+
/// root matches the expected one for the block.
110112
fn calculate_block_trie_updates<Provider: DBProvider + HeaderProvider>(
111113
provider: &Provider,
112114
block_number: BlockNumber,
113-
input: TrieInput,
115+
input: TrieInputSorted,
114116
) -> Result<TrieUpdates, StageError> {
115117
let (root, trie_updates) =
116118
StateRoot::overlay_root_from_nodes_with_updates(provider.tx_ref(), input).map_err(
@@ -192,21 +194,21 @@ impl MerkleChangeSets {
192194
);
193195
let mut per_block_state_reverts = Vec::new();
194196
for block_number in target_range.clone() {
195-
per_block_state_reverts.push(HashedPostState::from_reverts::<KeccakKeyHasher>(
197+
per_block_state_reverts.push(HashedPostStateSorted::from_reverts::<KeccakKeyHasher>(
196198
provider.tx_ref(),
197199
block_number..=block_number,
198200
)?);
199201
}
200202

201203
// Helper to retrieve state revert data for a specific block from the pre-computed array
202-
let get_block_state_revert = |block_number: BlockNumber| -> &HashedPostState {
204+
let get_block_state_revert = |block_number: BlockNumber| -> &HashedPostStateSorted {
203205
let index = (block_number - target_start) as usize;
204206
&per_block_state_reverts[index]
205207
};
206208

207209
// Helper to accumulate state reverts from a given block to the target end
208-
let compute_cumulative_state_revert = |block_number: BlockNumber| -> HashedPostState {
209-
let mut cumulative_revert = HashedPostState::default();
210+
let compute_cumulative_state_revert = |block_number: BlockNumber| -> HashedPostStateSorted {
211+
let mut cumulative_revert = HashedPostStateSorted::default();
210212
for n in (block_number..target_end).rev() {
211213
cumulative_revert.extend_ref(get_block_state_revert(n))
212214
}
@@ -216,7 +218,7 @@ impl MerkleChangeSets {
216218
// To calculate the changeset for a block, we first need the TrieUpdates which are
217219
// generated as a result of processing the block. To get these we need:
218220
// 1) The TrieUpdates which revert the db's trie to _prior_ to the block
219-
// 2) The HashedPostState to revert the db's state to _after_ the block
221+
// 2) The HashedPostStateSorted to revert the db's state to _after_ the block
220222
//
221223
// To get (1) for `target_start` we need to do a big state root calculation which takes
222224
// into account all changes between that block and db tip. For each block after the
@@ -227,12 +229,15 @@ impl MerkleChangeSets {
227229
?target_start,
228230
"Computing trie state at starting block",
229231
);
230-
let mut input = TrieInput::default();
231-
input.state = compute_cumulative_state_revert(target_start);
232-
input.prefix_sets = input.state.construct_prefix_sets();
232+
let initial_state = compute_cumulative_state_revert(target_start);
233+
let initial_prefix_sets = initial_state.construct_prefix_sets();
234+
let initial_input =
235+
TrieInputSorted::new(Arc::default(), Arc::new(initial_state), initial_prefix_sets);
233236
// target_start will be >= 1, see `determine_target_range`.
234-
input.nodes =
235-
Self::calculate_block_trie_updates(provider, target_start - 1, input.clone())?;
237+
let mut nodes = Arc::new(
238+
Self::calculate_block_trie_updates(provider, target_start - 1, initial_input)?
239+
.into_sorted(),
240+
);
236241

237242
for block_number in target_range {
238243
debug!(
@@ -242,21 +247,24 @@ impl MerkleChangeSets {
242247
);
243248
// Revert the state so that this block has been just processed, meaning we take the
244249
// cumulative revert of the subsequent block.
245-
input.state = compute_cumulative_state_revert(block_number + 1);
250+
let state = Arc::new(compute_cumulative_state_revert(block_number + 1));
251+
252+
// Construct prefix sets from only this block's `HashedPostStateSorted`, because we only
253+
// care about trie updates which occurred as a result of this block being processed.
254+
let prefix_sets = get_block_state_revert(block_number).construct_prefix_sets();
246255

247-
// Construct prefix sets from only this block's `HashedPostState`, because we only care
248-
// about trie updates which occurred as a result of this block being processed.
249-
input.prefix_sets = get_block_state_revert(block_number).construct_prefix_sets();
256+
let input = TrieInputSorted::new(Arc::clone(&nodes), state, prefix_sets);
250257

251258
// Calculate the trie updates for this block, then apply those updates to the reverts.
252259
// We calculate the overlay which will be passed into the next step using the trie
253260
// reverts prior to them being updated.
254261
let this_trie_updates =
255-
Self::calculate_block_trie_updates(provider, block_number, input.clone())?;
262+
Self::calculate_block_trie_updates(provider, block_number, input)?.into_sorted();
256263

257-
let trie_overlay = input.nodes.clone().into_sorted();
258-
input.nodes.extend_ref(&this_trie_updates);
259-
let this_trie_updates = this_trie_updates.into_sorted();
264+
let trie_overlay = Arc::clone(&nodes);
265+
let mut nodes_mut = Arc::unwrap_or_clone(nodes);
266+
nodes_mut.extend_ref(&this_trie_updates);
267+
nodes = Arc::new(nodes_mut);
260268

261269
// Write the changesets to the DB using the trie updates produced by the block, and the
262270
// trie reverts as the overlay.

crates/storage/provider/src/providers/state/historical.rs

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ use reth_trie::{
2121
proof::{Proof, StorageProof},
2222
updates::TrieUpdates,
2323
witness::TrieWitness,
24-
AccountProof, HashedPostState, HashedStorage, KeccakKeyHasher, MultiProof, MultiProofTargets,
25-
StateRoot, StorageMultiProof, StorageRoot, TrieInput,
24+
AccountProof, HashedPostState, HashedPostStateSorted, HashedStorage, KeccakKeyHasher,
25+
MultiProof, MultiProofTargets, StateRoot, StorageMultiProof, StorageRoot, TrieInput,
26+
TrieInputSorted,
2627
};
2728
use reth_trie_db::{
2829
DatabaseHashedPostState, DatabaseHashedStorage, DatabaseProof, DatabaseStateRoot,
@@ -118,7 +119,7 @@ impl<'b, Provider: DBProvider + BlockNumReader> HistoricalStateProviderRef<'b, P
118119
}
119120

120121
/// Retrieve revert hashed state for this history provider.
121-
fn revert_state(&self) -> ProviderResult<HashedPostState> {
122+
fn revert_state(&self) -> ProviderResult<HashedPostStateSorted> {
122123
if !self.lowest_available_blocks.is_account_history_available(self.block_number) ||
123124
!self.lowest_available_blocks.is_storage_history_available(self.block_number)
124125
{
@@ -133,7 +134,8 @@ impl<'b, Provider: DBProvider + BlockNumReader> HistoricalStateProviderRef<'b, P
133134
);
134135
}
135136

136-
Ok(HashedPostState::from_reverts::<KeccakKeyHasher>(self.tx(), self.block_number..)?)
137+
HashedPostStateSorted::from_reverts::<KeccakKeyHasher>(self.tx(), self.block_number..)
138+
.map_err(ProviderError::from)
137139
}
138140

139141
/// Retrieve revert hashed storage for this history provider and target address.
@@ -287,14 +289,15 @@ impl<Provider: DBProvider + BlockNumReader> StateRootProvider
287289
{
288290
fn state_root(&self, hashed_state: HashedPostState) -> ProviderResult<B256> {
289291
let mut revert_state = self.revert_state()?;
290-
revert_state.extend(hashed_state);
291-
StateRoot::overlay_root(self.tx(), revert_state)
292+
let hashed_state_sorted = hashed_state.into_sorted();
293+
revert_state.extend_ref(&hashed_state_sorted);
294+
StateRoot::overlay_root(self.tx(), &revert_state)
292295
.map_err(|err| ProviderError::Database(err.into()))
293296
}
294297

295298
fn state_root_from_nodes(&self, mut input: TrieInput) -> ProviderResult<B256> {
296-
input.prepend(self.revert_state()?);
297-
StateRoot::overlay_root_from_nodes(self.tx(), input)
299+
input.prepend(self.revert_state()?.into());
300+
StateRoot::overlay_root_from_nodes(self.tx(), TrieInputSorted::from_unsorted(input))
298301
.map_err(|err| ProviderError::Database(err.into()))
299302
}
300303

@@ -303,18 +306,22 @@ impl<Provider: DBProvider + BlockNumReader> StateRootProvider
303306
hashed_state: HashedPostState,
304307
) -> ProviderResult<(B256, TrieUpdates)> {
305308
let mut revert_state = self.revert_state()?;
306-
revert_state.extend(hashed_state);
307-
StateRoot::overlay_root_with_updates(self.tx(), revert_state)
309+
let hashed_state_sorted = hashed_state.into_sorted();
310+
revert_state.extend_ref(&hashed_state_sorted);
311+
StateRoot::overlay_root_with_updates(self.tx(), &revert_state)
308312
.map_err(|err| ProviderError::Database(err.into()))
309313
}
310314

311315
fn state_root_from_nodes_with_updates(
312316
&self,
313317
mut input: TrieInput,
314318
) -> ProviderResult<(B256, TrieUpdates)> {
315-
input.prepend(self.revert_state()?);
316-
StateRoot::overlay_root_from_nodes_with_updates(self.tx(), input)
317-
.map_err(|err| ProviderError::Database(err.into()))
319+
input.prepend(self.revert_state()?.into());
320+
StateRoot::overlay_root_from_nodes_with_updates(
321+
self.tx(),
322+
TrieInputSorted::from_unsorted(input),
323+
)
324+
.map_err(|err| ProviderError::Database(err.into()))
318325
}
319326
}
320327

@@ -367,7 +374,7 @@ impl<Provider: DBProvider + BlockNumReader> StateProofProvider
367374
address: Address,
368375
slots: &[B256],
369376
) -> ProviderResult<AccountProof> {
370-
input.prepend(self.revert_state()?);
377+
input.prepend(self.revert_state()?.into());
371378
let proof = <Proof<_, _> as DatabaseProof>::from_tx(self.tx());
372379
proof.overlay_account_proof(input, address, slots).map_err(ProviderError::from)
373380
}
@@ -377,13 +384,13 @@ impl<Provider: DBProvider + BlockNumReader> StateProofProvider
377384
mut input: TrieInput,
378385
targets: MultiProofTargets,
379386
) -> ProviderResult<MultiProof> {
380-
input.prepend(self.revert_state()?);
387+
input.prepend(self.revert_state()?.into());
381388
let proof = <Proof<_, _> as DatabaseProof>::from_tx(self.tx());
382389
proof.overlay_multiproof(input, targets).map_err(ProviderError::from)
383390
}
384391

385392
fn witness(&self, mut input: TrieInput, target: HashedPostState) -> ProviderResult<Vec<Bytes>> {
386-
input.prepend(self.revert_state()?);
393+
input.prepend(self.revert_state()?.into());
387394
TrieWitness::overlay_witness(self.tx(), input, target)
388395
.map_err(ProviderError::from)
389396
.map(|hm| hm.into_values().collect())

crates/storage/provider/src/providers/state/latest.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use reth_trie::{
1212
updates::TrieUpdates,
1313
witness::TrieWitness,
1414
AccountProof, HashedPostState, HashedStorage, KeccakKeyHasher, MultiProof, MultiProofTargets,
15-
StateRoot, StorageMultiProof, StorageRoot, TrieInput,
15+
StateRoot, StorageMultiProof, StorageRoot, TrieInput, TrieInputSorted,
1616
};
1717
use reth_trie_db::{
1818
DatabaseProof, DatabaseStateRoot, DatabaseStorageProof, DatabaseStorageRoot,
@@ -60,29 +60,32 @@ impl<Provider: BlockHashReader> BlockHashReader for LatestStateProviderRef<'_, P
6060

6161
impl<Provider: DBProvider + Sync> StateRootProvider for LatestStateProviderRef<'_, Provider> {
6262
fn state_root(&self, hashed_state: HashedPostState) -> ProviderResult<B256> {
63-
StateRoot::overlay_root(self.tx(), hashed_state)
63+
StateRoot::overlay_root(self.tx(), &hashed_state.into_sorted())
6464
.map_err(|err| ProviderError::Database(err.into()))
6565
}
6666

6767
fn state_root_from_nodes(&self, input: TrieInput) -> ProviderResult<B256> {
68-
StateRoot::overlay_root_from_nodes(self.tx(), input)
68+
StateRoot::overlay_root_from_nodes(self.tx(), TrieInputSorted::from_unsorted(input))
6969
.map_err(|err| ProviderError::Database(err.into()))
7070
}
7171

7272
fn state_root_with_updates(
7373
&self,
7474
hashed_state: HashedPostState,
7575
) -> ProviderResult<(B256, TrieUpdates)> {
76-
StateRoot::overlay_root_with_updates(self.tx(), hashed_state)
76+
StateRoot::overlay_root_with_updates(self.tx(), &hashed_state.into_sorted())
7777
.map_err(|err| ProviderError::Database(err.into()))
7878
}
7979

8080
fn state_root_from_nodes_with_updates(
8181
&self,
8282
input: TrieInput,
8383
) -> ProviderResult<(B256, TrieUpdates)> {
84-
StateRoot::overlay_root_from_nodes_with_updates(self.tx(), input)
85-
.map_err(|err| ProviderError::Database(err.into()))
84+
StateRoot::overlay_root_from_nodes_with_updates(
85+
self.tx(),
86+
TrieInputSorted::from_unsorted(input),
87+
)
88+
.map_err(|err| ProviderError::Database(err.into()))
8689
}
8790
}
8891

crates/storage/provider/src/providers/state/overlay.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use reth_trie::{
1414
hashed_cursor::{HashedCursorFactory, HashedPostStateCursorFactory},
1515
trie_cursor::{InMemoryTrieCursorFactory, TrieCursorFactory},
1616
updates::TrieUpdatesSorted,
17-
HashedPostState, HashedPostStateSorted, KeccakKeyHasher,
17+
HashedPostStateSorted, KeccakKeyHasher,
1818
};
1919
use reth_trie_db::{
2020
DatabaseHashedCursorFactory, DatabaseHashedPostState, DatabaseTrieCursorFactory,
@@ -234,13 +234,10 @@ where
234234
let _guard = debug_span!(target: "providers::state::overlay", "Retrieving hashed state reverts").entered();
235235

236236
let start = Instant::now();
237-
// TODO(mediocregopher) make from_reverts return sorted
238-
// https://github.com/paradigmxyz/reth/issues/19382
239-
let res = HashedPostState::from_reverts::<KeccakKeyHasher>(
237+
let res = HashedPostStateSorted::from_reverts::<KeccakKeyHasher>(
240238
provider.tx_ref(),
241239
from_block + 1..,
242-
)?
243-
.into_sorted();
240+
)?;
244241
retrieve_hashed_state_reverts_duration = start.elapsed();
245242
res
246243
};

crates/storage/provider/src/writer/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -917,7 +917,7 @@ mod tests {
917917
assert_eq!(
918918
StateRoot::overlay_root(
919919
tx,
920-
provider_factory.hashed_post_state(&state.bundle_state)
920+
&provider_factory.hashed_post_state(&state.bundle_state).into_sorted()
921921
)
922922
.unwrap(),
923923
state_root(expected.clone().into_iter().map(|(address, (account, storage))| (

0 commit comments

Comments
 (0)