Skip to content

Commit 88e8069

Browse files
committed
fix: run the validation and the insertion of a transaction in the mempool at the same time
Signed-off-by: Eric Torreborre <etorreborre@yahoo.com>
1 parent 8811af1 commit 88e8069

File tree

10 files changed

+104
-223
lines changed

10 files changed

+104
-223
lines changed

crates/amaru-mempool/src/strategies/dummy.rs

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,7 @@
1515
use std::{collections::BTreeSet, mem, pin::Pin};
1616

1717
use amaru_kernel::cbor;
18-
use amaru_ouroboros_traits::{
19-
CanValidateTransactions, Mempool, MempoolSeqNo, TransactionValidationError, TxId, TxOrigin, TxRejectReason,
20-
TxSubmissionMempool,
21-
};
18+
use amaru_ouroboros_traits::{Mempool, MempoolSeqNo, TxId, TxOrigin, TxRejectReason, TxSubmissionMempool};
2219
use parking_lot::RwLock;
2320

2421
#[derive(Debug, Default)]
@@ -37,12 +34,6 @@ pub struct DummyMempoolInner<T> {
3734
transactions: Vec<T>,
3835
}
3936

40-
impl<Tx: cbor::Encode<()> + Send + Sync + 'static> CanValidateTransactions<Tx> for DummyMempool<Tx> {
41-
fn validate_transaction(&self, _tx: Tx) -> Result<(), TransactionValidationError> {
42-
Ok(())
43-
}
44-
}
45-
4637
impl<Tx: cbor::Encode<()> + Send + Sync + 'static> TxSubmissionMempool<Tx> for DummyMempool<Tx> {
4738
fn insert(&self, tx: Tx, _tx_origin: TxOrigin) -> Result<(TxId, MempoolSeqNo), TxRejectReason> {
4839
let tx_id = TxId::from(&tx);

crates/amaru-mempool/src/strategies/in_memory_mempool.rs

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ use std::{
2121

2222
use amaru_kernel::{cbor, to_cbor};
2323
use amaru_ouroboros_traits::{
24-
CanValidateTransactions, MempoolSeqNo, TransactionValidationError, TxId, TxOrigin, TxRejectReason,
25-
TxSubmissionMempool, mempool::Mempool,
24+
MempoolSeqNo, TransactionValidationError, TxId, TxOrigin, TxRejectReason, TxSubmissionMempool, mempool::Mempool,
2625
};
2726
use tokio::sync::Notify;
2827

@@ -31,39 +30,48 @@ use tokio::sync::Notify;
3130
/// It stores transactions in memory, indexed by their TxId and by a sequence number assigned
3231
/// at insertion time.
3332
///
34-
/// The validation of the transactions are delegated to a `CanValidateTransactions` implementation.
33+
/// The validation of transactions is delegated to an injected validator.
3534
///
3635
#[derive(Clone)]
3736
pub struct InMemoryMempool<Tx> {
3837
config: MempoolConfig,
3938
inner: Arc<parking_lot::RwLock<MempoolInner<Tx>>>,
40-
tx_validator: Arc<dyn CanValidateTransactions<Tx>>,
39+
tx_validator: Arc<dyn TxValidator<Tx>>,
4140
}
4241

43-
impl<Tx> Default for InMemoryMempool<Tx> {
42+
pub trait TxValidator<Tx>: Send + Sync {
43+
fn validate(&self, tx: &Tx) -> Result<(), TransactionValidationError>;
44+
}
45+
46+
impl<Tx, F> TxValidator<Tx> for F
47+
where
48+
F: Fn(&Tx) -> Result<(), TransactionValidationError> + Send + Sync,
49+
{
50+
fn validate(&self, tx: &Tx) -> Result<(), TransactionValidationError> {
51+
self(tx)
52+
}
53+
}
54+
55+
impl<Tx: 'static> Default for InMemoryMempool<Tx> {
4456
fn default() -> Self {
4557
Self::from_config(MempoolConfig::default())
4658
}
4759
}
4860

4961
impl<Tx> InMemoryMempool<Tx> {
50-
pub fn new(config: MempoolConfig, tx_validator: Arc<dyn CanValidateTransactions<Tx>>) -> Self {
62+
pub fn new(config: MempoolConfig, tx_validator: Arc<dyn TxValidator<Tx>>) -> Self {
5163
InMemoryMempool { config, inner: Arc::new(parking_lot::RwLock::new(MempoolInner::default())), tx_validator }
5264
}
65+
}
5366

67+
impl<Tx: 'static> InMemoryMempool<Tx> {
5468
pub fn from_config(config: MempoolConfig) -> Self {
55-
Self::new(config, Arc::new(DefaultCanValidateTransactions))
69+
Self::new(config, Arc::new(default_tx_validator::<Tx>))
5670
}
5771
}
5872

59-
/// A default transactions validator.
60-
#[derive(Clone, Debug, Default)]
61-
pub struct DefaultCanValidateTransactions;
62-
63-
impl<Tx> CanValidateTransactions<Tx> for DefaultCanValidateTransactions {
64-
fn validate_transaction(&self, _tx: Tx) -> Result<(), TransactionValidationError> {
65-
Ok(())
66-
}
73+
pub fn default_tx_validator<Tx>(_tx: &Tx) -> Result<(), TransactionValidationError> {
74+
Ok(())
6775
}
6876

6977
#[derive(Debug)]
@@ -169,14 +177,9 @@ impl MempoolConfig {
169177
}
170178
}
171179

172-
impl<Tx: Send + Sync + 'static> CanValidateTransactions<Tx> for InMemoryMempool<Tx> {
173-
fn validate_transaction(&self, tx: Tx) -> Result<(), TransactionValidationError> {
174-
self.tx_validator.validate_transaction(tx)
175-
}
176-
}
177-
178180
impl<Tx: Send + Sync + 'static + cbor::Encode<()> + Clone> TxSubmissionMempool<Tx> for InMemoryMempool<Tx> {
179181
fn insert(&self, tx: Tx, tx_origin: TxOrigin) -> Result<(TxId, MempoolSeqNo), TxRejectReason> {
182+
self.tx_validator.validate(&tx).map_err(TxRejectReason::Invalid)?;
180183
let mut inner = self.inner.write();
181184
let res = inner.insert(&self.config, tx, tx_origin);
182185
if res.is_ok() {
@@ -258,7 +261,7 @@ impl<Tx: Send + Sync + 'static + cbor::Encode<()> + Clone> Mempool<Tx> for InMem
258261

259262
#[cfg(test)]
260263
mod tests {
261-
use std::{ops::Deref, slice, str::FromStr, time::Duration};
264+
use std::{ops::Deref, slice, str::FromStr, sync::Arc, time::Duration};
262265

263266
use amaru_kernel::{Peer, cbor, cbor as minicbor};
264267
use assertables::assert_some_eq_x;
@@ -283,6 +286,16 @@ mod tests {
283286
Ok(())
284287
}
285288

289+
#[test]
290+
fn reject_invalid_transaction_on_insert() {
291+
let mempool = InMemoryMempool::new(MempoolConfig::default(), Arc::new(reject_tx));
292+
let tx = Tx::from_str("tx1").unwrap();
293+
294+
let result = mempool.insert(tx, TxOrigin::Local);
295+
296+
assert_eq!(result, Err(TxRejectReason::Invalid(anyhow::anyhow!("transaction rejected for testing").into())));
297+
}
298+
286299
// HELPERS
287300
#[derive(Debug, PartialEq, Eq, Clone, cbor::Encode, cbor::Decode)]
288301
struct Tx(#[n(0)] String);
@@ -300,4 +313,8 @@ mod tests {
300313
Ok(Tx(s.to_string()))
301314
}
302315
}
316+
317+
fn reject_tx(_tx: &Tx) -> Result<(), TransactionValidationError> {
318+
Err(anyhow::anyhow!("transaction rejected for testing").into())
319+
}
303320
}

crates/amaru-ouroboros-traits/src/can_validate_transactions/mock.rs

Lines changed: 0 additions & 25 deletions
This file was deleted.

crates/amaru-ouroboros-traits/src/can_validate_transactions/mod.rs

Lines changed: 0 additions & 85 deletions
This file was deleted.

crates/amaru-ouroboros-traits/src/lib.rs

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@ pub use has_stake_distribution::{HasStakeDistribution, PoolSummary};
1818
pub mod can_validate_blocks;
1919
pub use can_validate_blocks::{BlockValidationError, CanValidateBlocks};
2020

21-
pub mod can_validate_transactions;
22-
pub use can_validate_transactions::{CanValidateTransactions, TransactionValidationError};
23-
2421
pub mod connections;
2522
pub use connections::*;
2623

crates/amaru-ouroboros-traits/src/mempool.rs

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
use std::{
16+
error::Error,
1617
fmt,
1718
fmt::{Display, Formatter},
1819
pin::Pin,
@@ -22,8 +23,6 @@ use std::{
2223
use amaru_kernel::{Hash, Hasher, Peer, TransactionId, cbor, size::TRANSACTION_BODY};
2324
use serde::{Deserialize, Serialize};
2425

25-
use crate::CanValidateTransactions;
26-
2726
/// An simple mempool interface to:
2827
///
2928
/// - Retrieve transactions to be included in a new block.
@@ -48,7 +47,7 @@ pub trait Mempool<Tx: Send + Sync + 'static>: TxSubmissionMempool<Tx> + Send + S
4847

4948
pub type ResourceMempool<Tx> = Arc<dyn TxSubmissionMempool<Tx>>;
5049

51-
pub trait TxSubmissionMempool<Tx: Send + Sync + 'static>: Send + Sync + CanValidateTransactions<Tx> {
50+
pub trait TxSubmissionMempool<Tx: Send + Sync + 'static>: Send + Sync {
5251
/// Insert a transaction into the mempool, specifying its origin.
5352
/// A TxOrigin::Local origin indicates the transaction was created on the current node,
5453
/// A TxOrigin::Remote(origin_peer) indicates the transaction was received from a remote peer.
@@ -97,6 +96,49 @@ pub trait TxSubmissionMempool<Tx: Send + Sync + 'static>: Send + Sync + CanValid
9796
fn last_seq_no(&self) -> MempoolSeqNo;
9897
}
9998

99+
#[derive(Debug, thiserror::Error)]
100+
#[error("TransactionValidationError: {0}")]
101+
pub struct TransactionValidationError(#[from] anyhow::Error);
102+
103+
impl TransactionValidationError {
104+
pub fn to_anyhow(self) -> anyhow::Error {
105+
self.0
106+
}
107+
108+
pub fn downcast<T: Error + fmt::Debug + Send + Sync + 'static>(self) -> Result<T, anyhow::Error> {
109+
self.0.downcast::<T>()
110+
}
111+
112+
pub fn downcast_ref<T: Error + fmt::Debug + Send + Sync + 'static>(&self) -> Option<&T> {
113+
self.0.downcast_ref::<T>()
114+
}
115+
}
116+
117+
impl Serialize for TransactionValidationError {
118+
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
119+
where
120+
S: serde::Serializer,
121+
{
122+
serializer.serialize_str(&self.0.to_string())
123+
}
124+
}
125+
126+
impl<'de> Deserialize<'de> for TransactionValidationError {
127+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
128+
where
129+
D: serde::Deserializer<'de>,
130+
{
131+
let s = String::deserialize(deserializer)?;
132+
Ok(TransactionValidationError(anyhow::anyhow!(s)))
133+
}
134+
}
135+
136+
impl PartialEq for TransactionValidationError {
137+
fn eq(&self, other: &Self) -> bool {
138+
self.0.to_string() == other.0.to_string()
139+
}
140+
}
141+
100142
/// Sequence number assigned to a transaction when inserted into the mempool.
101143
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, Default)]
102144
pub struct MempoolSeqNo(pub u64);
@@ -111,14 +153,14 @@ impl MempoolSeqNo {
111153
}
112154
}
113155

114-
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, thiserror::Error, Serialize, Deserialize)]
156+
#[derive(Debug, PartialEq, thiserror::Error, Serialize, Deserialize)]
115157
pub enum TxRejectReason {
116158
#[error("Mempool is full")]
117159
MempoolFull,
118160
#[error("Transaction is a duplicate")]
119161
Duplicate,
120-
#[error("Transaction is invalid")]
121-
Invalid,
162+
#[error(transparent)]
163+
Invalid(#[from] TransactionValidationError),
122164
}
123165

124166
/// Origin of a transaction being inserted into the mempool:

0 commit comments

Comments
 (0)