Skip to content

Commit df585e4

Browse files
committed
refactor(hermes): watch channel to simplify shutdowns
1 parent f18f1c8 commit df585e4

File tree

7 files changed

+81
-109
lines changed

7 files changed

+81
-109
lines changed

hermes/src/api.rs

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@ use {
1313
},
1414
ipnet::IpNet,
1515
serde_qs::axum::QsQueryConfig,
16-
std::sync::{
17-
atomic::Ordering,
18-
Arc,
19-
},
16+
std::sync::Arc,
2017
tokio::sync::broadcast::Sender,
2118
tower_http::cors::CorsLayer,
2219
utoipa::OpenApi,
@@ -159,10 +156,7 @@ pub async fn run(opts: RunOptions, state: ApiState) -> Result<()> {
159156
axum::Server::try_bind(&opts.rpc.listen_addr)?
160157
.serve(app.into_make_service())
161158
.with_graceful_shutdown(async {
162-
while !crate::SHOULD_EXIT.load(Ordering::Acquire) {
163-
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
164-
}
165-
159+
let _ = crate::EXIT.subscribe().changed().await;
166160
tracing::info!("Shutting down RPC server...");
167161
})
168162
.await?;

hermes/src/api/ws.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,10 @@ use {
6969
},
7070
time::Duration,
7171
},
72-
tokio::sync::broadcast::Receiver,
72+
tokio::sync::{
73+
broadcast::Receiver,
74+
watch,
75+
},
7376
};
7477

7578
const PING_INTERVAL_DURATION: Duration = Duration::from_secs(30);
@@ -262,7 +265,7 @@ pub struct Subscriber {
262265
sender: SplitSink<WebSocket, Message>,
263266
price_feeds_with_config: HashMap<PriceIdentifier, PriceFeedClientConfig>,
264267
ping_interval: tokio::time::Interval,
265-
exit_check_interval: tokio::time::Interval,
268+
exit: watch::Receiver<bool>,
266269
responded_to_ping: bool,
267270
}
268271

@@ -287,7 +290,7 @@ impl Subscriber {
287290
sender,
288291
price_feeds_with_config: HashMap::new(),
289292
ping_interval: tokio::time::interval(PING_INTERVAL_DURATION),
290-
exit_check_interval: tokio::time::interval(Duration::from_secs(5)),
293+
exit: crate::EXIT.subscribe(),
291294
responded_to_ping: true, // We start with true so we don't close the connection immediately
292295
}
293296
}
@@ -332,13 +335,10 @@ impl Subscriber {
332335
self.sender.send(Message::Ping(vec![])).await?;
333336
Ok(())
334337
},
335-
_ = self.exit_check_interval.tick() => {
336-
if crate::SHOULD_EXIT.load(Ordering::Acquire) {
337-
self.sender.close().await?;
338-
self.closed = true;
339-
return Err(anyhow!("Application is shutting down. Closing connection."));
340-
}
341-
Ok(())
338+
_ = self.exit.changed() => {
339+
self.sender.close().await?;
340+
self.closed = true;
341+
return Err(anyhow!("Application is shutting down. Closing connection."));
342342
}
343343
}
344344
}

hermes/src/main.rs

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@ use {
88
Parser,
99
},
1010
futures::future::join_all,
11+
lazy_static::lazy_static,
1112
state::State,
12-
std::{
13-
io::IsTerminal,
14-
sync::atomic::AtomicBool,
13+
std::io::IsTerminal,
14+
tokio::{
15+
spawn,
16+
sync::watch,
1517
},
16-
tokio::spawn,
1718
};
1819

1920
mod aggregate;
@@ -25,13 +26,18 @@ mod price_feeds_metadata;
2526
mod serde;
2627
mod state;
2728

28-
// A static exit flag to indicate to running threads that we're shutting down. This is used to
29-
// gracefully shutdown the application.
30-
//
31-
// NOTE: A more idiomatic approach would be to use a tokio::sync::broadcast channel, and to send a
32-
// shutdown signal to all running tasks. However, this is a bit more complicated to implement and
33-
// we don't rely on global state for anything else.
34-
pub(crate) static SHOULD_EXIT: AtomicBool = AtomicBool::new(false);
29+
lazy_static! {
30+
/// A static exit flag to indicate to running threads that we're shutting down. This is used to
31+
/// gracefully shutdown the application.
32+
///
33+
/// We make this global based on the fact the:
34+
/// - The `Sender` side does not rely on any async runtime.
35+
/// - Exit logic doesn't really require carefully threading this value through the app.
36+
/// - The `Receiver` side of a watch channel performs the detection based on if the change
37+
/// happened after the subscribe, so it means all listeners should always be notified
38+
/// currectly.
39+
pub static ref EXIT: watch::Sender<bool> = watch::channel(false).0;
40+
}
3541

