@@ -291,13 +291,18 @@ impl WantsInputs {
291291 . map ( |input| input. sequence )
292292 . unwrap_or_default ( ) ;
293293
294- // Collect existing PSBT outpoints so duplicate inputs are filtered out .
294+ // Collect existing PSBT outpoints to detect duplicate inputs.
295295 let mut seen_outpoints: HashSet < _ > =
296296 self . payjoin_psbt . unsigned_tx . input . iter ( ) . map ( |txin| txin. previous_output ) . collect ( ) ;
297- let inputs: Vec < _ > = inputs
298- . into_iter ( )
299- . filter ( |input| seen_outpoints. insert ( input. txin . previous_output ) )
300- . collect ( ) ;
297+ let inputs: Vec < _ > = inputs. into_iter ( ) . collect ( ) ;
298+ for input in & inputs {
299+ if !seen_outpoints. insert ( input. txin . previous_output ) {
300+ return Err ( InternalInputContributionError :: DuplicateInput (
301+ input. txin . previous_output ,
302+ )
303+ . into ( ) ) ;
304+ }
305+ }
301306
302307 // Insert contributions at random indices for privacy
303308 let mut rng = rand:: thread_rng ( ) ;
@@ -659,11 +664,17 @@ mod tests {
659664 let wants_inputs = wants_inputs. contribute_inputs ( vec ! [ input_pair_1. clone( ) ] ) . unwrap ( ) ;
660665 assert_eq ! ( wants_inputs. receiver_inputs. len( ) , 1 ) ;
661666 assert_eq ! ( wants_inputs. receiver_inputs[ 0 ] , input_pair_1) ;
662- // Contribute the same input again (should be filtered out) and a new input.
663- let wants_inputs = wants_inputs
667+ // Contribute the same input again (should error) and a new input.
668+ let duplicate_input = wants_inputs
669+ . clone ( )
664670 . contribute_inputs ( vec ! [ input_pair_2. clone( ) , input_pair_1. clone( ) ] )
665- . unwrap ( ) ;
666- // Only input_pair_2 should be added input_pair_1 is a duplicate and should be filtered out hence the length is 2.
671+ . unwrap_err ( ) ;
672+ assert_eq ! (
673+ duplicate_input,
674+ InputContributionError :: from( InternalInputContributionError :: DuplicateInput ( ot1) )
675+ ) ;
676+ // Contribute only the new input
677+ let wants_inputs = wants_inputs. contribute_inputs ( vec ! [ input_pair_2. clone( ) ] ) . unwrap ( ) ;
667678 assert_eq ! ( wants_inputs. receiver_inputs. len( ) , 2 ) ;
668679 assert_eq ! ( wants_inputs. receiver_inputs[ 0 ] , input_pair_1) ;
669680 assert_eq ! ( wants_inputs. receiver_inputs[ 1 ] , input_pair_2) ;
0 commit comments