@@ -13,7 +13,7 @@ use std::sync::Arc;
1313use tokio:: sync:: { mpsc, oneshot, Mutex } ;
1414use tokio:: task:: JoinHandle ;
1515use tokio:: time:: Duration ;
16- use tracing:: { debug, error, info, warn} ;
16+ use tracing:: { debug, error, info, trace , warn} ;
1717
1818/// Default channel buffer size
1919const DEFAULT_CHANNEL_SIZE : usize = 100 ;
@@ -43,6 +43,10 @@ pub struct ShardManager<H: WebSocketHandler> {
4343 /// Mutex to serialize start/stop operations and prevent race conditions.
4444 /// This is a tokio::Mutex so it can be held across await points.
4545 lifecycle_lock : Mutex < ( ) > ,
46+ /// Channel for connections to request hot switchover
47+ switchover_tx : mpsc:: Sender < usize > ,
48+ /// Receiver for switchover requests (wrapped in Mutex for interior mutability)
49+ switchover_rx : Mutex < mpsc:: Receiver < usize > > ,
4650}
4751
4852struct ManagerState < S : Clone + Eq + std:: hash:: Hash > {
@@ -73,6 +77,8 @@ impl<S: Clone + Eq + std::hash::Hash> Default for ManagerState<S> {
7377impl < H : WebSocketHandler > ShardManager < H > {
7478 /// Create a new shard manager
7579 pub fn new ( config : ShardManagerConfig , handler : H ) -> Self {
80+ // Channel for connections to request hot switchover (buffer 10 requests)
81+ let ( switchover_tx, switchover_rx) = mpsc:: channel ( 10 ) ;
7682 Self {
7783 handler : Arc :: new ( handler) ,
7884 config,
@@ -81,6 +87,8 @@ impl<H: WebSocketHandler> ShardManager<H> {
8187 shard_handles : RwLock :: new ( HashMap :: new ( ) ) ,
8288 next_shard_id : AtomicUsize :: new ( 0 ) ,
8389 lifecycle_lock : Mutex :: new ( ( ) ) ,
90+ switchover_tx,
91+ switchover_rx : Mutex :: new ( switchover_rx) ,
8492 }
8593 }
8694
@@ -89,6 +97,11 @@ impl<H: WebSocketHandler> ShardManager<H> {
8997 self . metrics . clone ( )
9098 }
9199
100+ /// Get a reference to the handler
101+ pub fn handler ( & self ) -> & Arc < H > {
102+ & self . handler
103+ }
104+
92105 /// Check if the manager is currently running
93106 pub fn is_running ( & self ) -> bool {
94107 self . state . read ( ) . is_running
@@ -202,6 +215,28 @@ impl<H: WebSocketHandler> ShardManager<H> {
202215 Ok ( ( ) )
203216 }
204217
218+ /// Process pending hot switchover requests from connections.
219+ ///
220+ /// Call this periodically (e.g., in a background task) to handle
221+ /// data timeout triggered switchovers. Returns the number of
222+ /// switchovers initiated.
223+ pub async fn process_switchover_requests ( & self ) -> usize {
224+ let mut count = 0 ;
225+ let mut rx = self . switchover_rx . lock ( ) . await ;
226+
227+ // Process all pending requests
228+ while let Ok ( shard_id) = rx. try_recv ( ) {
229+ info ! ( "[SHARD-{}] Processing hot switchover request" , shard_id) ;
230+ if let Err ( e) = self . hot_switchover ( shard_id) . await {
231+ warn ! ( "[SHARD-{}] Hot switchover failed: {}" , shard_id, e) ;
232+ } else {
233+ count += 1 ;
234+ }
235+ }
236+
237+ count
238+ }
239+
205240 /// Stop all shards gracefully
206241 ///
207242 /// This will close all connections and wait for tasks to complete.
@@ -287,6 +322,7 @@ impl<H: WebSocketHandler> ShardManager<H> {
287322
288323 // Check if already subscribed
289324 if let Some ( & shard_id) = state. subscription_to_shard . get ( & item) {
325+ trace ! ( "[SHARD-{}] Subscription already exists, skipping" , shard_id) ;
290326 return SubscribeResult :: AlreadySubscribed { shard_id } ;
291327 }
292328
@@ -332,6 +368,7 @@ impl<H: WebSocketHandler> ShardManager<H> {
332368 ( tx, sub_count)
333369 } ;
334370 state. subscription_to_shard . insert ( item. clone ( ) , shard_id) ;
371+ trace ! ( "[SHARD-{}] Added subscription (count: {})" , shard_id, add_result. 1 ) ;
335372 Ok ( ( shard_id, add_result. 0 , add_result. 1 ) )
336373 }
337374 None if self . config . auto_rebalance => {
@@ -397,6 +434,7 @@ impl<H: WebSocketHandler> ShardManager<H> {
397434 ( tx, count)
398435 } ;
399436 state. subscription_to_shard . insert ( item. clone ( ) , new_id) ;
437+ trace ! ( "[SHARD-{}] Created new shard, added subscription (count: {})" , new_id, add_result. 1 ) ;
400438 ( new_id, add_result. 0 , add_result. 1 )
401439 }
402440 } ;
@@ -519,6 +557,7 @@ impl<H: WebSocketHandler> ShardManager<H> {
519557 for sub in & subs {
520558 shard. remove_subscription ( sub) ;
521559 }
560+ trace ! ( "[SHARD-{}] Removed {} subscriptions" , shard_id, subs. len( ) ) ;
522561
523562 let tx = shard. command_tx . clone ( ) ;
524563 let count = shard. subscription_count ( ) ;
@@ -594,10 +633,17 @@ impl<H: WebSocketHandler> ShardManager<H> {
594633 let shard = metrics. get ( shard_id) ?;
595634
596635 if !shard. is_connected {
636+ trace ! ( "[SHARD-{}] Freshness check: disconnected" , shard_id) ;
597637 return None ;
598638 }
599639
600- shard. time_since_last_message
640+ let freshness = shard. time_since_last_message ;
641+ trace ! (
642+ "[SHARD-{}] Freshness check: {:?}" ,
643+ shard_id,
644+ freshness. map( |d| format!( "{}ms" , d. as_millis( ) ) ) . unwrap_or_else( || "no data" . to_string( ) )
645+ ) ;
646+ freshness
601647 }
602648
603649 /// Check if a subscription's data is considered fresh.
@@ -626,12 +672,26 @@ impl<H: WebSocketHandler> ShardManager<H> {
626672 max_staleness : Duration ,
627673 ) -> bool {
628674 let Some ( shard_id) = self . subscription_shard ( subscription) else {
675+ trace ! ( "Freshness check: subscription not found" ) ;
629676 return false ;
630677 } ;
631678
632679 match self . shard_freshness ( shard_id) {
633- Some ( since_last_message) => since_last_message <= max_staleness,
634- None => false ,
680+ Some ( since_last_message) => {
681+ let is_fresh = since_last_message <= max_staleness;
682+ trace ! (
683+ "[SHARD-{}] Subscription freshness: {}ms vs max {}ms -> {}" ,
684+ shard_id,
685+ since_last_message. as_millis( ) ,
686+ max_staleness. as_millis( ) ,
687+ if is_fresh { "FRESH" } else { "STALE" }
688+ ) ;
689+ is_fresh
690+ }
691+ None => {
692+ trace ! ( "[SHARD-{}] Subscription freshness: no data -> STALE" , shard_id) ;
693+ false
694+ }
635695 }
636696 }
637697
@@ -855,6 +915,7 @@ impl<H: WebSocketHandler> ShardManager<H> {
855915 let backoff_config = self . config . backoff . clone ( ) ;
856916 let health_config = self . config . health . clone ( ) ;
857917 let metrics = self . metrics . clone ( ) ;
918+ let switchover_tx = self . switchover_tx . clone ( ) ;
858919
859920 let handle = tokio:: spawn ( async move {
860921 Self :: run_connection_with_recovery (
@@ -865,6 +926,7 @@ impl<H: WebSocketHandler> ShardManager<H> {
865926 health_config,
866927 metrics,
867928 rx,
929+ switchover_tx,
868930 )
869931 . await
870932 } ) ;
@@ -899,18 +961,20 @@ impl<H: WebSocketHandler> ShardManager<H> {
899961 health_config : crate :: config:: HealthConfig ,
900962 metrics : Arc < Metrics > ,
901963 command_rx : mpsc:: Receiver < ConnectionCommand > ,
964+ switchover_tx : mpsc:: Sender < usize > ,
902965 ) {
903966 // Note: We can't easily restart the connection after channel closure,
904967 // so we just catch panics within this task and log them.
905968 // For full recovery, the manager would need to recreate the channel.
906- let connection = Connection :: new (
969+ let connection = Connection :: with_switchover_channel (
907970 shard_id,
908971 handler,
909972 connection_config,
910973 backoff_config,
911974 health_config,
912975 metrics. clone ( ) ,
913976 command_rx,
977+ switchover_tx,
914978 ) ;
915979
916980 match AssertUnwindSafe ( connection. run ( ) )
0 commit comments