3642
/// Initialize the Application. This can be invoked either by real main, or by the Geyser plugin.
3743
#[tracing::instrument]
@@ -55,7 +61,7 @@ async fn init() -> Result<()> {
5561
tracing::info!("Registered shutdown signal handler...");
5662
tokio::signal::ctrl_c().await.unwrap();
5763
tracing::info!("Shut down signal received, waiting for tasks...");
58-
SHOULD_EXIT.store(true, std::sync::atomic::Ordering::Release);
64+
let _ = EXIT.send(true);
5965
});
6066

6167
// Spawn all worker tasks, and wait for all to complete (which will happen if a shutdown

hermes/src/metrics_server.rs

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,7 @@ use {
1616
Router,
1717
},
1818
prometheus_client::encoding::text::encode,
19-
std::sync::{
20-
atomic::Ordering,
21-
Arc,
22-
},
19+
std::sync::Arc,
2320
};
2421

2522

@@ -37,10 +34,7 @@ pub async fn run(opts: RunOptions, state: Arc<AppState>) -> Result<()> {
3734
axum::Server::try_bind(&opts.metrics.server_listen_addr)?
3835
.serve(app.into_make_service())
3936
.with_graceful_shutdown(async {
40-
while !crate::SHOULD_EXIT.load(Ordering::Acquire) {
41-
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
42-
}
43-
37+
let _ = crate::EXIT.subscribe().changed().await;
4438
tracing::info!("Shutting down metrics server...");
4539
})
4640
.await?;

hermes/src/network/pythnet.rs

Lines changed: 40 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,7 @@ use {
5858
},
5959
std::{
6060
collections::BTreeMap,
61-
sync::{
62-
atomic::Ordering,
63-
Arc,
64-
},
61+
sync::Arc,
6562
time::Duration,
6663
},
6764
tokio::time::Instant,
@@ -160,7 +157,7 @@ pub async fn run(store: Arc<State>, pythnet_ws_endpoint: String) -> Result<()> {
160157
.program_subscribe(&system_program::id(), Some(config))
161158
.await?;
162159

163-
while !crate::SHOULD_EXIT.load(Ordering::Acquire) {
160+
loop {
164161
match notif.next().await {
165162
Some(update) => {
166163
let account: Account = match update.value.account.decode() {
@@ -213,8 +210,6 @@ pub async fn run(store: Arc<State>, pythnet_ws_endpoint: String) -> Result<()> {
213210
}
214211
}
215212
}
216-
217-
Ok(())
218213
}
219214

220215
/// Fetch existing GuardianSet accounts from Wormhole.
@@ -281,79 +276,69 @@ pub async fn spawn(opts: RunOptions, state: Arc<State>) -> Result<()> {
281276
let task_listener = {
282277
let store = state.clone();
283278
let pythnet_ws_endpoint = opts.pythnet.ws_addr.clone();
279+
let mut exit = crate::EXIT.subscribe();
284280
tokio::spawn(async move {
285-
while !crate::SHOULD_EXIT.load(Ordering::Acquire) {
281+
loop {
286282
let current_time = Instant::now();
287-
288-
if let Err(ref e) = run(store.clone(), pythnet_ws_endpoint.clone()).await {
289-
tracing::error!(error = ?e, "Error in Pythnet network listener.");
290-
if current_time.elapsed() < Duration::from_secs(30) {
291-
tracing::error!("Pythnet listener restarting too quickly. Sleep 1s.");
292-
tokio::time::sleep(Duration::from_secs(1)).await;
283+
tokio::select! {
284+
_ = exit.changed() => break,
285+
Err(err) = run(store.clone(), pythnet_ws_endpoint.clone()) => {
286+
tracing::error!(error = ?err, "Error in Pythnet network listener.");
287+
if current_time.elapsed() < Duration::from_secs(30) {
288+
tracing::error!("Pythnet listener restarting too quickly. Sleep 1s.");
289+
tokio::time::sleep(Duration::from_secs(1)).await;
290+
}
293291
}
294292
}
295293
}
296-
297294
tracing::info!("Shutting down Pythnet listener...");
298295
})
299296
};
300297

301298
let task_guardian_watcher = {
302299
let store = state.clone();
303300
let pythnet_http_endpoint = opts.pythnet.http_addr.clone();
301+
let mut exit = crate::EXIT.subscribe();
304302
tokio::spawn(async move {
305-
while !crate::SHOULD_EXIT.load(Ordering::Acquire) {
306-
// Poll for new guardian sets every 60 seconds. We use a short wait time so we can
307-
// properly exit if a quit signal was received. This isn't a perfect solution, but
308-
// it's good enough for now.
309-
for _ in 0..60 {
310-
if crate::SHOULD_EXIT.load(Ordering::Acquire) {
311-
break;
312-
}
313-
tokio::time::sleep(Duration::from_secs(1)).await;
314-
}
315-
316-
match fetch_existing_guardian_sets(
317-
store.clone(),
318-
pythnet_http_endpoint.clone(),
319-
opts.wormhole.contract_addr,
320-
)
321-
.await
322-
{
323-
Ok(_) => {}
324-
Err(err) => {
325-
tracing::error!(error = ?err, "Failed to poll for new guardian sets.")
303+
loop {
304+
tokio::select! {
305+
_ = exit.changed() => break,
306+
_ = tokio::time::sleep(Duration::from_secs(60)) => {
307+
if let Err(err) = fetch_existing_guardian_sets(
308+
store.clone(),
309+
pythnet_http_endpoint.clone(),
310+
opts.wormhole.contract_addr,
311+
)
312+
.await
313+
{
314+
tracing::error!(error = ?err, "Failed to poll for new guardian sets.")
315+
}
326316
}
327317
}
328318
}
329-
330319
tracing::info!("Shutting down Pythnet guardian set poller...");
331320
})
332321
};
333322

334323

335324
let task_price_feeds_metadata_updater = {
336325
let price_feeds_state = state.clone();
326+
let mut exit = crate::EXIT.subscribe();
337327
tokio::spawn(async move {
338-
while !crate::SHOULD_EXIT.load(Ordering::Acquire) {
339-
if let Err(e) = fetch_and_store_price_feeds_metadata(
340-
price_feeds_state.as_ref(),
341-
&opts.pythnet.mapping_addr,
342-
&rpc_client,
343-
)
344-
.await
345-
{
346-
tracing::error!("Error in fetching and storing price feeds metadata: {}", e);
347-
}
348-
// This loop with a sleep interval of 1 second allows the task to check for an exit signal at a
349-
// fine-grained interval. Instead of sleeping directly for the entire `price_feeds_update_interval`,
350-
// which could delay the response to an exit signal, this approach ensures the task can exit promptly
351-
// if `crate::SHOULD_EXIT` is set, enhancing the responsiveness of the service to shutdown requests.
352-
for _ in 0..DEFAULT_PRICE_FEEDS_CACHE_UPDATE_INTERVAL {
353-
if crate::SHOULD_EXIT.load(Ordering::Acquire) {
354-
break;
328+
loop {
329+
tokio::select! {
330+
_ = exit.changed() => break,
331+
_ = tokio::time::sleep(Duration::from_secs(DEFAULT_PRICE_FEEDS_CACHE_UPDATE_INTERVAL)) => {
332+
if let Err(e) = fetch_and_store_price_feeds_metadata(
333+
price_feeds_state.as_ref(),
334+
&opts.pythnet.mapping_addr,
335+
&rpc_client,
336+
)
337+
.await
338+
{
339+
tracing::error!("Error in fetching and storing price feeds metadata: {}", e);
340+
}
355341
}
356-
tokio::time::sleep(Duration::from_secs(1)).await;
357342
}
358343
}
359344
})

hermes/src/network/wormhole.rs

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,7 @@ use {
4343
Digest,
4444
Keccak256,
4545
},
46-
std::sync::{
47-
atomic::Ordering,
48-
Arc,
49-
},
46+
std::sync::Arc,
5047
tonic::Request,
5148
wormhole_sdk::{
5249
vaa::{
@@ -153,16 +150,16 @@ mod proto {
153150
// Launches the Wormhole gRPC service.
154151
#[tracing::instrument(skip(opts, state))]
155152
pub async fn spawn(opts: RunOptions, state: Arc<State>) -> Result<()> {
156-
while !crate::SHOULD_EXIT.load(Ordering::Acquire) {
157-
if let Err(e) = run(opts.clone(), state.clone()).await {
158-
tracing::error!(error = ?e, "Wormhole gRPC service failed.");
153+
let mut exit = crate::EXIT.subscribe();
154+
loop {
155+
tokio::select! {
156+
_ = exit.changed() => break,
157+
Err(err) = run(opts.clone(), state.clone()) => {
158+
tracing::error!(error = ?err, "Wormhole gRPC service failed.");
159+
}
159160
}
160-
161-
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
162161
}
163-
164162
tracing::info!("Shutting down Wormhole gRPC service...");
165-
166163
Ok(())
167164
}
168165

@@ -182,10 +179,6 @@ async fn run(opts: RunOptions, state: Arc<State>) -> Result<()> {
182179
.into_inner();
183180

184181
while let Some(Ok(message)) = stream.next().await {
185-
if crate::SHOULD_EXIT.load(Ordering::Acquire) {
186-
return Ok(());
187-
}
188-
189182
if let Err(e) = process_message(state.clone(), message.vaa_bytes).await {
190183
tracing::debug!(error = ?e, "Skipped VAA.");
191184
}

hermes/src/price_feeds_metadata.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use {
99
anyhow::Result,
1010
};
1111

12-
pub const DEFAULT_PRICE_FEEDS_CACHE_UPDATE_INTERVAL: u16 = 600;
12+
pub const DEFAULT_PRICE_FEEDS_CACHE_UPDATE_INTERVAL: u64 = 600;
1313

1414
pub async fn retrieve_price_feeds_metadata(state: &State) -> Result<Vec<PriceFeedMetadata>> {
1515
let price_feeds_metadata = state.price_feeds_metadata.read().await;

0 commit comments

Comments
 (0)