21
21
WebSocket ,
22
22
WebSocketUpgrade ,
23
23
} ,
24
- ConnectInfo ,
25
24
State as AxumState ,
26
25
} ,
26
+ http:: HeaderMap ,
27
27
response:: IntoResponse ,
28
28
} ,
29
29
dashmap:: DashMap ,
50
50
} ,
51
51
std:: {
52
52
collections:: HashMap ,
53
- net:: {
54
- IpAddr ,
55
- SocketAddr ,
56
- } ,
53
+ net:: IpAddr ,
57
54
num:: NonZeroU32 ,
58
55
sync:: {
59
56
atomic:: {
@@ -83,21 +80,23 @@ pub struct PriceFeedClientConfig {
83
80
}
84
81
85
82
pub struct WsState {
86
- pub subscriber_counter : AtomicUsize ,
87
- pub subscribers : DashMap < SubscriberId , mpsc:: Sender < AggregationEvent > > ,
88
- pub bytes_limit_whitelist : Vec < IpNet > ,
89
- pub rate_limiter : DefaultKeyedRateLimiter < IpAddr > ,
83
+ pub subscriber_counter : AtomicUsize ,
84
+ pub subscribers : DashMap < SubscriberId , mpsc:: Sender < AggregationEvent > > ,
85
+ pub bytes_limit_whitelist : Vec < IpNet > ,
86
+ pub rate_limiter : DefaultKeyedRateLimiter < IpAddr > ,
87
+ pub requester_ip_header_name : String ,
90
88
}
91
89
92
90
impl WsState {
93
- pub fn new ( whitelist : Vec < IpNet > ) -> Self {
91
+ pub fn new ( whitelist : Vec < IpNet > , requester_ip_header_name : String ) -> Self {
94
92
Self {
95
- subscriber_counter : AtomicUsize :: new ( 0 ) ,
96
- subscribers : DashMap :: new ( ) ,
97
- rate_limiter : RateLimiter :: dashmap ( Quota :: per_second ( nonzero ! (
93
+ subscriber_counter : AtomicUsize :: new ( 0 ) ,
94
+ subscribers : DashMap :: new ( ) ,
95
+ rate_limiter : RateLimiter :: dashmap ( Quota :: per_second ( nonzero ! (
98
96
BYTES_LIMIT_PER_IP_PER_SECOND
99
97
) ) ) ,
100
98
bytes_limit_whitelist : whitelist,
99
+ requester_ip_header_name,
101
100
}
102
101
}
103
102
}
@@ -142,23 +141,33 @@ enum ServerResponseMessage {
142
141
pub async fn ws_route_handler (
143
142
ws : WebSocketUpgrade ,
144
143
AxumState ( state) : AxumState < super :: ApiState > ,
145
- ConnectInfo ( addr ) : ConnectInfo < SocketAddr > ,
144
+ headers : HeaderMap ,
146
145
) -> impl IntoResponse {
146
+ let requester_ip = headers
147
+ . get ( state. ws . requester_ip_header_name . as_str ( ) )
148
+ . and_then ( |value| value. to_str ( ) . ok ( ) )
149
+ . and_then ( |value| value. split ( ',' ) . next ( ) ) // Only take the first ip if there are multiple
150
+ . and_then ( |value| value. parse ( ) . ok ( ) ) ;
151
+
147
152
ws. max_message_size ( MAX_CLIENT_MESSAGE_SIZE )
148
- . on_upgrade ( move |socket| websocket_handler ( socket, state, addr ) )
153
+ . on_upgrade ( move |socket| websocket_handler ( socket, state, requester_ip ) )
149
154
}
150
155
151
- #[ tracing:: instrument( skip( stream, state, addr) ) ]
152
- async fn websocket_handler ( stream : WebSocket , state : super :: ApiState , addr : SocketAddr ) {
156
+ #[ tracing:: instrument( skip( stream, state, subscriber_ip) ) ]
157
+ async fn websocket_handler (
158
+ stream : WebSocket ,
159
+ state : super :: ApiState ,
160
+ subscriber_ip : Option < IpAddr > ,
161
+ ) {
153
162
let ws_state = state. ws . clone ( ) ;
154
163
let id = ws_state. subscriber_counter . fetch_add ( 1 , Ordering :: SeqCst ) ;
155
- tracing:: debug!( id, %addr , "New Websocket Connection" ) ;
164
+ tracing:: debug!( id, ?subscriber_ip , "New Websocket Connection" ) ;
156
165
157
166
let ( notify_sender, notify_receiver) = mpsc:: channel ( NOTIFICATIONS_CHAN_LEN ) ;
158
167
let ( sender, receiver) = stream. split ( ) ;
159
168
let mut subscriber = Subscriber :: new (
160
169
id,
161
- addr . ip ( ) ,
170
+ subscriber_ip ,
162
171
state. state . clone ( ) ,
163
172
state. ws . clone ( ) ,
164
173
notify_receiver,
@@ -176,7 +185,7 @@ pub type SubscriberId = usize;
176
185
/// It listens to the store for updates and sends them to the client.
177
186
pub struct Subscriber {
178
187
id : SubscriberId ,
179
- ip_addr : IpAddr ,
188
+ ip_addr : Option < IpAddr > ,
180
189
closed : bool ,
181
190
store : Arc < State > ,
182
191
ws_state : Arc < WsState > ,
@@ -191,7 +200,7 @@ pub struct Subscriber {
191
200
impl Subscriber {
192
201
pub fn new (
193
202
id : SubscriberId ,
194
- ip_addr : IpAddr ,
203
+ ip_addr : Option < IpAddr > ,
195
204
store : Arc < State > ,
196
205
ws_state : Arc < WsState > ,
197
206
notify_receiver : mpsc:: Receiver < AggregationEvent > ,
@@ -291,32 +300,36 @@ impl Subscriber {
291
300
} ) ?;
292
301
293
302
// Close the connection if rate limit is exceeded and the ip is not whitelisted.
294
- if !self
295
- . ws_state
296
- . bytes_limit_whitelist
297
- . iter ( )
298
- . any ( |ip_net| ip_net. contains ( & self . ip_addr ) )
299
- && self . ws_state . rate_limiter . check_key_n (
300
- & self . ip_addr ,
301
- NonZeroU32 :: new ( message. len ( ) . try_into ( ) ?) . ok_or ( anyhow ! ( "Empty message" ) ) ?,
302
- ) != Ok ( Ok ( ( ) ) )
303
- {
304
- tracing:: info!(
305
- self . id,
306
- ip = %self . ip_addr,
307
- "Rate limit exceeded. Closing connection." ,
308
- ) ;
309
- self . sender
310
- . send (
311
- serde_json:: to_string ( & ServerResponseMessage :: Err {
312
- error : "Rate limit exceeded" . to_string ( ) ,
313
- } ) ?
314
- . into ( ) ,
315
- )
316
- . await ?;
317
- self . sender . close ( ) . await ?;
318
- self . closed = true ;
319
- return Ok ( ( ) ) ;
303
+ // If the ip address is None no rate limiting is applied.
304
+ if let Some ( ip_addr) = self . ip_addr {
305
+ if !self
306
+ . ws_state
307
+ . bytes_limit_whitelist
308
+ . iter ( )
309
+ . any ( |ip_net| ip_net. contains ( & ip_addr) )
310
+ && self . ws_state . rate_limiter . check_key_n (
311
+ & ip_addr,
312
+ NonZeroU32 :: new ( message. len ( ) . try_into ( ) ?)
313
+ . ok_or ( anyhow ! ( "Empty message" ) ) ?,
314
+ ) != Ok ( Ok ( ( ) ) )
315
+ {
316
+ tracing:: info!(
317
+ self . id,
318
+ ip = %ip_addr,
319
+ "Rate limit exceeded. Closing connection." ,
320
+ ) ;
321
+ self . sender
322
+ . send (
323
+ serde_json:: to_string ( & ServerResponseMessage :: Err {
324
+ error : "Rate limit exceeded" . to_string ( ) ,
325
+ } ) ?
326
+ . into ( ) ,
327
+ )
328
+ . await ?;
329
+ self . sender . close ( ) . await ?;
330
+ self . closed = true ;
331
+ return Ok ( ( ) ) ;
332
+ }
320
333
}
321
334
322
335
// `sender.feed` buffers a message to the client but does not flush it, so we can send
0 commit comments