Skip to content

Commit 049ebbe

Browse files
committed
joinstr: implem a notification system to notify consumer that the pool state have beem updated
1 parent 2c53ce3 commit 049ebbe

File tree

3 files changed

+108
-25
lines changed

3 files changed

+108
-25
lines changed

rust/joinstr/src/interface.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ pub fn initiate_coinjoin(config: PoolConfig, peer: PeerConfig) -> Result<Txid, E
140140
initiator.set_coin(coin)?;
141141
initiator.set_address(addr)?;
142142

143-
initiator.start_coinjoin_blocking(None, Some(signer.clone()))?;
143+
initiator.start_coinjoin_blocking(None, Some(signer.clone()), || {})?;
144144

145145
let txid = initiator
146146
.final_tx()
@@ -200,7 +200,7 @@ pub fn join_coinjoin(pool: Pool, peer: PeerConfig) -> Result<String /* Txid */,
200200
let client = Client::new(&url, port)?;
201201
signer.set_client(client);
202202

203-
joinstr_peer.start_coinjoin_blocking(None, Some(signer.clone()))?;
203+
joinstr_peer.start_coinjoin_blocking(None, Some(signer.clone()), || {})?;
204204

205205
let txid = joinstr_peer
206206
.final_tx()

rust/joinstr/src/joinstr/mod.rs

Lines changed: 103 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,9 @@ impl Joinstr<'_> {
497497
/// Start the round of output registration, will block until enough output
498498
/// registered or if some error occur.
499499
///
500+
/// # Arguments
501+
/// * `notif` - A callback function called every time the pool state is updated.
502+
///
500503
/// # Errors
501504
///
502505
/// This function will return an error if:
@@ -506,7 +509,10 @@ impl Joinstr<'_> {
506509
/// - the nostr client do not have private keys
507510
/// - timeout elapsed
508511
/// - peer count do not match
509-
fn register_outputs(&mut self) -> Result<(), Error> {
512+
fn register_outputs<N>(&mut self, notif: N) -> Result<(), Error>
513+
where
514+
N: Fn(),
515+
{
510516
let inner = self.inner.lock().expect("poisoned");
511517
inner.pool_exists()?;
512518
let (expired, start_early) = inner.start_timeline()?;
@@ -541,6 +547,7 @@ impl Joinstr<'_> {
541547
}
542548
peers.insert(npub);
543549
inner.peers.push(npub);
550+
notif();
544551
log::debug!(
545552
"Coordinator({}).register_outputs(): receive Join({}) request. \n peers: {}",
546553
inner.client.name,
@@ -583,7 +590,7 @@ impl Joinstr<'_> {
583590
let mut inner = self.inner.lock().expect("poisoned");
584591
if let Some(output) = inner.output.as_ref() {
585592
coinjoin.add_output(output.clone());
586-
inner.register_output()?;
593+
inner.register_output(&notif)?;
587594
}
588595
drop(inner);
589596

@@ -613,6 +620,7 @@ impl Joinstr<'_> {
613620
// TODO: we must error if outputs > peers
614621
// TODO: check address network
615622
inner.outputs.push(o.assume_checked());
623+
notif();
616624
}
617625
// FIXME: here it can be some cases where, because network timing, we can
618626
// receive a signed input before the output registration round ended, we should
@@ -647,12 +655,16 @@ impl Joinstr<'_> {
647655
));
648656
}
649657
self.inner.lock().expect("poisoined").coinjoin = Some(coinjoin);
658+
notif();
650659
Ok(())
651660
}
652661

