Skip to content

Commit da21b9d

Browse files
authored
refactor: Improve SPV shutdown handling with CancellationToken (#187)
1 parent ffd3d2e commit da21b9d

File tree

4 files changed

+79
-46
lines changed

4 files changed

+79
-46
lines changed

dash-spv-ffi/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ dashcore = { path = "../dash", package = "dashcore" }
1717
libc = "0.2"
1818
once_cell = "1.19"
1919
tokio = { version = "1", features = ["full"] }
20+
tokio-util = "0.7"
2021
serde = { version = "1.0", features = ["derive"] }
2122
serde_json = "1.0"
2223
log = "0.4"

dash-spv-ffi/src/client.rs

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@ use std::collections::HashMap;
1717
use std::ffi::{CStr, CString};
1818
use std::os::raw::{c_char, c_void};
1919
use std::str::FromStr;
20-
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
20+
use std::sync::atomic::{AtomicU64, Ordering};
2121
use std::sync::{Arc, Mutex};
2222
use std::time::Duration;
2323
use tokio::runtime::Runtime;
2424
use tokio::sync::mpsc::{error::TryRecvError, UnboundedReceiver};
25+
use tokio_util::sync::CancellationToken;
2526

2627
/// Global callback registry for thread-safe callback management
2728
static CALLBACK_REGISTRY: Lazy<Arc<Mutex<CallbackRegistry>>> =
@@ -104,12 +105,6 @@ struct SyncCallbackData {
104105
_marker: std::marker::PhantomData<()>,
105106
}
106107

107-
async fn wait_for_shutdown_signal(signal: Arc<AtomicBool>) {
108-
while !signal.load(Ordering::Relaxed) {
109-
tokio::time::sleep(Duration::from_millis(50)).await;
110-
}
111-
}
112-
113108
/// FFIDashSpvClient structure
114109
type InnerClient = DashSpvClient<
115110
key_wallet_manager::wallet_manager::WalletManager<
@@ -126,7 +121,7 @@ pub struct FFIDashSpvClient {
126121
event_callbacks: Arc<Mutex<FFIEventCallbacks>>,
127122
active_threads: Arc<Mutex<Vec<std::thread::JoinHandle<()>>>>,
128123
sync_callbacks: Arc<Mutex<Option<SyncCallbackData>>>,
129-
shutdown_signal: Arc<AtomicBool>,
124+
shutdown_token: CancellationToken,
130125
// Stored event receiver for pull-based draining (no background thread by default)
131126
event_rx: Arc<Mutex<Option<UnboundedReceiver<dash_spv::types::SpvEvent>>>>,
132127
}
@@ -197,7 +192,7 @@ pub unsafe extern "C" fn dash_spv_ffi_client_new(
197192
event_callbacks: Arc::new(Mutex::new(FFIEventCallbacks::default())),
198193
active_threads: Arc::new(Mutex::new(Vec::new())),
199194
sync_callbacks: Arc::new(Mutex::new(None)),
200-
shutdown_signal: Arc::new(AtomicBool::new(false)),
195+
shutdown_token: CancellationToken::new(),
201196
event_rx: Arc::new(Mutex::new(None)),
202197
};
203198
Box::into_raw(Box::new(ffi_client))
@@ -378,8 +373,8 @@ pub unsafe extern "C" fn dash_spv_ffi_client_drain_events(client: *mut FFIDashSp
378373
FFIErrorCode::Success as i32
379374
}
380375

381-
fn stop_client_internal(client: &FFIDashSpvClient) -> Result<(), dash_spv::SpvError> {
382-
client.shutdown_signal.store(true, Ordering::Relaxed);
376+
fn stop_client_internal(client: &mut FFIDashSpvClient) -> Result<(), dash_spv::SpvError> {
377+
client.shutdown_token.cancel();
383378

384379
// Ensure callbacks are cleared so no further progress/completion notifications fire.
385380
{
@@ -411,7 +406,7 @@ fn stop_client_internal(client: &FFIDashSpvClient) -> Result<(), dash_spv::SpvEr
411406
res
412407
});
413408

414-
client.shutdown_signal.store(false, Ordering::Relaxed);
409+
client.shutdown_token = CancellationToken::new();
415410

416411
result
417412
}
@@ -525,7 +520,7 @@ pub unsafe extern "C" fn dash_spv_ffi_client_start(client: *mut FFIDashSpvClient
525520
pub unsafe extern "C" fn dash_spv_ffi_client_stop(client: *mut FFIDashSpvClient) -> i32 {
526521
null_check!(client);
527522

528-
let client = &(*client);
523+
let client = &mut (*client);
529524
match stop_client_internal(client) {
530525
Ok(()) => FFIErrorCode::Success as i32,
531526
Err(e) => {
@@ -785,7 +780,6 @@ pub unsafe extern "C" fn dash_spv_ffi_client_sync_to_tip_with_progress(
785780
let inner = client.inner.clone();
786781
let runtime = client.runtime.clone();
787782
let sync_callbacks = client.sync_callbacks.clone();
788-
let shutdown_signal = client.shutdown_signal.clone();
789783

790784
// Take progress receiver from client
791785
let progress_receiver = {
@@ -797,7 +791,7 @@ pub unsafe extern "C" fn dash_spv_ffi_client_sync_to_tip_with_progress(
797791
if let Some(mut receiver) = progress_receiver {
798792
let runtime_handle = runtime.handle().clone();
799793
let sync_callbacks_clone = sync_callbacks.clone();
800-
let shutdown_signal_clone = shutdown_signal.clone();
794+
let shutdown_token_monitor = client.shutdown_token.clone();
801795

802796
let handle = std::thread::spawn(move || {
803797
runtime_handle.block_on(async move {
@@ -859,7 +853,7 @@ pub unsafe extern "C" fn dash_spv_ffi_client_sync_to_tip_with_progress(
859853
None => break,
860854
}
861855
}
862-
_ = wait_for_shutdown_signal(shutdown_signal_clone.clone()) => {
856+
_ = shutdown_token_monitor.cancelled() => {
863857
break;
864858
}
865859
}
@@ -874,15 +868,12 @@ pub unsafe extern "C" fn dash_spv_ffi_client_sync_to_tip_with_progress(
874868
// Spawn sync task in a separate thread with safe callback access
875869
let runtime_handle = runtime.handle().clone();
876870
let sync_callbacks_clone = sync_callbacks.clone();
877-
let shutdown_signal_for_thread = shutdown_signal.clone();
878-
let stop_triggered_for_thread = Arc::new(AtomicBool::new(false));
871+
let shutdown_token_sync = client.shutdown_token.clone();
879872
let sync_handle = std::thread::spawn(move || {
880-
let stop_triggered_for_callback = stop_triggered_for_thread.clone();
873+
let shutdown_token_callback = shutdown_token_sync.clone();
881874
// Run monitoring loop
882875
let monitor_result = runtime_handle.block_on({
883876
let inner = inner.clone();
884-
let shutdown_signal_for_thread = shutdown_signal_for_thread.clone();
885-
let stop_triggered_for_thread = stop_triggered_for_callback.clone();
886877
async move {
887878
let mut spv_client = {
888879
let mut guard = inner.lock().unwrap();
@@ -903,8 +894,7 @@ pub unsafe extern "C" fn dash_spv_ffi_client_sync_to_tip_with_progress(
903894
Ok(inner) => inner,
904895
Err(_) => Ok(()),
905896
},
906-
_ = wait_for_shutdown_signal(shutdown_signal_for_thread.clone()) => {
907-
stop_triggered_for_thread.store(true, Ordering::Relaxed);
897+
_ = shutdown_token_sync.cancelled() => {
908898
abort_handle.abort();
909899
match monitor_future.as_mut().await {
910900
Ok(inner) => inner,
@@ -930,7 +920,7 @@ pub unsafe extern "C" fn dash_spv_ffi_client_sync_to_tip_with_progress(
930920
..
931921
}) = registry.unregister(callback_data.callback_id)
932922
{
933-
if stop_triggered_for_callback.load(Ordering::Relaxed) {
923+
if shutdown_token_callback.is_cancelled() {
934924
let msg = CString::new("Sync stopped by request").unwrap_or_else(|_| {
935925
CString::new("Sync stopped").expect("hardcoded string is safe")
936926
});
@@ -984,7 +974,7 @@ pub unsafe extern "C" fn dash_spv_ffi_client_sync_to_tip_with_progress(
984974
pub unsafe extern "C" fn dash_spv_ffi_client_cancel_sync(client: *mut FFIDashSpvClient) -> i32 {
985975
null_check!(client);
986976

987-
let client = &(*client);
977+
let client = &mut (*client);
988978

989979
match stop_client_internal(client) {
990980
Ok(()) => FFIErrorCode::Success as i32,
@@ -1318,8 +1308,8 @@ pub unsafe extern "C" fn dash_spv_ffi_client_destroy(client: *mut FFIDashSpvClie
13181308
if !client.is_null() {
13191309
let client = Box::from_raw(client);
13201310

1321-
// Set shutdown signal to stop all threads
1322-
client.shutdown_signal.store(true, Ordering::Relaxed);
1311+
// Cancel shutdown token to stop all threads
1312+
client.shutdown_token.cancel();
13231313

13241314
// Clean up any registered callbacks
13251315
if let Some(ref callback_data) = *client.sync_callbacks.lock().unwrap() {

dash-spv/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ clap = { version = "4.0", features = ["derive"] }
2323

2424
# Async runtime
2525
tokio = { version = "1.0", features = ["full"] }
26+
tokio-util = "0.7"
2627
async-trait = "0.1"
2728

2829
# Error handling

dash-spv/src/network/manager.rs

Lines changed: 60 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
use std::collections::{HashMap, HashSet};
44
use std::net::SocketAddr;
55
use std::path::PathBuf;
6-
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
6+
use std::sync::atomic::{AtomicUsize, Ordering};
77
use std::sync::Arc;
88
use std::time::{Duration, SystemTime};
99
use tokio::sync::{mpsc, Mutex};
@@ -14,6 +14,7 @@ use async_trait::async_trait;
1414
use dashcore::network::constants::ServiceFlags;
1515
use dashcore::network::message::NetworkMessage;
1616
use dashcore::Network;
17+
use tokio_util::sync::CancellationToken;
1718

1819
use crate::client::config::MempoolStrategy;
1920
use crate::client::ClientConfig;
@@ -43,8 +44,8 @@ pub struct PeerNetworkManager {
4344
reputation_manager: Arc<PeerReputationManager>,
4445
/// Network type
4546
network: Network,
46-
/// Shutdown signal
47-
shutdown: Arc<AtomicBool>,
47+
/// Shutdown token
48+
shutdown_token: CancellationToken,
4849
/// Channel for incoming messages
4950
message_tx: mpsc::Sender<(SocketAddr, NetworkMessage)>,
5051
message_rx: Arc<Mutex<mpsc::Receiver<(SocketAddr, NetworkMessage)>>>,
@@ -109,7 +110,7 @@ impl PeerNetworkManager {
109110
peer_store: Arc::new(peer_store),
110111
reputation_manager,
111112
network: config.network,
112-
shutdown: Arc::new(AtomicBool::new(false)),
113+
shutdown_token: CancellationToken::new(),
113114
message_tx,
114115
message_rx: Arc::new(Mutex::new(message_rx)),
115116
tasks: Arc::new(Mutex::new(JoinSet::new())),
@@ -204,7 +205,7 @@ impl PeerNetworkManager {
204205
let network = self.network;
205206
let message_tx = self.message_tx.clone();
206207
let addrv2_handler = self.addrv2_handler.clone();
207-
let shutdown = self.shutdown.clone();
208+
let shutdown_token = self.shutdown_token.clone();
208209
let reputation_manager = self.reputation_manager.clone();
209210
let mempool_strategy = self.mempool_strategy;
210211
let user_agent = self.user_agent.clone();
@@ -245,7 +246,7 @@ impl PeerNetworkManager {
245246
pool.clone(),
246247
message_tx,
247248
addrv2_handler,
248-
shutdown,
249+
shutdown_token,
249250
reputation_manager.clone(),
250251
connected_peer_count.clone(),
251252
)
@@ -287,19 +288,19 @@ impl PeerNetworkManager {
287288
pool: Arc<ConnectionPool>,
288289
message_tx: mpsc::Sender<(SocketAddr, NetworkMessage)>,
289290
addrv2_handler: Arc<AddrV2Handler>,
290-
shutdown: Arc<AtomicBool>,
291+
shutdown_token: CancellationToken,
291292
reputation_manager: Arc<PeerReputationManager>,
292293
connected_peer_count: Arc<AtomicUsize>,
293294
) {
294295
tokio::spawn(async move {
295296
log::debug!("Starting peer reader loop for {}", addr);
296297
let mut loop_iteration = 0;
297298

298-
while !shutdown.load(Ordering::Relaxed) {
299+
loop {
299300
loop_iteration += 1;
300301

301302
// Check shutdown signal first with detailed logging
302-
if shutdown.load(Ordering::Relaxed) {
303+
if shutdown_token.is_cancelled() {
303304
log::info!("Breaking peer reader loop for {} - shutdown signal received (iteration {})", addr, loop_iteration);
304305
break;
305306
}
@@ -326,7 +327,15 @@ impl PeerNetworkManager {
326327

327328
// Now get write lock only for the duration of the read
328329
let mut conn_guard = conn.write().await;
329-
conn_guard.receive_message().await
330+
tokio::select! {
331+
message = conn_guard.receive_message() => {
332+
message
333+
},
334+
_ = shutdown_token.cancelled() => {
335+
log::info!("Breaking peer reader loop for {} - shutdown signal received while reading (iteration {})", addr, loop_iteration);
336+
break;
337+
}
338+
}
330339
};
331340

332341
match msg_result {
@@ -576,7 +585,7 @@ impl PeerNetworkManager {
576585
let pool = self.pool.clone();
577586
let discovery = self.discovery.clone();
578587
let network = self.network;
579-
let shutdown = self.shutdown.clone();
588+
let shutdown_token = self.shutdown_token.clone();
580589
let addrv2_handler = self.addrv2_handler.clone();
581590
let peer_store = self.peer_store.clone();
582591
let reputation_manager = self.reputation_manager.clone();
@@ -599,7 +608,7 @@ impl PeerNetworkManager {
599608

600609
let mut tasks = self.tasks.lock().await;
601610
tasks.spawn(async move {
602-
while !shutdown.load(Ordering::Relaxed) {
611+
while !shutdown_token.is_cancelled() {
603612
// Clean up disconnected peers
604613
pool.cleanup_disconnected().await;
605614

@@ -612,7 +621,13 @@ impl PeerNetworkManager {
612621
for addr in initial_peers.iter() {
613622
if !pool.is_connected(addr).await && !pool.is_connecting(addr).await {
614623
log::info!("Reconnecting to exclusive peer: {}", addr);
615-
connect_fn(*addr).await;
624+
tokio::select! {
625+
_= connect_fn(*addr) => {},
626+
_ = shutdown_token.cancelled() => {
627+
log::info!("Maintenance loop shutting down during connection attempt (exclusive)");
628+
break;
629+
}
630+
}
616631
}
617632
}
618633
} else {
@@ -642,7 +657,13 @@ impl PeerNetworkManager {
642657

643658
for addr in best_peers {
644659
if !pool.is_connected(&addr).await && !pool.is_connecting(&addr).await {
645-
connect_fn(addr).await;
660+
tokio::select! {
661+
_= connect_fn(addr) => {},
662+
_ = shutdown_token.cancelled() => {
663+
log::info!("Maintenance loop shutting down during connection attempt (min peers)");
664+
break;
665+
}
666+
}
646667
attempted += 1;
647668
if attempted >= needed {
648669
break;
@@ -661,11 +682,23 @@ impl PeerNetworkManager {
661682
});
662683
if elapsed >= DNS_DISCOVERY_DELAY {
663684
log::info!("Using DNS discovery after {}s delay", elapsed.as_secs());
664-
let dns_peers = discovery.discover_peers(network).await;
685+
let dns_peers = tokio::select! {
686+
peers = discovery.discover_peers(network) => peers,
687+
_ = shutdown_token.cancelled() => {
688+
log::info!("Maintenance loop shutting down during DNS discovery");
689+
break;
690+
}
691+
};
665692
let mut dns_attempted = 0;
666693
for addr in dns_peers.into_iter() {
667694
if !pool.is_connected(&addr).await && !pool.is_connecting(&addr).await {
668-
connect_fn(addr).await;
695+
tokio::select! {
696+
_= connect_fn(addr) => {},
697+
_ = shutdown_token.cancelled() => {
698+
log::info!("Maintenance loop shutting down during connection attempt (dns)");
699+
break;
700+
}
701+
}
669702
dns_attempted += 1;
670703
if dns_attempted >= needed {
671704
break;
@@ -719,7 +752,15 @@ impl PeerNetworkManager {
719752
}
720753
}
721754

722-
time::sleep(MAINTENANCE_INTERVAL).await;
755+
tokio::select! {
756+
_ = time::sleep(MAINTENANCE_INTERVAL) => {
757+
log::debug!("Maintenance interval elapsed");
758+
}
759+
_ = shutdown_token.cancelled() => {
760+
log::info!("Maintenance loop shutting down");
761+
break;
762+
}
763+
}
723764
}
724765
});
725766
}
@@ -942,7 +983,7 @@ impl PeerNetworkManager {
942983
/// Shutdown the network manager
943984
pub async fn shutdown(&self) {
944985
log::info!("Shutting down peer network manager");
945-
self.shutdown.store(true, Ordering::Relaxed);
986+
self.shutdown_token.cancel();
946987

947988
// Save known peers before shutdown
948989
let addresses = self.addrv2_handler.get_addresses_for_peer(MAX_ADDR_TO_STORE).await;
@@ -983,7 +1024,7 @@ impl Clone for PeerNetworkManager {
9831024
peer_store: self.peer_store.clone(),
9841025
reputation_manager: self.reputation_manager.clone(),
9851026
network: self.network,
986-
shutdown: self.shutdown.clone(),
1027+
shutdown_token: self.shutdown_token.clone(),
9871028
message_tx: self.message_tx.clone(),
9881029
message_rx: self.message_rx.clone(),
9891030
tasks: self.tasks.clone(),

0 commit comments

Comments
 (0)