Skip to content

Commit 5a7362c

Browse files
authored
Simplify process_swap_request (cashubtc#631)
* Simplify process_swap_request * Fix occasional test_swap_to_send wallet errors
1 parent 393c95e commit 5a7362c

File tree

5 files changed

+50
-73
lines changed

5 files changed

+50
-73
lines changed

crates/cdk-integration-tests/src/init_pure_tests.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,7 @@ pub async fn create_and_start_test_mint() -> anyhow::Result<Arc<Mint>> {
157157

158158
let mut mint_builder = MintBuilder::new();
159159

160-
let database = cdk_sqlite::mint::memory::empty()
161-
.await
162-
.expect("valid db instance");
160+
let database = cdk_sqlite::mint::memory::empty().await?;
163161

164162
let localstore = Arc::new(database);
165163
mint_builder = mint_builder.with_localstore(localstore.clone());
@@ -216,9 +214,7 @@ pub async fn create_test_wallet_for_mint(mint: Arc<Mint>) -> anyhow::Result<Arc<
216214
let seed = Mnemonic::generate(12)?.to_seed_normalized("");
217215
let mint_url = "http://aa".to_string();
218216
let unit = CurrencyUnit::Sat;
219-
let localstore = cdk_sqlite::wallet::memory::empty()
220-
.await
221-
.expect("valid db instance");
217+
let localstore = cdk_sqlite::wallet::memory::empty().await?;
222218
let mut wallet = Wallet::new(&mint_url, unit, Arc::new(localstore), &seed, None)?;
223219

224220
wallet.set_client(connector);

crates/cdk-integration-tests/tests/fake_wallet.rs

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -996,17 +996,10 @@ async fn test_fake_mint_swap_spend_after_fail() -> Result<()> {
996996

997997
match response {
998998
Err(err) => match err {
999-
cdk::Error::TokenAlreadySpent => (),
1000-
err => {
1001-
bail!(
1002-
"Wrong mint error returned expected already spent: {}",
1003-
err.to_string()
1004-
);
1005-
}
999+
cdk::Error::TransactionUnbalanced(_, _, _) => (),
1000+
err => bail!("Wrong mint error returned expected TransactionUnbalanced, got: {err}"),
10061001
},
1007-
Ok(_) => {
1008-
bail!("Should not have allowed swap with unbalanced");
1009-
}
1002+
Ok(_) => bail!("Should not have allowed swap with unbalanced"),
10101003
}
10111004

10121005
let pre_mint = PreMintSecrets::random(active_keyset_id, 100.into(), &SplitTarget::None)?;
@@ -1076,14 +1069,10 @@ async fn test_fake_mint_melt_spend_after_fail() -> Result<()> {
10761069

10771070
match response {
10781071
Err(err) => match err {
1079-
cdk::Error::TokenAlreadySpent => (),
1080-
err => {
1081-
bail!("Wrong mint error returned: {}", err.to_string());
1082-
}
1072+
cdk::Error::TransactionUnbalanced(_, _, _) => (),
1073+
err => bail!("Wrong mint error returned expected TransactionUnbalanced, got: {err}"),
10831074
},
1084-
Ok(_) => {
1085-
bail!("Should not have allowed to mint with multiple units");
1086-
}
1075+
Ok(_) => bail!("Should not have allowed swap with unbalanced"),
10871076
}
10881077

10891078
let input_amount: u64 = proofs.total_amount()?.into();

crates/cdk-integration-tests/tests/mint.rs

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
33
use std::collections::{HashMap, HashSet};
44
use std::sync::Arc;
5-
use std::time::Duration;
65

76
use anyhow::{bail, Result};
87
use bip39::Mnemonic;
@@ -21,7 +20,6 @@ use cdk::util::unix_time;
2120
use cdk::Mint;
2221
use cdk_fake_wallet::FakeWallet;
2322
use cdk_sqlite::mint::memory;
24-
use tokio::time::sleep;
2523

2624
pub const MINT_URL: &str = "http://127.0.0.1:8088";
2725

@@ -215,20 +213,12 @@ pub async fn test_p2pk_swap() -> Result<()> {
215213

216214
let swap_request = SwapRequest::new(proofs.clone(), pre_swap.blinded_messages());
217215

216+
// Listen for status updates on all input proof pks
218217
let public_keys_to_listen: Vec<_> = swap_request
219218
.inputs
220-
.ys()
221-
.expect("key")
222-
.into_iter()
223-
.enumerate()
224-
.filter_map(|(key, pk)| {
225-
if key % 2 == 0 {
226-
// Only expect messages from every other key
227-
Some(pk.to_string())
228-
} else {
229-
None
230-
}
231-
})
219+
.ys()?
220+
.iter()
221+
.map(|pk| pk.to_string())
232222
.collect();
233223

234224
let mut listener = mint
@@ -265,29 +255,23 @@ pub async fn test_p2pk_swap() -> Result<()> {
265255

266256
assert!(attempt_swap.is_ok());
267257

268-
sleep(Duration::from_millis(10)).await;
269-
270258
let mut msgs = HashMap::new();
271259
while let Ok((sub_id, msg)) = listener.try_recv() {
272260
assert_eq!(sub_id, "test".into());
273261
match msg {
274262
NotificationPayload::ProofState(ProofState { y, state, .. }) => {
275-
let pk = y.to_string();
276-
msgs.get_mut(&pk)
277-
.map(|x: &mut Vec<State>| {
278-
x.push(state);
279-
})
280-
.unwrap_or_else(|| {
281-
msgs.insert(pk, vec![state]);
282-
});
263+
msgs.entry(y.to_string())
264+
.or_insert_with(Vec::new)
265+
.push(state);
283266
}
284267
_ => bail!("Wrong message received"),
285268
}
286269
}
287270

288271
for keys in public_keys_to_listen {
289272
let statuses = msgs.remove(&keys).expect("some events");
290-
assert_eq!(statuses, vec![State::Pending, State::Pending, State::Spent]);
273+
// Every input pk receives two state updates, as there are only two state transitions
274+
assert_eq!(statuses, vec![State::Pending, State::Spent]);
291275
}
292276

293277
assert!(listener.try_recv().is_err(), "no other event is happening");

crates/cdk-sqlite/src/wallet/memory.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@ use super::WalletSqliteDatabase;
66

77
/// Creates a new in-memory [`WalletSqliteDatabase`] instance
88
pub async fn empty() -> Result<WalletSqliteDatabase, Error> {
9-
let db = WalletSqliteDatabase::new(":memory:").await?;
9+
let db = WalletSqliteDatabase {
10+
pool: sqlx::sqlite::SqlitePool::connect(":memory:")
11+
.await
12+
.map_err(|e| Error::Database(Box::new(e)))?,
13+
};
1014
db.migrate().await;
1115
Ok(db)
1216
}

crates/cdk/src/mint/swap.rs

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,38 +12,22 @@ impl Mint {
1212
&self,
1313
swap_request: SwapRequest,
1414
) -> Result<SwapResponse, Error> {
15-
let input_ys = swap_request.inputs.ys()?;
16-
17-
self.localstore
18-
.add_proofs(swap_request.inputs.clone(), None)
19-
.await?;
20-
self.check_ys_spendable(&input_ys, State::Pending).await?;
21-
2215
if let Err(err) = self
2316
.verify_transaction_balanced(&swap_request.inputs, &swap_request.outputs)
2417
.await
2518
{
26-
tracing::debug!("Attempt to swap unbalanced transaction: {}", err);
27-
self.localstore.remove_proofs(&input_ys, None).await?;
19+
tracing::debug!("Attempt to swap unbalanced transaction, aborting: {err}");
2820
return Err(err);
2921
};
3022

31-
let EnforceSigFlag {
32-
sig_flag,
33-
pubkeys,
34-
sigs_required,
35-
} = enforce_sig_flag(swap_request.inputs.clone());
23+
self.validate_sig_flag(&swap_request).await?;
3624

37-
if sig_flag.eq(&SigFlag::SigAll) {
38-
let pubkeys = pubkeys.into_iter().collect();
39-
for blinded_message in &swap_request.outputs {
40-
if let Err(err) = blinded_message.verify_p2pk(&pubkeys, sigs_required) {
41-
tracing::info!("Could not verify p2pk in swap request");
42-
self.localstore.remove_proofs(&input_ys, None).await?;
43-
return Err(err.into());
44-
}
45-
}
46-
}
25+
// After swap request is fully validated, add the new proofs to DB
26+
let input_ys = swap_request.inputs.ys()?;
27+
self.localstore
28+
.add_proofs(swap_request.inputs.clone(), None)
29+
.await?;
30+
self.check_ys_spendable(&input_ys, State::Pending).await?;
4731

4832
let mut promises = Vec::with_capacity(swap_request.outputs.len());
4933

@@ -74,4 +58,24 @@ impl Mint {
7458

7559
Ok(SwapResponse::new(promises))
7660
}
61+
62+
async fn validate_sig_flag(&self, swap_request: &SwapRequest) -> Result<(), Error> {
63+
let EnforceSigFlag {
64+
sig_flag,
65+
pubkeys,
66+
sigs_required,
67+
} = enforce_sig_flag(swap_request.inputs.clone());
68+
69+
if sig_flag.eq(&SigFlag::SigAll) {
70+
let pubkeys = pubkeys.into_iter().collect();
71+
for blinded_message in &swap_request.outputs {
72+
if let Err(err) = blinded_message.verify_p2pk(&pubkeys, sigs_required) {
73+
tracing::info!("Could not verify p2pk in swap request");
74+
return Err(err.into());
75+
}
76+
}
77+
}
78+
79+
Ok(())
80+
}
7781
}

0 commit comments

Comments
 (0)