Skip to content

Commit 320f2a6

Browse files
fix(trie): PST: Fix update_leaf atomicity, remove update_leaves revealed tracking, fix callback calling (paradigmxyz#21573)
1 parent 70bfdaf commit 320f2a6

File tree

3 files changed

+120
-39
lines changed

3 files changed

+120
-39
lines changed

crates/trie/sparse-parallel/src/trie.rs

Lines changed: 111 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,6 @@ pub struct ParallelSparseTrie {
123123
update_actions_buffers: Vec<Vec<SparseTrieUpdatesAction>>,
124124
/// Thresholds controlling when parallelism is enabled for different operations.
125125
parallelism_thresholds: ParallelismThresholds,
126-
/// Tracks proof targets already requested via `update_leaves` to avoid duplicate callbacks
127-
/// across retry calls. Key is (`leaf_path`, `min_depth`).
128-
requested_proof_targets: alloy_primitives::map::HashSet<(Nibbles, u8)>,
129126
/// Metrics for the parallel sparse trie.
130127
#[cfg(feature = "metrics")]
131128
metrics: crate::metrics::ParallelSparseTrieMetrics,
@@ -144,7 +141,6 @@ impl Default for ParallelSparseTrie {
144141
branch_node_masks: BranchNodeMasksMap::default(),
145142
update_actions_buffers: Vec::default(),
146143
parallelism_thresholds: Default::default(),
147-
requested_proof_targets: Default::default(),
148144
#[cfg(feature = "metrics")]
149145
metrics: Default::default(),
150146
}
@@ -1182,7 +1178,7 @@ impl SparseTrieExt for ParallelSparseTrie {
11821178
fn update_leaves(
11831179
&mut self,
11841180
updates: &mut alloy_primitives::map::B256Map<reth_trie_sparse::LeafUpdate>,
1185-
mut proof_required_fn: impl FnMut(Nibbles, u8),
1181+
mut proof_required_fn: impl FnMut(B256, u8),
11861182
) -> SparseTrieResult<()> {
11871183
use reth_trie_sparse::{provider::NoRevealProvider, LeafUpdate};
11881184

@@ -1204,10 +1200,9 @@ impl SparseTrieExt for ParallelSparseTrie {
12041200
Ok(()) => {}
12051201
Err(e) => {
12061202
if let Some(path) = Self::get_retriable_path(&e) {
1203+
let target_key = Self::nibbles_to_padded_b256(&path);
12071204
let min_len = (path.len() as u8).min(64);
1208-
if self.requested_proof_targets.insert((full_path, min_len)) {
1209-
proof_required_fn(full_path, min_len);
1210-
}
1205+
proof_required_fn(target_key, min_len);
12111206
updates.insert(key, LeafUpdate::Changed(value));
12121207
} else {
12131208
return Err(e);
@@ -1219,10 +1214,9 @@ impl SparseTrieExt for ParallelSparseTrie {
12191214
if let Err(e) = self.update_leaf(full_path, value.clone(), NoRevealProvider)
12201215
{
12211216
if let Some(path) = Self::get_retriable_path(&e) {
1217+
let target_key = Self::nibbles_to_padded_b256(&path);
12221218
let min_len = (path.len() as u8).min(64);
1223-
if self.requested_proof_targets.insert((full_path, min_len)) {
1224-
proof_required_fn(full_path, min_len);
1225-
}
1219+
proof_required_fn(target_key, min_len);
12261220
updates.insert(key, LeafUpdate::Changed(value));
12271221
} else {
12281222
return Err(e);
@@ -1234,10 +1228,9 @@ impl SparseTrieExt for ParallelSparseTrie {
12341228
// Touched is read-only: check if path is accessible, request proof if blinded.
12351229
match self.find_leaf(&full_path, None) {
12361230
Err(LeafLookupError::BlindedNode { path, .. }) => {
1231+
let target_key = Self::nibbles_to_padded_b256(&path);
12371232
let min_len = (path.len() as u8).min(64);
1238-
if self.requested_proof_targets.insert((full_path, min_len)) {
1239-
proof_required_fn(full_path, min_len);
1240-
}
1233+
proof_required_fn(target_key, min_len);
12411234
updates.insert(key, LeafUpdate::Touched);
12421235
}
12431236
// Path is fully revealed (exists or proven non-existent), no action needed.
@@ -1263,14 +1256,6 @@ impl ParallelSparseTrie {
12631256
self.updates.is_some()
12641257
}
12651258

1266-
/// Clears the set of already-requested proof targets.
1267-
///
1268-
/// Call this when reusing the trie for a new payload to ensure proof callbacks
1269-
/// are emitted fresh.
1270-
pub fn clear_requested_proof_targets(&mut self) {
1271-
self.requested_proof_targets.clear();
1272-
}
1273-
12741259
/// Returns true if parallelism should be enabled for revealing the given number of nodes.
12751260
/// Will always return false in nostd builds.
12761261
const fn is_reveal_parallelism_enabled(&self, num_nodes: usize) -> bool {
@@ -1303,6 +1288,14 @@ impl ParallelSparseTrie {
13031288
}
13041289
}
13051290

1291+
/// Converts a nibbles path to a B256, right-padding with zeros to 64 nibbles.
1292+
fn nibbles_to_padded_b256(path: &Nibbles) -> B256 {
1293+
let packed = path.pack();
1294+
let mut bytes = [0u8; 32];
1295+
bytes[..packed.len()].copy_from_slice(&packed);
1296+
B256::from(bytes)
1297+
}
1298+
13061299
/// Rolls back a partial update by removing the value, removing any inserted nodes,
13071300
/// and restoring any modified original node.
13081301
/// This ensures `update_leaf` is atomic - either it succeeds completely or leaves the trie
@@ -2110,6 +2103,9 @@ impl SparseSubtrie {
21102103
///
21112104
/// If an update requires revealing a blinded node, an error is returned if the blinded
21122105
/// provider returns an error.
2106+
///
2107+
/// This method is atomic: if an error occurs during structural changes, all modifications
2108+
/// are rolled back and the trie state is unchanged.
21132109
pub fn update_leaf(
21142110
&mut self,
21152111
full_path: Nibbles,
@@ -2118,21 +2114,46 @@ impl SparseSubtrie {
21182114
retain_updates: bool,
21192115
) -> SparseTrieResult<Option<(Nibbles, BranchNodeMasks)>> {
21202116
debug_assert!(full_path.starts_with(&self.path));
2121-
let existing = self.inner.values.insert(full_path, value);
2122-
if existing.is_some() {
2123-
// trie structure unchanged, return immediately
2117+
2118+
// Check if value already exists - if so, just update it (no structural changes needed)
2119+
if let Entry::Occupied(mut e) = self.inner.values.entry(full_path) {
2120+
e.insert(value);
21242121
return Ok(None)
21252122
}
21262123

21272124
// Here we are starting at the root of the subtrie, and traversing from there.
21282125
let mut current = Some(self.path);
21292126
let mut revealed = None;
2127+
2128+
// Track inserted nodes and modified original for rollback on error
2129+
let mut inserted_nodes: Vec<Nibbles> = Vec::new();
2130+
let mut modified_original: Option<(Nibbles, SparseNode)> = None;
2131+
21302132
while let Some(current_path) = current {
2131-
match self.update_next_node(current_path, &full_path, retain_updates)? {
2133+
// Save original node for potential rollback (only if not already saved)
2134+
if modified_original.is_none() &&
2135+
let Some(node) = self.nodes.get(&current_path)
2136+
{
2137+
modified_original = Some((current_path, node.clone()));
2138+
}
2139+
2140+
let step_result = self.update_next_node(current_path, &full_path, retain_updates);
2141+
2142+
// Handle errors from update_next_node - rollback and propagate
2143+
if let Err(e) = step_result {
2144+
self.rollback_leaf_insert(&full_path, &inserted_nodes, modified_original.take());
2145+
return Err(e);
2146+
}
2147+
2148+
match step_result? {
21322149
LeafUpdateStep::Continue { next_node } => {
21332150
current = Some(next_node);
2151+
// Clear modified_original since we haven't actually modified anything yet
2152+
modified_original = None;
21342153
}
2135-
LeafUpdateStep::Complete { reveal_path, .. } => {
2154+
LeafUpdateStep::Complete { inserted_nodes: new_inserted, reveal_path } => {
2155+
inserted_nodes.extend(new_inserted);
2156+
21362157
if let Some(reveal_path) = reveal_path &&
21372158
self.nodes.get(&reveal_path).expect("node must exist").is_hash()
21382159
{
@@ -2142,10 +2163,29 @@ impl SparseSubtrie {
21422163
leaf_full_path = ?full_path,
21432164
"Extension node child not revealed in update_leaf, falling back to db",
21442165
);
2145-
if let Some(RevealedNode { node, tree_mask, hash_mask }) =
2146-
provider.trie_node(&reveal_path)?
2147-
{
2148-
let decoded = TrieNode::decode(&mut &node[..])?;
2166+
let revealed_node = match provider.trie_node(&reveal_path) {
2167+
Ok(node) => node,
2168+
Err(e) => {
2169+
self.rollback_leaf_insert(
2170+
&full_path,
2171+
&inserted_nodes,
2172+
modified_original.take(),
2173+
);
2174+
return Err(e);
2175+
}
2176+
};
2177+
if let Some(RevealedNode { node, tree_mask, hash_mask }) = revealed_node {
2178+
let decoded = match TrieNode::decode(&mut &node[..]) {
2179+
Ok(d) => d,
2180+
Err(e) => {
2181+
self.rollback_leaf_insert(
2182+
&full_path,
2183+
&inserted_nodes,
2184+
modified_original.take(),
2185+
);
2186+
return Err(e.into());
2187+
}
2188+
};
21492189
trace!(
21502190
target: "trie::parallel_sparse",
21512191
?reveal_path,
@@ -2155,14 +2195,26 @@ impl SparseSubtrie {
21552195
"Revealing child (from lower)",
21562196
);
21572197
let masks = BranchNodeMasks::from_optional(hash_mask, tree_mask);
2158-
self.reveal_node(reveal_path, &decoded, masks)?;
2198+
if let Err(e) = self.reveal_node(reveal_path, &decoded, masks) {
2199+
self.rollback_leaf_insert(
2200+
&full_path,
2201+
&inserted_nodes,
2202+
modified_original.take(),
2203+
);
2204+
return Err(e);
2205+
}
21592206

21602207
debug_assert_eq!(
21612208
revealed, None,
21622209
"Only a single blinded node should be revealed during update_leaf"
21632210
);
21642211
revealed = masks.map(|masks| (reveal_path, masks));
21652212
} else {
2213+
self.rollback_leaf_insert(
2214+
&full_path,
2215+
&inserted_nodes,
2216+
modified_original.take(),
2217+
);
21662218
return Err(SparseTrieErrorKind::NodeNotFoundInProvider {
21672219
path: reveal_path,
21682220
}
@@ -2178,9 +2230,36 @@ impl SparseSubtrie {
21782230
}
21792231
}
21802232

2233+
// Only insert the value after all structural changes succeed
2234+
self.inner.values.insert(full_path, value);
2235+
21812236
Ok(revealed)
21822237
}
21832238

2239+
/// Rollback structural changes made during a failed leaf insert.
2240+
///
2241+
/// This removes any nodes that were inserted and restores the original node
2242+
/// that was modified, ensuring atomicity of `update_leaf`.
2243+
fn rollback_leaf_insert(
2244+
&mut self,
2245+
full_path: &Nibbles,
2246+
inserted_nodes: &[Nibbles],
2247+
modified_original: Option<(Nibbles, SparseNode)>,
2248+
) {
2249+
// Remove any values that may have been inserted
2250+
self.inner.values.remove(full_path);
2251+
2252+
// Remove all inserted nodes
2253+
for node_path in inserted_nodes {
2254+
self.nodes.remove(node_path);
2255+
}
2256+
2257+
// Restore the original node that was modified
2258+
if let Some((path, original_node)) = modified_original {
2259+
self.nodes.insert(path, original_node);
2260+
}
2261+
}
2262+
21842263
/// Processes the current node, returning what to do next in the leaf update process.
21852264
///
21862265
/// This will add or update any nodes in the trie as necessary.

crates/trie/sparse/src/traits.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -281,15 +281,18 @@ pub trait SparseTrieExt: SparseTrie {
281281
/// Once that proof is calculated and revealed via [`SparseTrie::reveal_nodes`], the same
282282
/// `updates` map can be reused to retry the update.
283283
///
284-
/// Proof targets are deduplicated by `(full_path, min_len)` across all calls to this method.
285-
/// The callback will only be invoked once per unique target, even across retry loops.
286-
/// A deeper blinded node (higher `min_len`) for the same path is considered a new target.
284+
/// The callback receives `(key, min_len)` where `key` is the full 32-byte hashed key
285+
/// (right-padded with zeros from the blinded path) and `min_len` is the minimum depth
286+
/// at which proof nodes should be returned.
287+
///
288+
/// The callback may be invoked multiple times for the same target across retry loops.
289+
/// Callers should deduplicate if needed.
287290
///
288291
/// [`LeafUpdate::Touched`] behaves identically except it does not modify the leaf value.
289292
fn update_leaves(
290293
&mut self,
291294
updates: &mut B256Map<LeafUpdate>,
292-
proof_required_fn: impl FnMut(Nibbles, u8),
295+
proof_required_fn: impl FnMut(B256, u8),
293296
) -> SparseTrieResult<()>;
294297
}
295298

crates/trie/sparse/src/trie.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -301,14 +301,13 @@ impl<T: SparseTrieExt + Default> RevealableSparseTrie<T> {
301301
pub fn update_leaves(
302302
&mut self,
303303
updates: &mut B256Map<LeafUpdate>,
304-
mut proof_required_fn: impl FnMut(Nibbles, u8),
304+
mut proof_required_fn: impl FnMut(B256, u8),
305305
) -> SparseTrieResult<()> {
306306
match self {
307307
Self::Blind(_) => {
308308
// Nothing is revealed - emit proof targets for all keys with min_len = 0
309309
for key in updates.keys() {
310-
let full_path = Nibbles::unpack(*key);
311-
proof_required_fn(full_path, 0);
310+
proof_required_fn(*key, 0);
312311
}
313312
// All updates remain in the map for retry after proofs are fetched
314313
Ok(())

0 commit comments

Comments
 (0)