@@ -8,6 +8,7 @@ use futures_util::FutureExt;
88use parking_lot:: RwLock ;
99use std:: collections:: { HashMap , HashSet } ;
1010use std:: panic:: AssertUnwindSafe ;
11+ use std:: sync:: atomic:: { AtomicUsize , Ordering } ;
1112use std:: sync:: Arc ;
1213use tokio:: sync:: { mpsc, oneshot} ;
1314use tokio:: task:: JoinHandle ;
@@ -23,13 +24,21 @@ type ShardUnsubData<S> = (usize, Vec<S>, mpsc::Sender<ConnectionCommand>, usize)
2324/// Timeout for waiting for new connection during hot switchover
2425const HOT_SWITCHOVER_TIMEOUT : Duration = Duration :: from_secs ( 10 ) ;
2526
26- /// Manages multiple WebSocket shards with auto-rebalancing and hot switchover
27+ /// Manages multiple WebSocket shards with auto-rebalancing and hot switchover.
28+ ///
29+ /// # Thread Safety
30+ ///
31+ /// `ShardManager` is `Send + Sync` and all methods can be safely called from
32+ /// multiple tasks concurrently. Internal state is protected by `parking_lot::RwLock`
33+ /// which does not poison on panic.
2734pub struct ShardManager < H : WebSocketHandler > {
2835 handler : Arc < H > ,
2936 config : ShardManagerConfig ,
3037 metrics : Arc < Metrics > ,
3138 state : Arc < RwLock < ManagerState < H :: Subscription > > > ,
3239 shard_handles : RwLock < Vec < JoinHandle < ( ) > > > ,
40+ /// Monotonically increasing counter for shard IDs to prevent race conditions
41+ next_shard_id : AtomicUsize ,
3342}
3443
3544struct ManagerState < S : Clone + Eq + std:: hash:: Hash > {
@@ -65,6 +74,7 @@ impl<H: WebSocketHandler> ShardManager<H> {
6574 metrics : Arc :: new ( Metrics :: new ( ) ) ,
6675 state : Arc :: new ( RwLock :: new ( ManagerState :: default ( ) ) ) ,
6776 shard_handles : RwLock :: new ( Vec :: new ( ) ) ,
77+ next_shard_id : AtomicUsize :: new ( 0 ) ,
6878 }
6979 }
7080
@@ -73,12 +83,32 @@ impl<H: WebSocketHandler> ShardManager<H> {
7383 self . metrics . clone ( )
7484 }
7585
86+ /// Check if the manager is currently running
87+ pub fn is_running ( & self ) -> bool {
88+ self . state . read ( ) . is_running
89+ }
90+
7691 /// Start the shard manager
7792 ///
7893 /// This will create the initial shards based on existing subscriptions
7994 /// from the handler. If there are no initial subscriptions, no shards
8095 /// are created until the first subscription is added (lazy creation).
96+ ///
97+ /// # Errors
98+ ///
99+ /// Returns an error if the manager is already running or if configuration
100+ /// is invalid.
81101 pub async fn start ( & self ) -> Result < ( ) , Error > {
102+ // Check if already running
103+ {
104+ let state = self . state . read ( ) ;
105+ if state. is_running {
106+ return Err ( Error :: Handler (
107+ "ShardManager is already running" . to_string ( ) ,
108+ ) ) ;
109+ }
110+ }
111+
82112 let subscriptions = self . handler . subscriptions ( ) ;
83113 let max_per_shard = self . max_per_shard ( ) ;
84114
@@ -103,8 +133,9 @@ impl<H: WebSocketHandler> ShardManager<H> {
103133 max_per_shard
104134 ) ;
105135
106- // Create shards
107- for shard_id in 0 ..shard_count {
136+ // Create shards using atomic counter for IDs
137+ for _ in 0 ..shard_count {
138+ let shard_id = self . next_shard_id . fetch_add ( 1 , Ordering :: SeqCst ) ;
108139 self . create_shard ( shard_id) . await ?;
109140 }
110141
@@ -199,6 +230,9 @@ impl<H: WebSocketHandler> ShardManager<H> {
199230 state. shards_being_created . clear ( ) ;
200231 }
201232
233+ // Reset shard ID counter for potential restart
234+ self . next_shard_id . store ( 0 , Ordering :: SeqCst ) ;
235+
202236 info ! ( "ShardManager stopped" ) ;
203237 Ok ( ( ) )
204238 }
@@ -270,8 +304,8 @@ impl<H: WebSocketHandler> ShardManager<H> {
270304 Ok ( ( shard_id, tx, sub_count) )
271305 }
272306 None if self . config . auto_rebalance => {
273- // Need to create new shard - reserve the ID to prevent race conditions
274- let new_id = state . shards . len ( ) + state . shards_being_created . len ( ) ;
307+ // Need to create new shard - use atomic counter for thread-safe ID assignment
308+ let new_id = self . next_shard_id . fetch_add ( 1 , Ordering :: SeqCst ) ;
275309 state. shards_being_created . insert ( new_id) ;
276310 Err ( new_id)
277311 }
@@ -328,13 +362,30 @@ impl<H: WebSocketHandler> ShardManager<H> {
328362 . update_shard ( shard_id, |s| s. subscription_count = sub_count) ;
329363
330364 // Send subscribe message outside the lock
331- if let Some ( msg) = self . handler . subscription_message ( & [ item] ) {
365+ if let Some ( msg) = self . handler . subscription_message ( std :: slice :: from_ref ( & item) ) {
332366 if let Err ( e) = command_tx. send ( ConnectionCommand :: Send ( msg) ) . await {
333367 self . metrics . record_subscription_send_failed ( ) ;
334368 warn ! (
335369 "[SHARD-{}] Failed to send subscription message: {}" ,
336370 shard_id, e
337371 ) ;
372+
373+ // Rollback: remove subscription from state so retry is possible
374+ {
375+ let mut state = self . state . write ( ) ;
376+ if let Some ( shard) = state. shards . get_mut ( shard_id) {
377+ shard. remove_subscription ( & item) ;
378+ }
379+ state. subscription_to_shard . remove ( & item) ;
380+ }
381+
382+ // Update metrics to reflect rollback
383+ let new_count = {
384+ let state = self . state . read ( ) ;
385+ state. shards . get ( shard_id) . map ( |s| s. subscription_count ( ) ) . unwrap_or ( 0 )
386+ } ;
387+ self . metrics . update_shard ( shard_id, |s| s. subscription_count = new_count) ;
388+
338389 return SubscribeResult :: SendFailed {
339390 shard_id,
340391 error : e. to_string ( ) ,
@@ -345,10 +396,16 @@ impl<H: WebSocketHandler> ShardManager<H> {
345396 SubscribeResult :: Success { shard_id }
346397 }
347398
348- /// Convenience method that returns only the shard IDs of successful subscriptions
399+ /// Subscribe to items and return affected shard IDs.
400+ ///
401+ /// This is a convenience method that returns the shard IDs of successful
402+ /// subscriptions. Use [`subscribe`] instead if you need per-item error details.
403+ ///
404+ /// # Errors
349405 ///
350- /// This is a simpler API for cases where you don't need per-item error handling.
351- pub async fn subscribe_simple ( & self , items : Vec < H :: Subscription > ) -> Result < Vec < usize > , Error > {
406+ /// Returns an error only if all subscriptions fail. Partial success returns `Ok`
407+ /// with the shard IDs that were affected.
408+ pub async fn subscribe_all ( & self , items : Vec < H :: Subscription > ) -> Result < Vec < usize > , Error > {
352409 let results = self . subscribe ( items) . await ;
353410 let mut affected_shards_set = HashSet :: new ( ) ;
354411 let mut had_failure = false ;
0 commit comments