Skip to content

Commit d9a9adf

Browse files
authored
feat(subs): await subscription ID & ignore first recv message (#123)
1 parent 62bf0fc commit d9a9adf

File tree

4 files changed

+268
-125
lines changed

4 files changed

+268
-125
lines changed

src/c/mod.rs

Lines changed: 131 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ use std::net::SocketAddr;
66
use std::ops::Deref;
77
use std::os::raw::c_char;
88
use std::str::FromStr;
9-
use std::sync::atomic::{AtomicU64, Ordering};
109
use std::sync::Arc;
1110
use std::time::Duration;
1211

@@ -45,6 +44,7 @@ use starknet_crypto::{poseidon_hash_many, Felt};
4544
use stream_cancel::{StreamExt as _, Tripwire};
4645
use tokio::net::TcpListener;
4746
use tokio::runtime::Runtime;
47+
use tokio::sync::oneshot;
4848
use tokio::time::sleep;
4949
use tokio_stream::StreamExt;
5050
use torii_client::Client as TClient;
@@ -842,15 +842,15 @@ pub unsafe extern "C" fn client_on_entity_state_update(
842842
callback: unsafe extern "C" fn(types::FieldElement, CArray<Struct>),
843843
) -> Result<*mut Subscription> {
844844
let client = Arc::new(unsafe { &*client });
845-
let subscription_id = Arc::new(AtomicU64::new(0));
846845
let (trigger, tripwire) = Tripwire::new();
847846

848-
let subscription = Subscription { id: Arc::clone(&subscription_id), trigger };
847+
let (sub_id_tx, sub_id_rx) = oneshot::channel();
848+
let mut sub_id_tx = Some(sub_id_tx);
849+
849850
let clause: Option<torii_proto::Clause> = clause.map(|c| c.into()).into();
850851

851852
// Spawn a new thread to handle the stream and reconnections
852853
let client_clone = client.clone();
853-
let subscription_id_clone = Arc::clone(&subscription_id);
854854
RUNTIME.spawn(async move {
855855
let mut backoff = Duration::from_secs(1);
856856
let max_backoff = Duration::from_secs(60);
@@ -863,10 +863,14 @@ pub unsafe extern "C" fn client_on_entity_state_update(
863863
let mut rcv = rcv.take_until_if(tripwire.clone());
864864

865865
while let Some(Ok((id, entity))) = rcv.next().await {
866-
subscription_id_clone.store(id, Ordering::SeqCst);
867-
let key: types::FieldElement = entity.hashed_keys.into();
868-
let models: Vec<Struct> = entity.models.into_iter().map(|e| e.into()).collect();
869-
callback(key, models.into());
866+
if let Some(tx) = sub_id_tx.take() {
867+
tx.send(id).expect("Failed to send subscription ID");
868+
} else {
869+
let key: types::FieldElement = entity.hashed_keys.into();
870+
let models: Vec<Struct> =
871+
entity.models.into_iter().map(|e| e.into()).collect();
872+
callback(key, models.into());
873+
}
870874
}
871875
}
872876

@@ -880,6 +884,18 @@ pub unsafe extern "C" fn client_on_entity_state_update(
880884
}
881885
});
882886

887+
let subscription_id = match RUNTIME.block_on(sub_id_rx) {
888+
Ok(id) => id,
889+
Err(_) => {
890+
return Result::Err(Error {
891+
message: CString::new("Failed to establish entity subscription")
892+
.unwrap()
893+
.into_raw(),
894+
});
895+
}
896+
};
897+
898+
let subscription = Subscription { id: subscription_id, trigger };
883899
Result::Ok(Box::into_raw(Box::new(subscription)))
884900
}
885901

@@ -901,11 +917,7 @@ pub unsafe extern "C" fn client_update_entity_subscription(
901917
) -> Result<bool> {
902918
let clause: Option<torii_proto::Clause> = clause.map(|c| c.into()).into();
903919

904-
match RUNTIME.block_on(
905-
(*client)
906-
.inner
907-
.update_entity_subscription((*subscription).id.load(Ordering::SeqCst), clause),
908-
) {
920+
match RUNTIME.block_on((*client).inner.update_entity_subscription((*subscription).id, clause)) {
909921
Ok(_) => Result::Ok(true),
910922
Err(e) => Result::Err(e.into()),
911923
}
@@ -930,14 +942,12 @@ pub unsafe extern "C" fn client_on_event_message_update(
930942
let client = Arc::new(unsafe { &*client });
931943
let clause: Option<torii_proto::Clause> = clause.map(|c| c.into()).into();
932944

933-
let subscription_id = Arc::new(AtomicU64::new(0));
934945
let (trigger, tripwire) = Tripwire::new();
935-
936-
let subscription = Subscription { id: Arc::clone(&subscription_id), trigger };
946+
let (sub_id_tx, sub_id_rx) = oneshot::channel();
947+
let mut sub_id_tx = Some(sub_id_tx);
937948

938949
// Spawn a new thread to handle the stream and reconnections
939950
let client_clone = client.clone();
940-
let subscription_id_clone = Arc::clone(&subscription_id);
941951
RUNTIME.spawn(async move {
942952
let mut backoff = Duration::from_secs(1);
943953
let max_backoff = Duration::from_secs(60);
@@ -950,10 +960,14 @@ pub unsafe extern "C" fn client_on_event_message_update(
950960
let mut rcv = rcv.take_until_if(tripwire.clone());
951961

952962
while let Some(Ok((id, entity))) = rcv.next().await {
953-
subscription_id_clone.store(id, Ordering::SeqCst);
954-
let key: types::FieldElement = entity.hashed_keys.into();
955-
let models: Vec<Struct> = entity.models.into_iter().map(|e| e.into()).collect();
956-
callback(key, models.into());
963+
if let Some(tx) = sub_id_tx.take() {
964+
tx.send(id).expect("Failed to send subscription ID");
965+
} else {
966+
let key: types::FieldElement = entity.hashed_keys.into();
967+
let models: Vec<Struct> =
968+
entity.models.into_iter().map(|e| e.into()).collect();
969+
callback(key, models.into());
970+
}
957971
}
958972
}
959973

@@ -967,6 +981,18 @@ pub unsafe extern "C" fn client_on_event_message_update(
967981
}
968982
});
969983

984+
let subscription_id = match RUNTIME.block_on(sub_id_rx) {
985+
Ok(id) => id,
986+
Err(_) => {
987+
return Result::Err(Error {
988+
message: CString::new("Failed to establish event message subscription")
989+
.unwrap()
990+
.into_raw(),
991+
});
992+
}
993+
};
994+
995+
let subscription = Subscription { id: subscription_id, trigger };
970996
Result::Ok(Box::into_raw(Box::new(subscription)))
971997
}
972998

@@ -988,11 +1014,9 @@ pub unsafe extern "C" fn client_update_event_message_subscription(
9881014
) -> Result<bool> {
9891015
let clause: Option<torii_proto::Clause> = clause.map(|c| c.into()).into();
9901016

991-
match RUNTIME.block_on(
992-
(*client)
993-
.inner
994-
.update_event_message_subscription((*subscription).id.load(Ordering::SeqCst), clause),
995-
) {
1017+
match RUNTIME
1018+
.block_on((*client).inner.update_event_message_subscription((*subscription).id, clause))
1019+
{
9961020
Ok(_) => Result::Ok(true),
9971021
Err(e) => Result::Err(e.into()),
9981022
}
@@ -1023,11 +1047,10 @@ pub unsafe extern "C" fn client_on_starknet_event(
10231047
clauses.iter().map(|c| c.clone().into()).collect::<Vec<_>>()
10241048
};
10251049

1026-
let subscription_id = Arc::new(AtomicU64::new(0));
1050+
let (sub_id_tx, sub_id_rx) = oneshot::channel();
1051+
let mut sub_id_tx = Some(sub_id_tx);
10271052
let (trigger, tripwire) = Tripwire::new();
10281053

1029-
let subscription = Subscription { id: Arc::clone(&subscription_id), trigger };
1030-
10311054
// Spawn a new thread to handle the stream and reconnections
10321055
let client_clone = client.clone();
10331056
RUNTIME.spawn(async move {
@@ -1043,7 +1066,12 @@ pub unsafe extern "C" fn client_on_starknet_event(
10431066
let mut rcv = rcv.take_until_if(tripwire.clone());
10441067

10451068
while let Some(Ok(event)) = rcv.next().await {
1046-
callback(event.into());
1069+
if let Some(tx) = sub_id_tx.take() {
1070+
tx.send(0).expect("Failed to send subscription ID");
1071+
} else {
1072+
let event: Event = event.into();
1073+
callback(event);
1074+
}
10471075
}
10481076
}
10491077

@@ -1057,6 +1085,17 @@ pub unsafe extern "C" fn client_on_starknet_event(
10571085
}
10581086
});
10591087

1088+
let subscription_id = match RUNTIME.block_on(sub_id_rx) {
1089+
Ok(id) => id,
1090+
Err(_) => {
1091+
return Result::Err(Error {
1092+
message: CString::new("Failed to establish event subscription").unwrap().into_raw(),
1093+
});
1094+
}
1095+
};
1096+
1097+
let subscription = Subscription { id: subscription_id, trigger };
1098+
10601099
Result::Ok(Box::into_raw(Box::new(subscription)))
10611100
}
10621101

@@ -1151,11 +1190,10 @@ pub unsafe extern "C" fn client_on_token_update(
11511190
ids.iter().map(|f| f.clone().into()).collect::<Vec<U256>>()
11521191
};
11531192

1154-
let subscription_id = Arc::new(AtomicU64::new(0));
1193+
let (sub_id_tx, sub_id_rx) = oneshot::channel();
1194+
let mut sub_id_tx = Some(sub_id_tx);
11551195
let (trigger, tripwire) = Tripwire::new();
11561196

1157-
let subscription = Subscription { id: Arc::clone(&subscription_id), trigger };
1158-
11591197
// Spawn a new thread to handle the stream and reconnections
11601198
let client_clone = client.clone();
11611199
RUNTIME.spawn(async move {
@@ -1174,9 +1212,13 @@ pub unsafe extern "C" fn client_on_token_update(
11741212
let mut rcv = rcv.take_until_if(tripwire.clone());
11751213

11761214
while let Some(Ok((id, token))) = rcv.next().await {
1177-
subscription_id.store(id, Ordering::SeqCst);
1178-
let token: Token = token.into();
1179-
callback(token);
1215+
// Our first message will be the subscription ID
1216+
if let Some(tx) = sub_id_tx.take() {
1217+
tx.send(id).expect("Failed to send subscription ID");
1218+
} else {
1219+
let token: Token = token.into();
1220+
callback(token);
1221+
}
11801222
}
11811223
}
11821224

@@ -1190,6 +1232,17 @@ pub unsafe extern "C" fn client_on_token_update(
11901232
}
11911233
});
11921234

1235+
let subscription_id = match RUNTIME.block_on(sub_id_rx) {
1236+
Ok(id) => id,
1237+
Err(_) => {
1238+
return Result::Err(Error {
1239+
message: CString::new("Failed to establish token subscription").unwrap().into_raw(),
1240+
});
1241+
}
1242+
};
1243+
1244+
let subscription = Subscription { id: subscription_id, trigger };
1245+
11931246
Result::Ok(Box::into_raw(Box::new(subscription)))
11941247
}
11951248

@@ -1351,11 +1404,10 @@ pub unsafe extern "C" fn on_indexer_update(
13511404
Some(unsafe { (*contract_address).clone().into() })
13521405
};
13531406

1354-
let subscription_id = Arc::new(AtomicU64::new(0));
1407+
let (sub_id_tx, sub_id_rx) = oneshot::channel();
1408+
let mut sub_id_tx = Some(sub_id_tx);
13551409
let (trigger, tripwire) = Tripwire::new();
13561410

1357-
let subscription = Subscription { id: Arc::clone(&subscription_id), trigger };
1358-
13591411
// Spawn a new thread to handle the stream and reconnections
13601412
let client_clone = client.clone();
13611413
RUNTIME.spawn(async move {
@@ -1370,7 +1422,11 @@ pub unsafe extern "C" fn on_indexer_update(
13701422
let mut rcv = rcv.take_until_if(tripwire.clone());
13711423

13721424
while let Some(Ok(update)) = rcv.next().await {
1373-
callback(update.into());
1425+
if let Some(tx) = sub_id_tx.take() {
1426+
tx.send(0).expect("Failed to send subscription ID");
1427+
} else {
1428+
callback(update.into());
1429+
}
13741430
}
13751431
}
13761432

@@ -1384,6 +1440,19 @@ pub unsafe extern "C" fn on_indexer_update(
13841440
}
13851441
});
13861442

1443+
let subscription_id = match RUNTIME.block_on(sub_id_rx) {
1444+
Ok(id) => id,
1445+
Err(_) => {
1446+
return Result::Err(Error {
1447+
message: CString::new("Failed to establish indexer subscription")
1448+
.unwrap()
1449+
.into_raw(),
1450+
});
1451+
}
1452+
};
1453+
1454+
let subscription = Subscription { id: subscription_id, trigger };
1455+
13871456
Result::Ok(Box::into_raw(Box::new(subscription)))
13881457
}
13891458

@@ -1437,11 +1506,10 @@ pub unsafe extern "C" fn client_on_token_balance_update(
14371506
ids.iter().map(|f| f.clone().into()).collect::<Vec<U256>>()
14381507
};
14391508

1440-
let subscription_id = Arc::new(AtomicU64::new(0));
1509+
let (sub_id_tx, sub_id_rx) = oneshot::channel();
1510+
let mut sub_id_tx = Some(sub_id_tx);
14411511
let (trigger, tripwire) = Tripwire::new();
14421512

1443-
let subscription = Subscription { id: Arc::clone(&subscription_id), trigger };
1444-
14451513
// Spawn a new thread to handle the stream and reconnections
14461514
let client_clone = client.clone();
14471515
RUNTIME.spawn(async move {
@@ -1464,9 +1532,12 @@ pub unsafe extern "C" fn client_on_token_balance_update(
14641532
let mut rcv = rcv.take_until_if(tripwire.clone());
14651533

14661534
while let Some(Ok((id, balance))) = rcv.next().await {
1467-
subscription_id.store(id, Ordering::SeqCst);
1468-
let balance: TokenBalance = balance.into();
1469-
callback(balance);
1535+
if let Some(tx) = sub_id_tx.take() {
1536+
tx.send(id).expect("Failed to send subscription ID");
1537+
} else {
1538+
let balance: TokenBalance = balance.into();
1539+
callback(balance);
1540+
}
14701541
}
14711542
}
14721543

@@ -1480,6 +1551,19 @@ pub unsafe extern "C" fn client_on_token_balance_update(
14801551
}
14811552
});
14821553

1554+
let subscription_id = match RUNTIME.block_on(sub_id_rx) {
1555+
Ok(id) => id,
1556+
Err(_) => {
1557+
return Result::Err(Error {
1558+
message: CString::new("Failed to establish token balance subscription")
1559+
.unwrap()
1560+
.into_raw(),
1561+
});
1562+
}
1563+
};
1564+
1565+
let subscription = Subscription { id: subscription_id, trigger };
1566+
14831567
Result::Ok(Box::into_raw(Box::new(subscription)))
14841568
}
14851569

@@ -1530,7 +1614,7 @@ pub unsafe extern "C" fn client_update_token_balance_subscription(
15301614
};
15311615

15321616
match RUNTIME.block_on((*client).inner.update_token_balance_subscription(
1533-
(*subscription).id.load(Ordering::SeqCst),
1617+
(*subscription).id,
15341618
contract_addresses,
15351619
account_addresses,
15361620
token_ids,

src/types.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ use std::ffi::c_char;
33
use std::fs::File;
44
use std::io::{BufReader, BufWriter};
55
use std::path::PathBuf;
6-
use std::sync::atomic::AtomicU64;
76
use std::sync::Arc;
87

98
use serde::{Deserialize, Serialize};
@@ -106,8 +105,9 @@ pub struct ControllerAccount {
106105
pub(crate) account: account_sdk::account::session::account::SessionAccount,
107106
pub(crate) username: String,
108107
}
108+
109109
#[wasm_bindgen]
110110
pub struct Subscription {
111-
pub(crate) id: Arc<AtomicU64>,
111+
pub id: u64,
112112
pub(crate) trigger: Trigger,
113113
}

0 commit comments

Comments
 (0)