Skip to content

Commit bd77795

Browse files
authored
Deduplicate PSBT inputs in contribute_inputs (#1254)
2 parents 6396e6a + f3b7028 commit bd77795

File tree

1 file changed

+11
-5
lines changed
  • payjoin/src/core/receive/common

1 file changed

+11
-5
lines changed

payjoin/src/core/receive/common/mod.rs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
//! APIs to expose as relevant typestates.
44
55
use std::cmp::{max, min};
6+
use std::collections::HashSet;
67

78
use bitcoin::psbt::Psbt;
89
use bitcoin::secp256k1::rand::seq::SliceRandom;
@@ -290,7 +291,13 @@ impl WantsInputs {
290291
.map(|input| input.sequence)
291292
.unwrap_or_default();
292293

293-
let inputs = inputs.into_iter().collect::<Vec<_>>();
294+
// Collect existing PSBT outpoints so duplicate inputs are filtered out.
295+
let mut seen_outpoints: HashSet<_> =
296+
self.payjoin_psbt.unsigned_tx.input.iter().map(|txin| txin.previous_output).collect();
297+
let inputs: Vec<_> = inputs
298+
.into_iter()
299+
.filter(|input| seen_outpoints.insert(input.txin.previous_output))
300+
.collect();
294301

295302
// Insert contributions at random indices for privacy
296303
let mut rng = rand::thread_rng();
@@ -652,15 +659,14 @@ mod tests {
652659
let wants_inputs = wants_inputs.contribute_inputs(vec![input_pair_1.clone()]).unwrap();
653660
assert_eq!(wants_inputs.receiver_inputs.len(), 1);
654661
assert_eq!(wants_inputs.receiver_inputs[0], input_pair_1);
655-
// Contribute the same input again, and a new input.
656-
// TODO: if we ever decide to fix contribute duplicate inputs, we need to update this test.
662+
// Contribute the same input again (should be filtered out) and a new input.
657663
let wants_inputs = wants_inputs
658664
.contribute_inputs(vec![input_pair_2.clone(), input_pair_1.clone()])
659665
.unwrap();
660-
assert_eq!(wants_inputs.receiver_inputs.len(), 3);
666+
// Only input_pair_2 should be added input_pair_1 is a duplicate and should be filtered out hence the length is 2.
667+
assert_eq!(wants_inputs.receiver_inputs.len(), 2);
661668
assert_eq!(wants_inputs.receiver_inputs[0], input_pair_1);
662669
assert_eq!(wants_inputs.receiver_inputs[1], input_pair_2);
663-
assert_eq!(wants_inputs.receiver_inputs[2], input_pair_1);
664670
}
665671

666672
#[test]

0 commit comments

Comments
 (0)