Skip to content

Commit d11216f

Browse files
committed
feat(hermes): use ip from request headers for ratelimiting
1 parent b158f28 commit d11216f

File tree

5 files changed

+91
-66
lines changed

5 files changed

+91
-66
lines changed

hermes/Cargo.lock

Lines changed: 8 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

hermes/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "hermes"
3-
version = "0.3.2"
3+
version = "0.3.3"
44
description = "Hermes is an agent that provides Verified Prices from the Pythnet Pyth Oracle."
55
edition = "2021"
66

hermes/src/api.rs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,9 @@ use {
1313
},
1414
ipnet::IpNet,
1515
serde_qs::axum::QsQueryConfig,
16-
std::{
17-
net::SocketAddr,
18-
sync::{
19-
atomic::Ordering,
20-
Arc,
21-
},
16+
std::sync::{
17+
atomic::Ordering,
18+
Arc,
2219
},
2320
tokio::{
2421
signal,
@@ -40,10 +37,14 @@ pub struct ApiState {
4037
}
4138

4239
impl ApiState {
43-
pub fn new(state: Arc<State>, ws_whitelist: Vec<IpNet>) -> Self {
40+
pub fn new(
41+
state: Arc<State>,
42+
ws_whitelist: Vec<IpNet>,
43+
requester_ip_header_name: String,
44+
) -> Self {
4445
Self {
4546
state,
46-
ws: Arc::new(ws::WsState::new(ws_whitelist)),
47+
ws: Arc::new(ws::WsState::new(ws_whitelist, requester_ip_header_name)),
4748
}
4849
}
4950
}
@@ -88,7 +89,11 @@ pub async fn run(
8889
)]
8990
struct ApiDoc;
9091

91-
let state = ApiState::new(state, opts.rpc.ws_whitelist);
92+
let state = ApiState::new(
93+
state,
94+
opts.rpc.ws_whitelist,
95+
opts.rpc.requester_ip_header_name,
96+
);
9297

9398
// Initialize Axum Router. Note the type here is a `Router<State>` due to the use of the
9499
// `with_state` method which replaces `Body` with `State` in the type signature.
@@ -135,7 +140,7 @@ pub async fn run(
135140
// Binds the axum's server to the configured address and port. This is a blocking call and will
136141
// not return until the server is shutdown.
137142
axum::Server::try_bind(&opts.rpc.addr)?
138-
.serve(app.into_make_service_with_connect_info::<SocketAddr>())
143+
.serve(app.into_make_service())
139144
.with_graceful_shutdown(async {
140145
// Ignore Ctrl+C errors, either way we need to shut down. The main Ctrl+C handler
141146
// should also have triggered so we will let that one print the shutdown warning.

hermes/src/api/ws.rs

Lines changed: 60 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ use {
2121
WebSocket,
2222
WebSocketUpgrade,
2323
},
24-
ConnectInfo,
2524
State as AxumState,
2625
},
26+
http::HeaderMap,
2727
response::IntoResponse,
2828
},
2929
dashmap::DashMap,
@@ -50,10 +50,7 @@ use {
5050
},
5151
std::{
5252
collections::HashMap,
53-
net::{
54-
IpAddr,
55-
SocketAddr,
56-
},
53+
net::IpAddr,
5754
num::NonZeroU32,
5855
sync::{
5956
atomic::{
@@ -83,21 +80,23 @@ pub struct PriceFeedClientConfig {
8380
}
8481

8582
pub struct WsState {
86-
pub subscriber_counter: AtomicUsize,
87-
pub subscribers: DashMap<SubscriberId, mpsc::Sender<AggregationEvent>>,
88-
pub bytes_limit_whitelist: Vec<IpNet>,
89-
pub rate_limiter: DefaultKeyedRateLimiter<IpAddr>,
83+
pub subscriber_counter: AtomicUsize,
84+
pub subscribers: DashMap<SubscriberId, mpsc::Sender<AggregationEvent>>,
85+
pub bytes_limit_whitelist: Vec<IpNet>,
86+
pub rate_limiter: DefaultKeyedRateLimiter<IpAddr>,
87+
pub requester_ip_header_name: String,
9088
}
9189

9290
impl WsState {
93-
pub fn new(whitelist: Vec<IpNet>) -> Self {
91+
pub fn new(whitelist: Vec<IpNet>, requester_ip_header_name: String) -> Self {
9492
Self {
95-
subscriber_counter: AtomicUsize::new(0),
96-
subscribers: DashMap::new(),
97-
rate_limiter: RateLimiter::dashmap(Quota::per_second(nonzero!(
93+
subscriber_counter: AtomicUsize::new(0),
94+
subscribers: DashMap::new(),
95+
rate_limiter: RateLimiter::dashmap(Quota::per_second(nonzero!(
9896
BYTES_LIMIT_PER_IP_PER_SECOND
9997
))),
10098
bytes_limit_whitelist: whitelist,
99+
requester_ip_header_name,
101100
}
102101
}
103102
}
@@ -142,23 +141,33 @@ enum ServerResponseMessage {
142141
pub async fn ws_route_handler(
143142
ws: WebSocketUpgrade,
144143
AxumState(state): AxumState<super::ApiState>,
145-
ConnectInfo(addr): ConnectInfo<SocketAddr>,
144+
headers: HeaderMap,
146145
) -> impl IntoResponse {
146+
let requester_ip = headers
147+
.get(state.ws.requester_ip_header_name.as_str())
148+
.and_then(|value| value.to_str().ok())
149+
.and_then(|value| value.split(',').next()) // Only take the first ip if there are multiple
150+
.and_then(|value| value.parse().ok());
151+
147152
ws.max_message_size(MAX_CLIENT_MESSAGE_SIZE)
148-
.on_upgrade(move |socket| websocket_handler(socket, state, addr))
153+
.on_upgrade(move |socket| websocket_handler(socket, state, requester_ip))
149154
}
150155

151-
#[tracing::instrument(skip(stream, state, addr))]
152-
async fn websocket_handler(stream: WebSocket, state: super::ApiState, addr: SocketAddr) {
156+
#[tracing::instrument(skip(stream, state, subscriber_ip))]
157+
async fn websocket_handler(
158+
stream: WebSocket,
159+
state: super::ApiState,
160+
subscriber_ip: Option<IpAddr>,
161+
) {
153162
let ws_state = state.ws.clone();
154163
let id = ws_state.subscriber_counter.fetch_add(1, Ordering::SeqCst);
155-
tracing::debug!(id, %addr, "New Websocket Connection");
164+
tracing::debug!(id, ?subscriber_ip, "New Websocket Connection");
156165

157166
let (notify_sender, notify_receiver) = mpsc::channel(NOTIFICATIONS_CHAN_LEN);
158167
let (sender, receiver) = stream.split();
159168
let mut subscriber = Subscriber::new(
160169
id,
161-
addr.ip(),
170+
subscriber_ip,
162171
state.state.clone(),
163172
state.ws.clone(),
164173
notify_receiver,
@@ -176,7 +185,7 @@ pub type SubscriberId = usize;
176185
/// It listens to the store for updates and sends them to the client.
177186
pub struct Subscriber {
178187
id: SubscriberId,
179-
ip_addr: IpAddr,
188+
ip_addr: Option<IpAddr>,
180189
closed: bool,
181190
store: Arc<State>,
182191
ws_state: Arc<WsState>,
@@ -191,7 +200,7 @@ pub struct Subscriber {
191200
impl Subscriber {
192201
pub fn new(
193202
id: SubscriberId,
194-
ip_addr: IpAddr,
203+
ip_addr: Option<IpAddr>,
195204
store: Arc<State>,
196205
ws_state: Arc<WsState>,
197206
notify_receiver: mpsc::Receiver<AggregationEvent>,
@@ -291,32 +300,36 @@ impl Subscriber {
291300
})?;
292301

293302
// Close the connection if rate limit is exceeded and the ip is not whitelisted.
294-
if !self
295-
.ws_state
296-
.bytes_limit_whitelist
297-
.iter()
298-
.any(|ip_net| ip_net.contains(&self.ip_addr))
299-
&& self.ws_state.rate_limiter.check_key_n(
300-
&self.ip_addr,
301-
NonZeroU32::new(message.len().try_into()?).ok_or(anyhow!("Empty message"))?,
302-
) != Ok(Ok(()))
303-
{
304-
tracing::info!(
305-
self.id,
306-
ip = %self.ip_addr,
307-
"Rate limit exceeded. Closing connection.",
308-
);
309-
self.sender
310-
.send(
311-
serde_json::to_string(&ServerResponseMessage::Err {
312-
error: "Rate limit exceeded".to_string(),
313-
})?
314-
.into(),
315-
)
316-
.await?;
317-
self.sender.close().await?;
318-
self.closed = true;
319-
return Ok(());
303+
// If the ip address is None no rate limiting is applied.
304+
if let Some(ip_addr) = self.ip_addr {
305+
if !self
306+
.ws_state
307+
.bytes_limit_whitelist
308+
.iter()
309+
.any(|ip_net| ip_net.contains(&ip_addr))
310+
&& self.ws_state.rate_limiter.check_key_n(
311+
&ip_addr,
312+
NonZeroU32::new(message.len().try_into()?)
313+
.ok_or(anyhow!("Empty message"))?,
314+
) != Ok(Ok(()))
315+
{
316+
tracing::info!(
317+
self.id,
318+
ip = %ip_addr,
319+
"Rate limit exceeded. Closing connection.",
320+
);
321+
self.sender
322+
.send(
323+
serde_json::to_string(&ServerResponseMessage::Err {
324+
error: "Rate limit exceeded".to_string(),
325+
})?
326+
.into(),
327+
)
328+
.await?;
329+
self.sender.close().await?;
330+
self.closed = true;
331+
return Ok(());
332+
}
320333
}
321334

322335
// `sender.feed` buffers a message to the client but does not flush it, so we can send

hermes/src/config/rpc.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use {
55
};
66

77
const DEFAULT_RPC_ADDR: &str = "127.0.0.1:33999";
8+
const DEFAULT_RPC_REQUESTER_IP_HEADER_NAME: &str = "X-Forwarded-For";
89

910
#[derive(Args, Clone, Debug)]
1011
#[command(next_help_heading = "RPC Options")]
@@ -21,4 +22,10 @@ pub struct Options {
2122
#[arg(value_delimiter = ',')]
2223
#[arg(env = "RPC_WS_WHITELIST")]
2324
pub ws_whitelist: Vec<IpNet>,
25+
26+
/// Header name (case insensitive) to fetch requester IP from.
27+
#[arg(long = "rpc-requester-ip-header-name")]
28+
#[arg(default_value = DEFAULT_RPC_REQUESTER_IP_HEADER_NAME)]
29+
#[arg(env = "RPC_REQUESTER_IP_HEADER_NAME")]
30+
pub requester_ip_header_name: String,
2431
}

0 commit comments

Comments
 (0)