Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions apps/hermes/server/src/api/rest/v2/sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@ use {
pyth_sdk::PriceIdentifier,
serde::Deserialize,
serde_qs::axum::QsQuery,
std::convert::Infallible,
tokio::sync::broadcast,
std::{convert::Infallible, time::Duration},
tokio::{sync::broadcast, time::Instant},
tokio_stream::{wrappers::BroadcastStream, StreamExt as _},
utoipa::IntoParams,
};

// Constants
const MAX_CONNECTION_DURATION: Duration = Duration::from_secs(24 * 60 * 60); // 24 hours

#[derive(Debug, Deserialize, IntoParams)]
#[into_params(parameter_in = Query)]
pub struct StreamPriceUpdatesQueryParams {
Expand Down Expand Up @@ -93,7 +96,14 @@ where
// Convert the broadcast receiver into a Stream
let stream = BroadcastStream::new(update_rx);

// Set connection deadline
let connection_deadline = Instant::now() + MAX_CONNECTION_DURATION;

let sse_stream = stream
.take_while(move |_| {
let now = Instant::now();
now < connection_deadline
})
.then(move |message| {
let state_clone = state.clone(); // Clone again to use inside the async block
let price_ids_clone = price_ids.clone(); // Clone again for use inside the async block
Expand Down Expand Up @@ -122,7 +132,12 @@ where
}
}
})
.filter_map(|x| x);
.filter_map(|x| x)
.chain(futures::stream::once(async {
Ok(Event::default()
.event("error")
.data("Connection timeout reached (24h)"))
}));

Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
}
Expand Down
28 changes: 27 additions & 1 deletion apps/hermes/server/src/api/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,15 @@ use {
},
time::Duration,
},
tokio::sync::{broadcast::Receiver, watch},
tokio::{
sync::{broadcast::Receiver, watch},
time::Instant,
},
};

const PING_INTERVAL_DURATION: Duration = Duration::from_secs(30);
const MAX_CLIENT_MESSAGE_SIZE: usize = 100 * 1024; // 100 KiB
const MAX_CONNECTION_DURATION: Duration = Duration::from_secs(24 * 60 * 60); // 24 hours

/// The maximum number of bytes that can be sent per second per IP address.
/// If the limit is exceeded, the connection is closed.
Expand Down Expand Up @@ -252,6 +256,7 @@ pub struct Subscriber<S> {
sender: SplitSink<WebSocket, Message>,
price_feeds_with_config: HashMap<PriceIdentifier, PriceFeedClientConfig>,
ping_interval: tokio::time::Interval,
connection_deadline: Instant,
exit: watch::Receiver<bool>,
responded_to_ping: bool,
}
Expand Down Expand Up @@ -280,6 +285,7 @@ where
sender,
price_feeds_with_config: HashMap::new(),
ping_interval: tokio::time::interval(PING_INTERVAL_DURATION),
connection_deadline: Instant::now() + MAX_CONNECTION_DURATION,
exit: crate::EXIT.subscribe(),
responded_to_ping: true, // We start with true so we don't close the connection immediately
}
Expand Down Expand Up @@ -325,6 +331,26 @@ where
self.sender.send(Message::Ping(vec![])).await?;
Ok(())
},
_ = tokio::time::sleep_until(self.connection_deadline) => {
tracing::info!(
id = self.id,
ip = ?self.ip_addr,
"Connection timeout reached (24h). Closing connection.",
);
self.sender
.send(
serde_json::to_string(&ServerMessage::Response(
ServerResponseMessage::Err {
error: "Connection timeout reached (24h)".to_string(),
},
))?
.into(),
)
.await?;
self.sender.close().await?;
self.closed = true;
Ok(())
},
_ = self.exit.changed() => {
self.sender.close().await?;
self.closed = true;
Expand Down
Loading