Skip to content

Commit 26c3d08

Browse files
authored
refactor(hermes): use broadcast channel for api notifications (#1388)
This change removes the manual broadcast implementation to send out API notifications to WS subscribers.
1 parent f31ef9e commit 26c3d08

File tree

5 files changed

+34
-100
lines changed

5 files changed

+34
-100
lines changed

hermes/src/aggregate.rs

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -268,24 +268,17 @@ pub async fn store_update(state: &State, update: Update) -> Result<()> {
268268
match aggregate_state.latest_completed_slot {
269269
None => {
270270
aggregate_state.latest_completed_slot.replace(slot);
271-
state
272-
.api_update_tx
273-
.send(AggregationEvent::New { slot })
274-
.await?;
271+
state.api_update_tx.send(AggregationEvent::New { slot })?;
275272
}
276273
Some(latest) if slot > latest => {
277274
state.prune_removed_keys(message_state_keys).await;
278275
aggregate_state.latest_completed_slot.replace(slot);
279-
state
280-
.api_update_tx
281-
.send(AggregationEvent::New { slot })
282-
.await?;
276+
state.api_update_tx.send(AggregationEvent::New { slot })?;
283277
}
284278
_ => {
285279
state
286280
.api_update_tx
287-
.send(AggregationEvent::OutOfOrder { slot })
288-
.await?;
281+
.send(AggregationEvent::OutOfOrder { slot })?;
289282
}
290283
}
291284

@@ -583,7 +576,7 @@ mod test {
583576
// Check that the update_rx channel has received a message
584577
assert_eq!(
585578
update_rx.recv().await,
586-
Some(AggregationEvent::New { slot: 10 })
579+
Ok(AggregationEvent::New { slot: 10 })
587580
);
588581

589582
// Check the price ids are stored correctly
@@ -708,7 +701,7 @@ mod test {
708701
// Check that the update_rx channel has received a message
709702
assert_eq!(
710703
update_rx.recv().await,
711-
Some(AggregationEvent::New { slot: 10 })
704+
Ok(AggregationEvent::New { slot: 10 })
712705
);
713706

714707
// Check the price ids are stored correctly
@@ -745,7 +738,7 @@ mod test {
745738
// Check that the update_rx channel has received a message
746739
assert_eq!(
747740
update_rx.recv().await,
748-
Some(AggregationEvent::New { slot: 15 })
741+
Ok(AggregationEvent::New { slot: 15 })
749742
);
750743

751744
// Check that price feed 2 does not exist anymore

hermes/src/api.rs

Lines changed: 11 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
use {
2-
self::ws::notify_updates,
32
crate::{
43
aggregate::AggregationEvent,
54
config::RunOptions,
@@ -18,7 +17,7 @@ use {
1817
atomic::Ordering,
1918
Arc,
2019
},
21-
tokio::sync::mpsc::Receiver,
20+
tokio::sync::broadcast::Sender,
2221
tower_http::cors::CorsLayer,
2322
utoipa::OpenApi,
2423
utoipa_swagger_ui::SwaggerUi,
@@ -32,16 +31,18 @@ mod ws;
3231

3332
#[derive(Clone)]
3433
pub struct ApiState {
35-
pub state: Arc<State>,
36-
pub ws: Arc<ws::WsState>,
37-
pub metrics: Arc<metrics_middleware::Metrics>,
34+
pub state: Arc<State>,
35+
pub ws: Arc<ws::WsState>,
36+
pub metrics: Arc<metrics_middleware::Metrics>,
37+
pub update_tx: Sender<AggregationEvent>,
3838
}
3939

4040
impl ApiState {
4141
pub fn new(
4242
state: Arc<State>,
4343
ws_whitelist: Vec<IpNet>,
4444
requester_ip_header_name: String,
45+
update_tx: Sender<AggregationEvent>,
4546
) -> Self {
4647
Self {
4748
metrics: Arc::new(metrics_middleware::Metrics::new(state.clone())),
@@ -51,57 +52,28 @@ impl ApiState {
5152
state.clone(),
5253
)),
5354
state,
55+
update_tx,
5456
}
5557
}
5658
}
5759

58-
#[tracing::instrument(skip(opts, state, update_rx))]
60+
#[tracing::instrument(skip(opts, state, update_tx))]
5961
pub async fn spawn(
6062
opts: RunOptions,
6163
state: Arc<State>,
62-
mut update_rx: Receiver<AggregationEvent>,
64+
update_tx: Sender<AggregationEvent>,
6365
) -> Result<()> {
6466
let state = {
6567
let opts = opts.clone();
6668
ApiState::new(
6769
state,
6870
opts.rpc.ws_whitelist,
6971
opts.rpc.requester_ip_header_name,
72+
update_tx,
7073
)
7174
};
7275

73-
let rpc_server = tokio::spawn(run(opts, state.clone()));
74-
75-
let ws_notifier = tokio::spawn(async move {
76-
let mut interval = tokio::time::interval(std::time::Duration::from_secs(5));
77-
78-
while !crate::SHOULD_EXIT.load(Ordering::Acquire) {
79-
tokio::select! {
80-
update = update_rx.recv() => {
81-
match update {
82-
None => {
83-
// When the received message is None it means the channel has been closed. This
84-
// should never happen as the channel is never closed. As we can't recover from
85-
// this we shut down the application.
86-
tracing::error!("Failed to receive update from store.");
87-
crate::SHOULD_EXIT.store(true, Ordering::Release);
88-
break;
89-
}
90-
Some(event) => {
91-
notify_updates(state.ws.clone(), event).await;
92-
},
93-
}
94-
},
95-
_ = interval.tick() => {}
96-
}
97-
}
98-
99-
tracing::info!("Shutting down Websocket notifier...")
100-
});
101-
102-
103-
let _ = tokio::join!(ws_notifier, rpc_server);
104-
Ok(())
76+
run(opts, state.clone()).await
10577
}
10678

10779
/// This method provides a background service that responds to REST requests

hermes/src/api/ws.rs

Lines changed: 11 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@ use {
2626
http::HeaderMap,
2727
response::IntoResponse,
2828
},
29-
dashmap::DashMap,
3029
futures::{
31-
future::join_all,
3230
stream::{
3331
SplitSink,
3432
SplitStream,
@@ -71,11 +69,10 @@ use {
7169
},
7270
time::Duration,
7371
},
74-
tokio::sync::mpsc,
72+
tokio::sync::broadcast::Receiver,
7573
};
7674

7775
const PING_INTERVAL_DURATION: Duration = Duration::from_secs(30);
78-
const NOTIFICATIONS_CHAN_LEN: usize = 1000;
7976
const MAX_CLIENT_MESSAGE_SIZE: usize = 100 * 1024; // 100 KiB
8077

8178
/// The maximum number of bytes that can be sent per second per IP address.
@@ -139,7 +136,6 @@ impl Metrics {
139136

140137
pub struct WsState {
141138
pub subscriber_counter: AtomicUsize,
142-
pub subscribers: DashMap<SubscriberId, mpsc::Sender<AggregationEvent>>,
143139
pub bytes_limit_whitelist: Vec<IpNet>,
144140
pub rate_limiter: DefaultKeyedRateLimiter<IpAddr>,
145141
pub requester_ip_header_name: String,
@@ -150,7 +146,6 @@ impl WsState {
150146
pub fn new(whitelist: Vec<IpNet>, requester_ip_header_name: String, state: Arc<State>) -> Self {
151147
Self {
152148
subscriber_counter: AtomicUsize::new(0),
153-
subscribers: DashMap::new(),
154149
rate_limiter: RateLimiter::dashmap(Quota::per_second(nonzero!(
155150
BYTES_LIMIT_PER_IP_PER_SECOND
156151
))),
@@ -220,6 +215,11 @@ async fn websocket_handler(
220215
subscriber_ip: Option<IpAddr>,
221216
) {
222217
let ws_state = state.ws.clone();
218+
219+
// Retain the recent rate limit data for the IP addresses to
220+
// prevent the rate limiter size from growing indefinitely.
221+
ws_state.rate_limiter.retain_recent();
222+
223223
let id = ws_state.subscriber_counter.fetch_add(1, Ordering::SeqCst);
224224

225225
tracing::debug!(id, ?subscriber_ip, "New Websocket Connection");
@@ -232,7 +232,7 @@ async fn websocket_handler(
232232
})
233233
.inc();
234234

235-
let (notify_sender, notify_receiver) = mpsc::channel(NOTIFICATIONS_CHAN_LEN);
235+
let notify_receiver = state.update_tx.subscribe();
236236
let (sender, receiver) = stream.split();
237237
let mut subscriber = Subscriber::new(
238238
id,
@@ -244,7 +244,6 @@ async fn websocket_handler(
244244
sender,
245245
);
246246

247-
ws_state.subscribers.insert(id, notify_sender);
248247
subscriber.run().await;
249248
}
250249

@@ -258,7 +257,7 @@ pub struct Subscriber {
258257
closed: bool,
259258
store: Arc<State>,
260259
ws_state: Arc<WsState>,
261-
notify_receiver: mpsc::Receiver<AggregationEvent>,
260+
notify_receiver: Receiver<AggregationEvent>,
262261
receiver: SplitStream<WebSocket>,
263262
sender: SplitSink<WebSocket, Message>,
264263
price_feeds_with_config: HashMap<PriceIdentifier, PriceFeedClientConfig>,
@@ -273,7 +272,7 @@ impl Subscriber {
273272
ip_addr: Option<IpAddr>,
274273
store: Arc<State>,
275274
ws_state: Arc<WsState>,
276-
notify_receiver: mpsc::Receiver<AggregationEvent>,
275+
notify_receiver: Receiver<AggregationEvent>,
277276
receiver: SplitStream<WebSocket>,
278277
sender: SplitSink<WebSocket, Message>,
279278
) -> Self {
@@ -307,8 +306,8 @@ impl Subscriber {
307306
tokio::select! {
308307
maybe_update_feeds_event = self.notify_receiver.recv() => {
309308
match maybe_update_feeds_event {
310-
Some(event) => self.handle_price_feeds_update(event).await,
311-
None => Err(anyhow!("Update channel closed. This should never happen. Closing connection."))
309+
Ok(event) => self.handle_price_feeds_update(event).await,
310+
Err(e) => Err(anyhow!("Failed to receive update from store: {:?}", e)),
312311
}
313312
},
314313
maybe_message_or_err = self.receiver.next() => {
@@ -610,33 +609,3 @@ impl Subscriber {
610609
Ok(())
611610
}
612611
}
613-
614-
pub async fn notify_updates(ws_state: Arc<WsState>, event: AggregationEvent) {
615-
let closed_subscribers: Vec<Option<SubscriberId>> =
616-
join_all(ws_state.subscribers.iter_mut().map(|subscriber| {
617-
let event = event.clone();
618-
async move {
619-
match subscriber.send(event).await {
620-
Ok(_) => None,
621-
Err(_) => {
622-
// An error here indicates the channel is closed (which may happen either when the
623-
// client has sent Message::Close or some other abrupt disconnection). We remove
624-
// subscribers only when send fails so we can handle closure only once when we are
625-
// able to see send() fail.
626-
Some(*subscriber.key())
627-
}
628-
}
629-
}
630-
}))
631-
.await;
632-
633-
// Remove closed_subscribers from ws_state
634-
closed_subscribers.into_iter().for_each(|id| {
635-
if let Some(id) = id {
636-
ws_state.subscribers.remove(&id);
637-
}
638-
});
639-
640-
// Clean the bytes limiting dictionary
641-
ws_state.rate_limiter.retain_recent();
642-
}

hermes/src/main.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ async fn init() -> Result<()> {
4444
config::Options::Run(opts) => {
4545
tracing::info!("Starting hermes service...");
4646

47-
// The update channel is used to send store update notifications to the public API.
48-
let (update_tx, update_rx) = tokio::sync::mpsc::channel(1000);
47+
// The update broadcast channel is used to send store update notifications to the public API.
48+
let (update_tx, _) = tokio::sync::broadcast::channel(1000);
4949

5050
// Initialize a cache store with a 1000 element circular buffer.
5151
let store = State::new(update_tx.clone(), 1000, opts.benchmarks.endpoint.clone());
@@ -64,7 +64,7 @@ async fn init() -> Result<()> {
6464
Box::pin(spawn(network::wormhole::spawn(opts.clone(), store.clone()))),
6565
Box::pin(spawn(network::pythnet::spawn(opts.clone(), store.clone()))),
6666
Box::pin(spawn(metrics_server::run(opts.clone(), store.clone()))),
67-
Box::pin(spawn(api::spawn(opts.clone(), store.clone(), update_rx))),
67+
Box::pin(spawn(api::spawn(opts.clone(), store.clone(), update_tx))),
6868
])
6969
.await;
7070

hermes/src/state.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use {
2020
sync::Arc,
2121
},
2222
tokio::sync::{
23-
mpsc::Sender,
23+
broadcast::Sender,
2424
RwLock,
2525
},
2626
};
@@ -81,11 +81,11 @@ pub mod test {
8181
use {
8282
super::*,
8383
crate::network::wormhole::update_guardian_set,
84-
tokio::sync::mpsc::Receiver,
84+
tokio::sync::broadcast::Receiver,
8585
};
8686

8787
pub async fn setup_state(cache_size: u64) -> (Arc<State>, Receiver<AggregationEvent>) {
88-
let (update_tx, update_rx) = tokio::sync::mpsc::channel(1000);
88+
let (update_tx, update_rx) = tokio::sync::broadcast::channel(1000);
8989
let state = State::new(update_tx, cache_size, None);
9090

9191
// Add an initial guardian set with public key 0

0 commit comments

Comments
 (0)