From f6858868a6827b58c9abc461f662be86fefaafd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Felipe=20Torres=20Gonz=C3=A1lez?= Date: Mon, 10 Nov 2025 15:21:42 +0100 Subject: [PATCH] Async ready lib --- src/errors.rs | 3 + src/lib.rs | 4 +- src/stores/memory_store.rs | 32 +++++--- src/tree.rs | 147 +++++++++++++++++++++++++++++-------- tests/tree.rs | 12 +-- 5 files changed, 149 insertions(+), 49 deletions(-) diff --git a/src/errors.rs b/src/errors.rs index 45ce060..4add2e4 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -11,4 +11,7 @@ pub enum MerkleError { #[error("Levels and indices must have the same length")] LengthMismatch { levels: usize, indices: usize }, + + #[error("Lock was poisoned: {0}")] + LockPoisoned(String), } diff --git a/src/lib.rs b/src/lib.rs index 77b20c7..24fd04d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -31,8 +31,8 @@ fn main() { .unwrap(); println!("root: {:?}", tree.root().unwrap()); - println!("num leaves: {:?}", tree.num_leaves()); - println!("proof: {:?}", tree.proof(0).unwrap().proof); + println!("num leaves: {:?}", tree.num_leaves().unwrap()); + println!("proof: {:?}", tree.proof(0).unwrap().read().unwrap().proof); } ``` diff --git a/src/stores/memory_store.rs b/src/stores/memory_store.rs index e25fb0a..6ef1ed8 100644 --- a/src/stores/memory_store.rs +++ b/src/stores/memory_store.rs @@ -3,11 +3,16 @@ //! Simple in-memory store implementation. use crate::{MerkleError, Node, Store}; -use std::collections::HashMap; +use std::{collections::HashMap, sync::RwLock}; /// Simple in-memory store implementation using a `HashMap`. #[derive(Default)] pub struct MemoryStore { + inner: RwLock, +} + +#[derive(Default)] +struct MemoryStoreInner { store: HashMap<(u32, u64), Node>, num_leaves: u64, } @@ -26,27 +31,36 @@ impl Store for MemoryStore { indices: indices.len(), }); } - - // The memory store doesnt really allow batch reads, so just get all the - // indexes/levels one by one. + let inner = self.inner.read().map_err(|e| { + MerkleError::LockPoisoned(format!("Failed to acquire read lock on MemoryStore: {}", e)) + })?; let result = levels .iter() .zip(indices) - .map(|(&lvl, &idx)| self.store.get(&(lvl, idx)).cloned()) + .map(|(&lvl, &idx)| inner.store.get(&(lvl, idx)).cloned()) .collect(); - Ok(result) } fn put(&mut self, items: &[(u32, u64, Node)]) -> Result<(), MerkleError> { + let mut inner = self.inner.write().map_err(|e| { + MerkleError::LockPoisoned(format!( + "Failed to acquire write lock on MemoryStore: {}", + e + )) + })?; for (level, index, hash) in items { - self.store.insert((*level, *index), *hash); + inner.store.insert((*level, *index), *hash); } let counter = items.iter().filter(|(level, _, _)| *level == 0).count(); - self.num_leaves += counter as u64; + inner.num_leaves += counter as u64; Ok(()) } fn get_num_leaves(&self) -> u64 { - self.num_leaves + // For get_num_leaves, we use expect since it's a simple getter and lock poisoning + // would indicate a serious bug. Using expect provides a clearer panic message. + self.inner.read() + .expect("MemoryStore lock was poisoned - this indicates a panic occurred while holding the lock") + .num_leaves } } diff --git a/src/tree.rs b/src/tree.rs index 99cc585..16484a2 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -5,12 +5,40 @@ use crate::hasher::{Hasher, Keccak256Hasher}; use crate::{MerkleError, Node, Store}; use core::ops::Index; -use std::collections::HashMap; +use std::{collections::HashMap, sync::RwLock}; #[cfg(feature = "memory_store")] use crate::stores::MemoryStore; pub struct MerkleProof { + inner: RwLock>, +} + +impl MerkleProof { + /// Get a read lock on the proof data + pub fn read( + &self, + ) -> Result>, MerkleError> { + self.inner.read().map_err(|e| { + MerkleError::LockPoisoned(format!("Failed to acquire read lock on MerkleProof: {}", e)) + }) + } + + /// Get a write lock on the proof data + pub fn write( + &self, + ) -> Result>, MerkleError> { + self.inner.write().map_err(|e| { + MerkleError::LockPoisoned(format!( + "Failed to acquire write lock on MerkleProof: {}", + e + )) + }) + } +} + +// Make MerkleProofInner public so the read/write methods can return it +pub struct MerkleProofInner { pub proof: [Node; DEPTH], pub leaf: Node, pub index: u64, @@ -18,6 +46,14 @@ pub struct MerkleProof { } pub struct MerkleTree +where + H: Hasher, + S: Store, +{ + inner: RwLock>, +} + +struct MerkleTreeInner where H: Hasher, S: Store, @@ -78,21 +114,27 @@ where last: hasher.hash(&zero[DEPTH - 1], &zero[DEPTH - 1]), }; Self { - hasher, - store, - zeros, + inner: RwLock::new(MerkleTreeInner { + hasher, + store, + zeros, + }), } } - pub fn add_leaves(&mut self, leaves: &[Node]) -> Result<(), MerkleError> { + pub fn add_leaves(&self, leaves: &[Node]) -> Result<(), MerkleError> { // Early return if leaves.is_empty() { return Ok(()); } + let mut inner = self.inner.write().map_err(|e| { + MerkleError::LockPoisoned(format!("Failed to acquire write lock on MerkleTree: {}", e)) + })?; + // Error if leaves do not fit in the tree // TODO: Avoid calculating this. Calculate it at init or do the shifting with the generic. - if self.store.get_num_leaves() + leaves.len() as u64 > (1 << DEPTH as u64) { + if inner.store.get_num_leaves() + leaves.len() as u64 > (1 << DEPTH as u64) { return Err(MerkleError::TreeFull { depth: DEPTH as u32, capacity: 1 << DEPTH as u64, @@ -107,7 +149,7 @@ where let mut cache: HashMap<(u32, u64), Node> = HashMap::new(); for (offset, leaf) in leaves.iter().enumerate() { - let mut idx = self.store.get_num_leaves() + offset as u64; + let mut idx = inner.store.get_num_leaves() + offset as u64; let mut h = *leaf; // Store the leaf @@ -132,7 +174,7 @@ where // Batch-fetch the missing siblings and insert them in cache. if fetch_len != 0 { - let fetched = self.store.get( + let fetched = inner.store.get( &levels_to_fetch[..fetch_len], &indices_to_fetch[..fetch_len], )?; @@ -150,7 +192,7 @@ where let sib_hash = cache .get(&(level as u32, sibling_idx)) .copied() - .unwrap_or(self.zeros[level]); + .unwrap_or(inner.zeros[level]); let (left, right) = if idx & 1 == 1 { (sib_hash, h) @@ -158,7 +200,7 @@ where (h, sib_hash) }; - h = self.hasher.hash(&left, &right); + h = inner.hasher.hash(&left, &right); idx >>= 1; batch.push(((level + 1) as u32, idx, h)); @@ -167,19 +209,22 @@ where } // Update all values in a single batch - self.store.put(&batch)?; + inner.store.put(&batch)?; Ok(()) } pub fn root(&self) -> Result { - Ok(self + let inner = self.inner.read().map_err(|e| { + MerkleError::LockPoisoned(format!("Failed to acquire read lock on MerkleTree: {}", e)) + })?; + Ok(inner .store .get(&[DEPTH as u32], &[0])? .into_iter() .next() .ok_or_else(|| MerkleError::StoreError("root fetch returned empty vector".into()))? - .unwrap_or(self.zeros[DEPTH])) + .unwrap_or(inner.zeros[DEPTH])) } pub fn proof(&self, leaf_idx: u64) -> Result, MerkleError> { @@ -201,6 +246,10 @@ where }); } + let inner = self.inner.read().map_err(|e| { + MerkleError::LockPoisoned(format!("Failed to acquire read lock on MerkleTree: {}", e)) + })?; + // Build level/index lists for siblings plus the leaf. // TODO: Can't do arithmetic here with DEPTH meaning there is no // easy way to put this in the stack. Unfortunately the array size @@ -222,40 +271,66 @@ where indices.push(leaf_idx); // Batch fetch all requested nodes. - let fetched = self.store.get(&levels, &indices)?; + let fetched = inner.store.get(&levels, &indices)?; // The first DEPTH items are the siblings. let mut proof = [Node::ZERO; DEPTH]; for (d, opt) in fetched.iter().take(DEPTH).enumerate() { - proof[d] = opt.unwrap_or(self.zeros[d]); + proof[d] = opt.unwrap_or(inner.zeros[d]); } // The last item is the leaf itself. - let leaf_hash = fetched.last().copied().flatten().unwrap_or(self.zeros[0]); + let leaf_hash = fetched.last().copied().flatten().unwrap_or(inner.zeros[0]); + + // Release the lock before calling root() to avoid deadlock + let root = { + drop(inner); + self.root()? + }; Ok(MerkleProof { - proof, - leaf: leaf_hash, - index: leaf_idx, - root: self.root()?, + inner: RwLock::new(MerkleProofInner { + proof, + leaf: leaf_hash, + index: leaf_idx, + root, + }), }) } pub fn verify_proof(&self, proof: &MerkleProof) -> Result { - let mut computed_hash = proof.leaf; - for (j, sibling_hash) in proof.proof.iter().enumerate() { - let (left, right) = if proof.index & (1 << j) == 0 { + let proof_inner = proof.inner.read().map_err(|e| { + MerkleError::LockPoisoned(format!("Failed to acquire read lock on MerkleProof: {}", e)) + })?; + let tree_inner = self.inner.read().map_err(|e| { + MerkleError::LockPoisoned(format!("Failed to acquire read lock on MerkleTree: {}", e)) + })?; + let mut computed_hash = proof_inner.leaf; + let idx = proof_inner.index; + let root = proof_inner.root; + for (j, sibling_hash) in proof_inner.proof.iter().enumerate() { + let (left, right) = if idx & (1 << j) == 0 { (computed_hash, *sibling_hash) } else { (*sibling_hash, computed_hash) }; - computed_hash = self.hasher.hash(&left, &right); + computed_hash = tree_inner.hasher.hash(&left, &right); } - Ok(computed_hash == proof.root) + Ok(computed_hash == root) } - pub fn num_leaves(&self) -> u64 { - self.store.get_num_leaves() + pub fn num_leaves(&self) -> Result { + Ok(self + .inner + .read() + .map_err(|e| { + MerkleError::LockPoisoned(format!( + "Failed to acquire read lock on MerkleTree: {}", + e + )) + })? + .store + .get_num_leaves()) } } @@ -312,10 +387,14 @@ mod tests { to_node!("0x27ae5ba08d7291c96c8cbddcc148bf48a6d68c7974b94356f53754ef6171d757"), ]; - for (i, zero) in tree.zeros.front.iter().enumerate() { + let inner = tree + .inner + .read() + .expect("Lock should not be poisoned in test"); + for (i, zero) in inner.zeros.front.iter().enumerate() { assert_eq!(zero, &expected_zeros[i]); } - assert_eq!(tree.zeros.last, expected_zeros[32]); + assert_eq!(inner.zeros.last, expected_zeros[32]); } #[cfg(feature = "memory_store")] @@ -366,10 +445,14 @@ mod tests { to_node!("0x2f68a1c58e257e42a17a6c61dff5551ed560b9922ab119d5ac8e184c9734ead9"), ]; - for (i, zero) in tree.zeros.front.iter().enumerate() { + let inner = tree + .inner + .read() + .expect("Lock should not be poisoned in test"); + for (i, zero) in inner.zeros.front.iter().enumerate() { assert_eq!(zero, &expected_zeros[i]); } - assert_eq!(tree.zeros.last, expected_zeros[32]); + assert_eq!(inner.zeros.last, expected_zeros[32]); } #[cfg(feature = "memory_store")] @@ -377,7 +460,7 @@ mod tests { fn test_tree_full_error() { let hasher = Keccak256Hasher; let store = MemoryStore::default(); - let mut tree = MerkleTree::::new(hasher, store); + let tree = MerkleTree::::new(hasher, store); tree.add_leaves(&(0..8).map(|_| Node::ZERO).collect::>()) .unwrap(); diff --git a/tests/tree.rs b/tests/tree.rs index bb43464..9471877 100644 --- a/tests/tree.rs +++ b/tests/tree.rs @@ -32,7 +32,7 @@ fn dir_size(path: &Path) -> u64 { #[cfg(feature = "memory_store")] #[test] fn test_merkle_tree_keccak_32_memory() { - let mut tree: MerkleTree32 = MerkleTree::new(Keccak256Hasher, MemoryStore::default()); + let tree: MerkleTree32 = MerkleTree::new(Keccak256Hasher, MemoryStore::default()); // create 10k leaves. let leaves = (0..10_000) @@ -44,21 +44,21 @@ fn test_merkle_tree_keccak_32_memory() { tree.add_leaves(&[*i]).unwrap(); } - assert_eq!(tree.num_leaves(), 10_000); + assert_eq!(tree.num_leaves().unwrap(), 10_000); assert_eq!( tree.root().unwrap(), to_node!("0x532c79f3ea0f4873946d1b14770eaa1c157255a003e73da987b858cc287b0482") ); // reset the tree. - let mut tree: MerkleTree32 = MerkleTree::new(Keccak256Hasher, MemoryStore::default()); + let tree: MerkleTree32 = MerkleTree::new(Keccak256Hasher, MemoryStore::default()); // same but add them in batches of 1_000. for batch in leaves.chunks(1_000) { tree.add_leaves(&batch).unwrap(); } - assert_eq!(tree.num_leaves(), 10_000); + assert_eq!(tree.num_leaves().unwrap(), 10_000); assert_eq!( tree.root().unwrap(), to_node!("0x532c79f3ea0f4873946d1b14770eaa1c157255a003e73da987b858cc287b0482") @@ -67,7 +67,7 @@ fn test_merkle_tree_keccak_32_memory() { // Get proofs for each leaf and verify them. for i in 0..10_000 { let proof = tree.proof(i).unwrap(); - assert_eq!(proof.proof.len(), 32); + assert_eq!(proof.read().unwrap().proof.len(), 32); assert_eq!(tree.verify_proof(&proof).unwrap(), true); } @@ -104,7 +104,7 @@ fn test_disk_space() { S: Store, F: FnOnce() -> S, { - let mut tree: MerkleTree = + let tree: MerkleTree = MerkleTree::new(Keccak256Hasher, new_store()); for _ in 0..NUM_BATCHES {