Skip to content

Commit c84cc55

Browse files
KaboomFoxclaude
andcommitted
Fix code review issues: race conditions, API, and docs
Correctness fixes: - Fix race condition in shard creation by using atomic counter for IDs instead of calculating from vector length - Fix subscription state inconsistency: rollback subscription from state when send fails, allowing retry without unsubscribe first API improvements: - Add is_running() method and prevent double start() - Rename subscribe_simple to subscribe_all with better docs - Add config validations for circuit_breaker_threshold > 0 and data_timeout >= ping_interval Documentation: - Add thread safety docs to ShardManager and Metrics - Document all MetricsSnapshot fields - Add ConfigError::InvalidConnection variant Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 98cb478 commit c84cc55

File tree

5 files changed

+111
-12
lines changed

5 files changed

+111
-12
lines changed

examples/polymarket.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
143143
.map(|i| format!("token-{}", i))
144144
.collect();
145145

146-
manager.subscribe_simple(new_tokens).await?;
146+
manager.subscribe_all(new_tokens).await?;
147147
info!("Subscribed to 10 more tokens, total: {}", manager.total_subscriptions());
148148

149149
// Run for a while

src/config.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,19 @@ impl ShardManagerConfigBuilder {
114114
));
115115
}
116116

117+
if self.config.health.data_timeout < self.config.health.ping_interval {
118+
return Err(ConfigError::InvalidHealth(
119+
"data_timeout should be >= ping_interval to avoid false positives".to_string(),
120+
));
121+
}
122+
123+
// Validate connection config
124+
if self.config.connection.circuit_breaker_threshold == 0 {
125+
return Err(ConfigError::InvalidConnection(
126+
"circuit_breaker_threshold must be > 0".to_string(),
127+
));
128+
}
129+
117130
// Validate max subscriptions
118131
if let Some(0) = self.config.max_subscriptions_per_shard {
119132
return Err(ConfigError::InvalidSubscriptionLimit(
@@ -134,6 +147,9 @@ pub enum ConfigError {
134147
/// Invalid health configuration
135148
#[error("Invalid health configuration: {0}")]
136149
InvalidHealth(String),
150+
/// Invalid connection configuration
151+
#[error("Invalid connection configuration: {0}")]
152+
InvalidConnection(String),
137153
/// Invalid subscription limit
138154
#[error("Invalid subscription limit: {0}")]
139155
InvalidSubscriptionLimit(String),

src/connection.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ impl<H: WebSocketHandler> Connection<H> {
6363
}
6464

6565
/// Create a new connection with a ready signal for hot switchover
66+
#[allow(clippy::too_many_arguments)]
6667
pub fn with_ready_signal(
6768
shard_id: usize,
6869
handler: Arc<H>,

src/manager.rs

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use futures_util::FutureExt;
88
use parking_lot::RwLock;
99
use std::collections::{HashMap, HashSet};
1010
use std::panic::AssertUnwindSafe;
11+
use std::sync::atomic::{AtomicUsize, Ordering};
1112
use std::sync::Arc;
1213
use tokio::sync::{mpsc, oneshot};
1314
use tokio::task::JoinHandle;
@@ -23,13 +24,21 @@ type ShardUnsubData<S> = (usize, Vec<S>, mpsc::Sender<ConnectionCommand>, usize)
2324
/// Timeout for waiting for new connection during hot switchover
2425
const HOT_SWITCHOVER_TIMEOUT: Duration = Duration::from_secs(10);
2526

26-
/// Manages multiple WebSocket shards with auto-rebalancing and hot switchover
27+
/// Manages multiple WebSocket shards with auto-rebalancing and hot switchover.
28+
///
29+
/// # Thread Safety
30+
///
31+
/// `ShardManager` is `Send + Sync` and all methods can be safely called from
32+
/// multiple tasks concurrently. Internal state is protected by `parking_lot::RwLock`
33+
/// which does not poison on panic.
2734
pub struct ShardManager<H: WebSocketHandler> {
2835
handler: Arc<H>,
2936
config: ShardManagerConfig,
3037
metrics: Arc<Metrics>,
3138
state: Arc<RwLock<ManagerState<H::Subscription>>>,
3239
shard_handles: RwLock<Vec<JoinHandle<()>>>,
40+
/// Monotonically increasing counter for shard IDs to prevent race conditions
41+
next_shard_id: AtomicUsize,
3342
}
3443

3544
struct ManagerState<S: Clone + Eq + std::hash::Hash> {
@@ -65,6 +74,7 @@ impl<H: WebSocketHandler> ShardManager<H> {
6574
metrics: Arc::new(Metrics::new()),
6675
state: Arc::new(RwLock::new(ManagerState::default())),
6776
shard_handles: RwLock::new(Vec::new()),
77+
next_shard_id: AtomicUsize::new(0),
6878
}
6979
}
7080

@@ -73,12 +83,32 @@ impl<H: WebSocketHandler> ShardManager<H> {
7383
self.metrics.clone()
7484
}
7585

86+
/// Check if the manager is currently running
87+
pub fn is_running(&self) -> bool {
88+
self.state.read().is_running
89+
}
90+
7691
/// Start the shard manager
7792
///
7893
/// This will create the initial shards based on existing subscriptions
7994
/// from the handler. If there are no initial subscriptions, no shards
8095
/// are created until the first subscription is added (lazy creation).
96+
///
97+
/// # Errors
98+
///
99+
/// Returns an error if the manager is already running or if configuration
100+
/// is invalid.
81101
pub async fn start(&self) -> Result<(), Error> {
102+
// Check if already running
103+
{
104+
let state = self.state.read();
105+
if state.is_running {
106+
return Err(Error::Handler(
107+
"ShardManager is already running".to_string(),
108+
));
109+
}
110+
}
111+
82112
let subscriptions = self.handler.subscriptions();
83113
let max_per_shard = self.max_per_shard();
84114

@@ -103,8 +133,9 @@ impl<H: WebSocketHandler> ShardManager<H> {
103133
max_per_shard
104134
);
105135

106-
// Create shards
107-
for shard_id in 0..shard_count {
136+
// Create shards using atomic counter for IDs
137+
for _ in 0..shard_count {
138+
let shard_id = self.next_shard_id.fetch_add(1, Ordering::SeqCst);
108139
self.create_shard(shard_id).await?;
109140
}
110141

@@ -199,6 +230,9 @@ impl<H: WebSocketHandler> ShardManager<H> {
199230
state.shards_being_created.clear();
200231
}
201232

233+
// Reset shard ID counter for potential restart
234+
self.next_shard_id.store(0, Ordering::SeqCst);
235+
202236
info!("ShardManager stopped");
203237
Ok(())
204238
}
@@ -270,8 +304,8 @@ impl<H: WebSocketHandler> ShardManager<H> {
270304
Ok((shard_id, tx, sub_count))
271305
}
272306
None if self.config.auto_rebalance => {
273-
// Need to create new shard - reserve the ID to prevent race conditions
274-
let new_id = state.shards.len() + state.shards_being_created.len();
307+
// Need to create new shard - use atomic counter for thread-safe ID assignment
308+
let new_id = self.next_shard_id.fetch_add(1, Ordering::SeqCst);
275309
state.shards_being_created.insert(new_id);
276310
Err(new_id)
277311
}
@@ -328,13 +362,30 @@ impl<H: WebSocketHandler> ShardManager<H> {
328362
.update_shard(shard_id, |s| s.subscription_count = sub_count);
329363

330364
// Send subscribe message outside the lock
331-
if let Some(msg) = self.handler.subscription_message(&[item]) {
365+
if let Some(msg) = self.handler.subscription_message(std::slice::from_ref(&item)) {
332366
if let Err(e) = command_tx.send(ConnectionCommand::Send(msg)).await {
333367
self.metrics.record_subscription_send_failed();
334368
warn!(
335369
"[SHARD-{}] Failed to send subscription message: {}",
336370
shard_id, e
337371
);
372+
373+
// Rollback: remove subscription from state so retry is possible
374+
{
375+
let mut state = self.state.write();
376+
if let Some(shard) = state.shards.get_mut(shard_id) {
377+
shard.remove_subscription(&item);
378+
}
379+
state.subscription_to_shard.remove(&item);
380+
}
381+
382+
// Update metrics to reflect rollback
383+
let new_count = {
384+
let state = self.state.read();
385+
state.shards.get(shard_id).map(|s| s.subscription_count()).unwrap_or(0)
386+
};
387+
self.metrics.update_shard(shard_id, |s| s.subscription_count = new_count);
388+
338389
return SubscribeResult::SendFailed {
339390
shard_id,
340391
error: e.to_string(),
@@ -345,10 +396,16 @@ impl<H: WebSocketHandler> ShardManager<H> {
345396
SubscribeResult::Success { shard_id }
346397
}
347398

348-
/// Convenience method that returns only the shard IDs of successful subscriptions
399+
/// Subscribe to items and return affected shard IDs.
400+
///
401+
/// This is a convenience method that returns the shard IDs of successful
402+
/// subscriptions. Use [`subscribe`] instead if you need per-item error details.
403+
///
404+
/// # Errors
349405
///
350-
/// This is a simpler API for cases where you don't need per-item error handling.
351-
pub async fn subscribe_simple(&self, items: Vec<H::Subscription>) -> Result<Vec<usize>, Error> {
406+
/// Returns an error only if all subscriptions fail. Partial success returns `Ok`
407+
/// with the shard IDs that were affected.
408+
pub async fn subscribe_all(&self, items: Vec<H::Subscription>) -> Result<Vec<usize>, Error> {
352409
let results = self.subscribe(items).await;
353410
let mut affected_shards_set = HashSet::new();
354411
let mut had_failure = false;

src/metrics.rs

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,18 @@ use parking_lot::RwLock;
22
use std::sync::atomic::{AtomicU64, Ordering};
33
use std::time::{Duration, Instant};
44

5-
/// Metrics for observability
5+
/// Metrics for observability.
66
///
77
/// This struct provides counters and gauges for monitoring WebSocket health.
88
/// Use `snapshot()` to get a point-in-time view of all metrics, or use
99
/// individual getter methods for specific values.
1010
///
11+
/// # Thread Safety
12+
///
13+
/// `Metrics` is `Send + Sync` and all methods are safe to call from multiple
14+
/// tasks concurrently. Counters use atomic operations, and per-shard metrics
15+
/// are protected by `parking_lot::RwLock`.
16+
///
1117
/// # Example
1218
/// ```ignore
1319
/// let metrics = manager.metrics();
@@ -319,24 +325,43 @@ impl Metrics {
319325
}
320326
}
321327

322-
/// A point-in-time snapshot of all metrics
328+
/// A point-in-time snapshot of all metrics.
329+
///
330+
/// Use [`Metrics::snapshot()`] to get a consistent view of all metrics at once.
331+
/// This is the recommended way to export metrics to monitoring systems.
323332
#[derive(Debug, Clone)]
324333
pub struct MetricsSnapshot {
334+
/// Total number of WebSocket connections established (including reconnections)
325335
pub connections_total: u64,
336+
/// Total number of reconnection attempts after disconnection
326337
pub reconnections_total: u64,
338+
/// Total number of WebSocket messages received across all shards
327339
pub messages_received_total: u64,
340+
/// Total number of WebSocket messages sent across all shards
328341
pub messages_sent_total: u64,
342+
/// Total number of errors encountered (connection failures, panics, etc.)
329343
pub errors_total: u64,
344+
/// Total number of WebSocket ping frames sent for health monitoring
330345
pub pings_sent_total: u64,
346+
/// Total number of WebSocket pong frames received
331347
pub pongs_received_total: u64,
348+
/// Total number of health check failures (pong timeouts, data timeouts)
332349
pub health_failures_total: u64,
350+
/// Total number of shard rebalancing operations (new shard creation)
333351
pub rebalances_total: u64,
352+
/// Total number of hot switchover operations initiated
334353
pub hot_switchovers_total: u64,
354+
/// Total number of hot switchover operations that failed
335355
pub hot_switchover_failures_total: u64,
356+
/// Total number of subscription message send failures
336357
pub subscription_send_failures_total: u64,
358+
/// Total number of times the circuit breaker tripped due to consecutive failures
337359
pub circuit_breaker_trips_total: u64,
360+
/// Current number of active (connected) shards
338361
pub active_connections: usize,
362+
/// Current total number of subscriptions across all shards
339363
pub total_subscriptions: usize,
364+
/// Per-shard metrics with detailed connection state
340365
pub shards: Vec<ShardMetrics>,
341366
}
342367

0 commit comments

Comments
 (0)