Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,7 @@ pub enum MerkleError {

#[error("Levels and indices must have the same length")]
LengthMismatch { levels: usize, indices: usize },

#[error("Lock was poisoned: {0}")]
LockPoisoned(String),
}
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ fn main() {
.unwrap();

println!("root: {:?}", tree.root().unwrap());
println!("num leaves: {:?}", tree.num_leaves());
println!("proof: {:?}", tree.proof(0).unwrap().proof);
println!("num leaves: {:?}", tree.num_leaves().unwrap());
println!("proof: {:?}", tree.proof(0).unwrap().read().unwrap().proof);
}
```

Expand Down
32 changes: 23 additions & 9 deletions src/stores/memory_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
//! Simple in-memory store implementation.

use crate::{MerkleError, Node, Store};
use std::collections::HashMap;
use std::{collections::HashMap, sync::RwLock};

/// Simple in-memory store implementation using a `HashMap`.
#[derive(Default)]
pub struct MemoryStore {
inner: RwLock<MemoryStoreInner>,
}

#[derive(Default)]
struct MemoryStoreInner {
store: HashMap<(u32, u64), Node>,
num_leaves: u64,
}
Expand All @@ -26,27 +31,36 @@ impl Store for MemoryStore {
indices: indices.len(),
});
}

// The memory store doesnt really allow batch reads, so just get all the
// indexes/levels one by one.
let inner = self.inner.read().map_err(|e| {
MerkleError::LockPoisoned(format!("Failed to acquire read lock on MemoryStore: {}", e))
})?;
let result = levels
.iter()
.zip(indices)
.map(|(&lvl, &idx)| self.store.get(&(lvl, idx)).cloned())
.map(|(&lvl, &idx)| inner.store.get(&(lvl, idx)).cloned())
.collect();

Ok(result)
}

fn put(&mut self, items: &[(u32, u64, Node)]) -> Result<(), MerkleError> {
let mut inner = self.inner.write().map_err(|e| {
MerkleError::LockPoisoned(format!(
"Failed to acquire write lock on MemoryStore: {}",
e
))
})?;
for (level, index, hash) in items {
self.store.insert((*level, *index), *hash);
inner.store.insert((*level, *index), *hash);
}
let counter = items.iter().filter(|(level, _, _)| *level == 0).count();
self.num_leaves += counter as u64;
inner.num_leaves += counter as u64;
Ok(())
}
fn get_num_leaves(&self) -> u64 {
self.num_leaves
// For get_num_leaves, we use expect since it's a simple getter and lock poisoning
// would indicate a serious bug. Using expect provides a clearer panic message.
self.inner.read()
.expect("MemoryStore lock was poisoned - this indicates a panic occurred while holding the lock")
.num_leaves
}
}
147 changes: 115 additions & 32 deletions src/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,55 @@
use crate::hasher::{Hasher, Keccak256Hasher};
use crate::{MerkleError, Node, Store};
use core::ops::Index;
use std::collections::HashMap;
use std::{collections::HashMap, sync::RwLock};

#[cfg(feature = "memory_store")]
use crate::stores::MemoryStore;

pub struct MerkleProof<const DEPTH: usize> {
inner: RwLock<MerkleProofInner<DEPTH>>,
}

impl<const DEPTH: usize> MerkleProof<DEPTH> {
/// Get a read lock on the proof data
pub fn read(
&self,
) -> Result<std::sync::RwLockReadGuard<'_, MerkleProofInner<DEPTH>>, MerkleError> {
self.inner.read().map_err(|e| {
MerkleError::LockPoisoned(format!("Failed to acquire read lock on MerkleProof: {}", e))
})
}

/// Get a write lock on the proof data
pub fn write(
&self,
) -> Result<std::sync::RwLockWriteGuard<'_, MerkleProofInner<DEPTH>>, MerkleError> {
self.inner.write().map_err(|e| {
MerkleError::LockPoisoned(format!(
"Failed to acquire write lock on MerkleProof: {}",
e
))
})
}
}

// Make MerkleProofInner public so the read/write methods can return it
pub struct MerkleProofInner<const DEPTH: usize> {
pub proof: [Node; DEPTH],
pub leaf: Node,
pub index: u64,
pub root: Node,
}

pub struct MerkleTree<H, S, const DEPTH: usize>
where
H: Hasher,
S: Store,
{
inner: RwLock<MerkleTreeInner<H, S, DEPTH>>,
}

struct MerkleTreeInner<H, S, const DEPTH: usize>
where
H: Hasher,
S: Store,
Expand Down Expand Up @@ -78,21 +114,27 @@ where
last: hasher.hash(&zero[DEPTH - 1], &zero[DEPTH - 1]),
};
Self {
hasher,
store,
zeros,
inner: RwLock::new(MerkleTreeInner {
hasher,
store,
zeros,
}),
}
}

pub fn add_leaves(&mut self, leaves: &[Node]) -> Result<(), MerkleError> {
pub fn add_leaves(&self, leaves: &[Node]) -> Result<(), MerkleError> {
// Early return
if leaves.is_empty() {
return Ok(());
}

let mut inner = self.inner.write().map_err(|e| {
MerkleError::LockPoisoned(format!("Failed to acquire write lock on MerkleTree: {}", e))
})?;

// Error if leaves do not fit in the tree
// TODO: Avoid calculating this. Calculate it at init or do the shifting with the generic.
if self.store.get_num_leaves() + leaves.len() as u64 > (1 << DEPTH as u64) {
if inner.store.get_num_leaves() + leaves.len() as u64 > (1 << DEPTH as u64) {
return Err(MerkleError::TreeFull {
depth: DEPTH as u32,
capacity: 1 << DEPTH as u64,
Expand All @@ -107,7 +149,7 @@ where
let mut cache: HashMap<(u32, u64), Node> = HashMap::new();

for (offset, leaf) in leaves.iter().enumerate() {
let mut idx = self.store.get_num_leaves() + offset as u64;
let mut idx = inner.store.get_num_leaves() + offset as u64;
let mut h = *leaf;

// Store the leaf
Expand All @@ -132,7 +174,7 @@ where

// Batch-fetch the missing siblings and insert them in cache.
if fetch_len != 0 {
let fetched = self.store.get(
let fetched = inner.store.get(
&levels_to_fetch[..fetch_len],
&indices_to_fetch[..fetch_len],
)?;
Expand All @@ -150,15 +192,15 @@ where
let sib_hash = cache
.get(&(level as u32, sibling_idx))
.copied()
.unwrap_or(self.zeros[level]);
.unwrap_or(inner.zeros[level]);

let (left, right) = if idx & 1 == 1 {
(sib_hash, h)
} else {
(h, sib_hash)
};

h = self.hasher.hash(&left, &right);
h = inner.hasher.hash(&left, &right);
idx >>= 1;

batch.push(((level + 1) as u32, idx, h));
Expand All @@ -167,19 +209,22 @@ where
}

// Update all values in a single batch
self.store.put(&batch)?;
inner.store.put(&batch)?;

Ok(())
}

pub fn root(&self) -> Result<Node, MerkleError> {
Ok(self
let inner = self.inner.read().map_err(|e| {
MerkleError::LockPoisoned(format!("Failed to acquire read lock on MerkleTree: {}", e))
})?;
Ok(inner
.store
.get(&[DEPTH as u32], &[0])?
.into_iter()
.next()
.ok_or_else(|| MerkleError::StoreError("root fetch returned empty vector".into()))?
.unwrap_or(self.zeros[DEPTH]))
.unwrap_or(inner.zeros[DEPTH]))
}

pub fn proof(&self, leaf_idx: u64) -> Result<MerkleProof<DEPTH>, MerkleError> {
Expand All @@ -201,6 +246,10 @@ where
});
}

let inner = self.inner.read().map_err(|e| {
MerkleError::LockPoisoned(format!("Failed to acquire read lock on MerkleTree: {}", e))
})?;

// Build level/index lists for siblings plus the leaf.
// TODO: Can't do arithmetic here with DEPTH meaning there is no
// easy way to put this in the stack. Unfortunately the array size
Expand All @@ -222,40 +271,66 @@ where
indices.push(leaf_idx);

// Batch fetch all requested nodes.
let fetched = self.store.get(&levels, &indices)?;
let fetched = inner.store.get(&levels, &indices)?;

// The first DEPTH items are the siblings.
let mut proof = [Node::ZERO; DEPTH];
for (d, opt) in fetched.iter().take(DEPTH).enumerate() {
proof[d] = opt.unwrap_or(self.zeros[d]);
proof[d] = opt.unwrap_or(inner.zeros[d]);
}

// The last item is the leaf itself.
let leaf_hash = fetched.last().copied().flatten().unwrap_or(self.zeros[0]);
let leaf_hash = fetched.last().copied().flatten().unwrap_or(inner.zeros[0]);

// Release the lock before calling root() to avoid deadlock
let root = {
drop(inner);
self.root()?
};

Ok(MerkleProof {
proof,
leaf: leaf_hash,
index: leaf_idx,
root: self.root()?,
inner: RwLock::new(MerkleProofInner {
proof,
leaf: leaf_hash,
index: leaf_idx,
root,
}),
})
}

pub fn verify_proof(&self, proof: &MerkleProof<DEPTH>) -> Result<bool, MerkleError> {
let mut computed_hash = proof.leaf;
for (j, sibling_hash) in proof.proof.iter().enumerate() {
let (left, right) = if proof.index & (1 << j) == 0 {
let proof_inner = proof.inner.read().map_err(|e| {
MerkleError::LockPoisoned(format!("Failed to acquire read lock on MerkleProof: {}", e))
})?;
let tree_inner = self.inner.read().map_err(|e| {
MerkleError::LockPoisoned(format!("Failed to acquire read lock on MerkleTree: {}", e))
})?;
let mut computed_hash = proof_inner.leaf;
let idx = proof_inner.index;
let root = proof_inner.root;
for (j, sibling_hash) in proof_inner.proof.iter().enumerate() {
let (left, right) = if idx & (1 << j) == 0 {
(computed_hash, *sibling_hash)
} else {
(*sibling_hash, computed_hash)
};
computed_hash = self.hasher.hash(&left, &right);
computed_hash = tree_inner.hasher.hash(&left, &right);
}
Ok(computed_hash == proof.root)
Ok(computed_hash == root)
}

pub fn num_leaves(&self) -> u64 {
self.store.get_num_leaves()
pub fn num_leaves(&self) -> Result<u64, MerkleError> {
Ok(self
.inner
.read()
.map_err(|e| {
MerkleError::LockPoisoned(format!(
"Failed to acquire read lock on MerkleTree: {}",
e
))
})?
.store
.get_num_leaves())
}
}

Expand Down Expand Up @@ -312,10 +387,14 @@ mod tests {
to_node!("0x27ae5ba08d7291c96c8cbddcc148bf48a6d68c7974b94356f53754ef6171d757"),
];

for (i, zero) in tree.zeros.front.iter().enumerate() {
let inner = tree
.inner
.read()
.expect("Lock should not be poisoned in test");
for (i, zero) in inner.zeros.front.iter().enumerate() {
assert_eq!(zero, &expected_zeros[i]);
}
assert_eq!(tree.zeros.last, expected_zeros[32]);
assert_eq!(inner.zeros.last, expected_zeros[32]);
}

#[cfg(feature = "memory_store")]
Expand Down Expand Up @@ -366,18 +445,22 @@ mod tests {
to_node!("0x2f68a1c58e257e42a17a6c61dff5551ed560b9922ab119d5ac8e184c9734ead9"),
];

for (i, zero) in tree.zeros.front.iter().enumerate() {
let inner = tree
.inner
.read()
.expect("Lock should not be poisoned in test");
for (i, zero) in inner.zeros.front.iter().enumerate() {
assert_eq!(zero, &expected_zeros[i]);
}
assert_eq!(tree.zeros.last, expected_zeros[32]);
assert_eq!(inner.zeros.last, expected_zeros[32]);
}

#[cfg(feature = "memory_store")]
#[test]
fn test_tree_full_error() {
let hasher = Keccak256Hasher;
let store = MemoryStore::default();
let mut tree = MerkleTree::<Keccak256Hasher, MemoryStore, 3>::new(hasher, store);
let tree = MerkleTree::<Keccak256Hasher, MemoryStore, 3>::new(hasher, store);

tree.add_leaves(&(0..8).map(|_| Node::ZERO).collect::<Vec<Node>>())
.unwrap();
Expand Down
Loading
Loading