653662
/// Start the round of input registration, will block until enough input
654663
/// registered or if some error occur.
655664
///
665+
/// # Arguments
666+
/// * `notif` - A callback function called every time the pool state is updated.
667+
///
656668
/// # Errors
657669
///
658670
/// This function will return an error if:
@@ -662,7 +674,10 @@ impl Joinstr<'_> {
662674
/// - timeout expired
663675
/// - trying register an input error
664676
/// - trying finalize coinjoin error
665-
fn register_inputs(&mut self) -> Result<(), Error> {
677+
fn register_inputs<N>(&mut self, notif: N) -> Result<(), Error>
678+
where
679+
N: Fn(),
680+
{
666681
let inner = self.inner.lock().expect("poisoned");
667682
inner.pool_exists()?;
668683
inner.coinjoin_exists()?;
@@ -698,13 +713,13 @@ impl Joinstr<'_> {
698713
PoolMessage::Psbt(psbt) => {
699714
let input: InputDataSigned =
700715
psbt.try_into().map_err(|_| Error::PsbtToInput)?;
701-
inner.try_register_input(input)?;
716+
inner.try_register_input(input, &notif)?;
702717
if inner.try_finalize_coinjoin()? {
703718
break;
704719
}
705720
}
706721
PoolMessage::Input(input) => {
707-
inner.try_register_input(input)?;
722+
inner.try_register_input(input, &notif)?;
708723
if inner.try_finalize_coinjoin()? {
709724
break;
710725
}
@@ -770,21 +785,61 @@ impl Joinstr<'_> {
770785
let mut cloned = self.clone();
771786
let signer = signer.clone();
772787
thread::spawn(move || {
773-
if let Err(e) = cloned.start_coinjoin_blocking(pool, signer) {
788+
if let Err(e) = cloned.start_coinjoin_blocking(pool, signer, || {}) {
789+
let mut inner = cloned.inner.lock().expect("poisoned");
790+
inner.error = Some(format!("{:?}", e));
791+
inner.step = Step::Failed;
792+
}
793+
});
794+
}
795+
796+
/// Start a coinjoin process, followings steps will be processed:
797+
/// - if no `pool` arg is passed, a new pool will be initiated.
798+
/// - if a `pool` arg is passed, it will join the pool
799+
/// - run the outputs registration round
800+
/// - if a `signer` arg is passed, it will signed the input it owns.
801+
/// - run the inputs registration round
802+
/// - finalize the transaction
803+
/// - broadcast the transaction
804+
///
805+
/// # Arguments
806+
/// * `pool` - The pool we want join (optional)
807+
/// * `signer` - The signer to sign our input with (optional)
808+
/// * `notif` - A callback function called every time the pool state is updated.
809+
///
810+
/// # Errors
811+
///
812+
/// This function will return an error if any step return an error.
813+
pub fn start_coinjoin_with_notif<S, N>(
814+
&mut self,
815+
pool: Option<Pool>,
816+
signer: Option<S>,
817+
notif: N,
818+
) where
819+
S: JoinstrSigner + Sized + Sync + Clone + Send + 'static,
820+
Self: Sized + Send + 'static,
821+
N: Fn() + Send + 'static,
822+
{
823+
let mut cloned = self.clone();
824+
let signer = signer.clone();
825+
thread::spawn(move || {
826+
if let Err(e) = cloned.start_coinjoin_blocking(pool, signer, notif) {
774827
let mut inner = cloned.inner.lock().expect("poisoned");
775828
inner.error = Some(format!("{:?}", e));
776829
inner.step = Step::Failed;
777830
}
778831
});
779832
}
780833

781-
pub fn start_coinjoin_blocking<S>(
834+
pub fn start_coinjoin_blocking<S, N>(
782835
&mut self,
783836
pool: Option<Pool>,
784837
signer: Option<S>,
838+
notif: N,
785839
) -> Result<(), Error>
786840
where
787841
S: JoinstrSigner + Sync + Clone + Send + 'static,
842+
N: Fn(),
788843
{
789844
let mut inner = self.inner.lock().expect("poisoned");
790845

@@ -802,36 +857,42 @@ impl Joinstr<'_> {
802857
inner.post()?;
803858
drop(inner);
804859
}
860+
notif();
861+
805862
// register peers & outputs
806-
self.register_outputs()?;
863+
self.register_outputs(&notif)?;
807864

808865
self.inner
809866
.lock()
810867
.expect("poisoned")
811868
.generate_unsigned_tx()?;
812869

870+
notif();
871+
813872
rand_delay();
814873

815874
let mut inner = self.inner.lock().expect("poisoned");
816875
if inner.input.is_some() {
817876
if let Some(s) = signer {
818-
inner.register_input(&s)?;
877+
inner.register_input(&s, &notif)?;
819878
} else {
820879
return Err(Error::SignerMissing);
821880
}
822881
}
823882
drop(inner);
824883

825-
self.register_inputs()?;
884+
self.register_inputs(&notif)?;
826885

827886
self.inner.lock().expect("poisoned").broadcast_tx()?;
887+
notif();
828888

829889
Ok(())
830890
}
831891

832-
pub fn restart<S>(state: State, name: &str, signer: S) -> Result<Self, Error>
892+
pub fn restart<S, N>(state: State, name: &str, signer: S, notif: N) -> Result<Self, Error>
833893
where
834894
S: JoinstrSigner + Sized + Sync + Clone + Send + 'static,
895+
N: Fn() + Send + 'static,
835896
Self: Sized + Send + 'static,
836897
{
837898
let State {
@@ -957,13 +1018,15 @@ impl Joinstr<'_> {
9571018

9581019
drop(inner);
9591020

960-
fn restart_blocking<S>(
1021+
fn restart_blocking<S, N>(
9611022
mut j: Joinstr,
9621023
expected_peers: usize,
9631024
signer: S,
1025+
notif: N,
9641026
) -> Result<(), Error>
9651027
where
9661028
S: JoinstrSigner + Sync + Clone + Send + 'static,
1029+
N: Fn(),
9671030
{
9681031
let inner = j.inner.lock().expect("poisoned");
9691032
let joined = inner.peers.len() >= expected_peers;
@@ -973,21 +1036,22 @@ impl Joinstr<'_> {
9731036
drop(inner);
9741037

9751038
if !joined || !output_registered {
976-
j.register_outputs()?;
1039+
j.register_outputs(&notif)?;
9771040
}
9781041

9791042
if !inputs_registered {
9801043
j.inner.lock().expect("poisoned").generate_unsigned_tx()?;
1044+
notif();
9811045

9821046
rand_delay();
9831047

9841048
let mut inner = j.inner.lock().expect("poisoned");
9851049
if inner.input.is_some() {
986-
inner.register_input(&signer)?;
1050+
inner.register_input(&signer, &notif)?;
9871051
}
9881052
drop(inner);
9891053

990-
j.register_inputs()?;
1054+
j.register_inputs(&notif)?;
9911055

9921056
j.inner.lock().expect("poisoned").broadcast_tx()?;
9931057
}
@@ -998,7 +1062,7 @@ impl Joinstr<'_> {
9981062
let j3 = j.clone();
9991063

10001064
std::thread::spawn(move || {
1001-
if let Err(e) = restart_blocking(j2, expected_peers, signer) {
1065+
if let Err(e) = restart_blocking(j2, expected_peers, signer, &notif) {
10021066
let mut inner = j3.inner.lock().expect("poisoned");
10031067
inner.error = Some(format!("{:?}", e));
10041068
inner.step = Step::Failed;
@@ -1215,20 +1279,27 @@ impl<'a> JoinstrInner<'a> {
12151279

12161280
/// Register [`Joinstr::output`] address to the pool
12171281
///
1282+
/// # Arguments
1283+
/// * `notif` - A callback function called every time the pool state is updated.
1284+
///
12181285
/// # Errors
12191286
///
12201287
/// This function will return an error if:
12211288
/// - the pool not exists
12221289
/// - [`Joinstr::output`] is missing
12231290
/// - fails to send the nostr message
1224-
fn register_output(&mut self) -> Result<(), Error> {
1291+
fn register_output<N>(&mut self, notif: N) -> Result<(), Error>
1292+
where
1293+
N: Fn(),
1294+
{
12251295
if let Some(address) = &self.output {
12261296
// let msg = PoolMessage::Outputs(Outputs::single(address.as_unchecked().clone()));
12271297
let msg = PoolMessage::Output(address.as_unchecked().clone());
12281298
self.pool_exists()?;
12291299
let npub = self.pool_as_ref()?.public_key;
12301300
self.client.send_pool_message(&npub, msg)?;
12311301
self.outputs.push(address.clone());
1302+
notif();
12321303
// TODO: handle re-send if fails
12331304
Ok(())
12341305
} else {
@@ -1273,6 +1344,9 @@ impl<'a> JoinstrInner<'a> {
12731344

12741345
/// Try to sign / register / send our input.
12751346
///
1347+
/// # Arguments
1348+
/// * `notif` - A callback function called every time the pool state is updated.
1349+
///
12761350
/// # Errors
12771351
///
12781352
/// This function will return an error if:
@@ -1282,9 +1356,10 @@ impl<'a> JoinstrInner<'a> {
12821356
/// - the inner pool dont exists
12831357
/// - [`Joinstr::input`] is None
12841358
/// - sending the input fails
1285-
fn register_input<S>(&mut self, signer: &S) -> Result<(), Error>
1359+
fn register_input<S, N>(&mut self, signer: &S, notif: N) -> Result<(), Error>
12861360
where
12871361
S: JoinstrSigner,
1362+
N: Fn(),
12881363
{
12891364
let unsigned = match self.coinjoin_as_ref()?.unsigned_tx() {
12901365
Some(u) => u,
@@ -1299,19 +1374,26 @@ impl<'a> JoinstrInner<'a> {
12991374
let npub = self.pool_as_ref()?.public_key;
13001375
self.client.send_pool_message(&npub, msg)?;
13011376
self.inputs.push(signed_input);
1377+
notif();
13021378
// TODO: handle re-send if fails
13031379
Ok(())
13041380
} else {
13051381
Err(Error::InputMissing)
13061382
}
13071383
}
13081384

1309-
// Try to register a received signed input to the inner [`CoinJoin`]
1385+
/// Try to register a received signed input to the inner [`CoinJoin`]
1386+
///
1387+
/// # Arguments
1388+
/// * `notif` - A callback function called every time the pool state is updated.
13101389
///
13111390
/// # Errors
13121391
///
13131392
/// This function will return an error if [`Joinstr::coinjoin`] is None
1314-
fn try_register_input(&mut self, input: InputDataSigned) -> Result<(), Error> {
1393+
fn try_register_input<N>(&mut self, input: InputDataSigned, notif: N) -> Result<(), Error>
1394+
where
1395+
N: Fn(),
1396+
{
13151397
self.coinjoin_exists()?;
13161398
log::debug!(
13171399
"Coordinator({}).register_input(): receive Inputs({:?}) request.",
@@ -1328,6 +1410,7 @@ impl<'a> JoinstrInner<'a> {
13281410
);
13291411
} else {
13301412
self.inputs.push(input);
1413+
notif();
13311414
}
13321415
}
13331416
Ok(())

rust/joinstr/tests/joinstr.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ fn simple_coinjoin() {
139139

140140
let coordinator_handle = thread::spawn(move || {
141141
coordinator
142-
.start_coinjoin_blocking(None, Option::<WpkhHotSigner>::None)
142+
.start_coinjoin_blocking(None, Option::<WpkhHotSigner>::None, || {})
143143
.unwrap();
144144
coordinator.final_tx()
145145
});
@@ -226,11 +226,11 @@ fn simple_coinjoin() {
226226
let signer_a = signer.clone();
227227
let pool_a = pool.clone();
228228
let _peer_a = thread::spawn(move || {
229-
let _ = peer_a.start_coinjoin_blocking(Some(pool_a), Some(signer_a));
229+
let _ = peer_a.start_coinjoin_blocking(Some(pool_a), Some(signer_a), || {});
230230
});
231231

232232
let _peer_b = thread::spawn(move || {
233-
let _ = peer_b.start_coinjoin_blocking(Some(pool), Some(signer));
233+
let _ = peer_b.start_coinjoin_blocking(Some(pool), Some(signer), || {});
234234
});
235235

236236
let final_tx = coordinator_handle.join().unwrap().unwrap();

0 commit comments

Comments
 (0)