26
26
http:: HeaderMap ,
27
27
response:: IntoResponse ,
28
28
} ,
29
- dashmap:: DashMap ,
30
29
futures:: {
31
- future:: join_all,
32
30
stream:: {
33
31
SplitSink ,
34
32
SplitStream ,
@@ -71,11 +69,10 @@ use {
71
69
} ,
72
70
time:: Duration ,
73
71
} ,
74
- tokio:: sync:: mpsc ,
72
+ tokio:: sync:: broadcast :: Receiver ,
75
73
} ;
76
74
77
75
const PING_INTERVAL_DURATION : Duration = Duration :: from_secs ( 30 ) ;
78
- const NOTIFICATIONS_CHAN_LEN : usize = 1000 ;
79
76
const MAX_CLIENT_MESSAGE_SIZE : usize = 100 * 1024 ; // 100 KiB
80
77
81
78
/// The maximum number of bytes that can be sent per second per IP address.
@@ -139,7 +136,6 @@ impl Metrics {
139
136
140
137
pub struct WsState {
141
138
pub subscriber_counter : AtomicUsize ,
142
- pub subscribers : DashMap < SubscriberId , mpsc:: Sender < AggregationEvent > > ,
143
139
pub bytes_limit_whitelist : Vec < IpNet > ,
144
140
pub rate_limiter : DefaultKeyedRateLimiter < IpAddr > ,
145
141
pub requester_ip_header_name : String ,
@@ -150,7 +146,6 @@ impl WsState {
150
146
pub fn new ( whitelist : Vec < IpNet > , requester_ip_header_name : String , state : Arc < State > ) -> Self {
151
147
Self {
152
148
subscriber_counter : AtomicUsize :: new ( 0 ) ,
153
- subscribers : DashMap :: new ( ) ,
154
149
rate_limiter : RateLimiter :: dashmap ( Quota :: per_second ( nonzero ! (
155
150
BYTES_LIMIT_PER_IP_PER_SECOND
156
151
) ) ) ,
@@ -220,6 +215,11 @@ async fn websocket_handler(
220
215
subscriber_ip : Option < IpAddr > ,
221
216
) {
222
217
let ws_state = state. ws . clone ( ) ;
218
+
219
+ // Retain the recent rate limit data for the IP addresses to
220
+ // prevent the rate limiter size from growing indefinitely.
221
+ ws_state. rate_limiter . retain_recent ( ) ;
222
+
223
223
let id = ws_state. subscriber_counter . fetch_add ( 1 , Ordering :: SeqCst ) ;
224
224
225
225
tracing:: debug!( id, ?subscriber_ip, "New Websocket Connection" ) ;
@@ -232,7 +232,7 @@ async fn websocket_handler(
232
232
} )
233
233
. inc ( ) ;
234
234
235
- let ( notify_sender , notify_receiver) = mpsc :: channel ( NOTIFICATIONS_CHAN_LEN ) ;
235
+ let notify_receiver = state . update_tx . subscribe ( ) ;
236
236
let ( sender, receiver) = stream. split ( ) ;
237
237
let mut subscriber = Subscriber :: new (
238
238
id,
@@ -244,7 +244,6 @@ async fn websocket_handler(
244
244
sender,
245
245
) ;
246
246
247
- ws_state. subscribers . insert ( id, notify_sender) ;
248
247
subscriber. run ( ) . await ;
249
248
}
250
249
@@ -258,7 +257,7 @@ pub struct Subscriber {
258
257
closed : bool ,
259
258
store : Arc < State > ,
260
259
ws_state : Arc < WsState > ,
261
- notify_receiver : mpsc :: Receiver < AggregationEvent > ,
260
+ notify_receiver : Receiver < AggregationEvent > ,
262
261
receiver : SplitStream < WebSocket > ,
263
262
sender : SplitSink < WebSocket , Message > ,
264
263
price_feeds_with_config : HashMap < PriceIdentifier , PriceFeedClientConfig > ,
@@ -273,7 +272,7 @@ impl Subscriber {
273
272
ip_addr : Option < IpAddr > ,
274
273
store : Arc < State > ,
275
274
ws_state : Arc < WsState > ,
276
- notify_receiver : mpsc :: Receiver < AggregationEvent > ,
275
+ notify_receiver : Receiver < AggregationEvent > ,
277
276
receiver : SplitStream < WebSocket > ,
278
277
sender : SplitSink < WebSocket , Message > ,
279
278
) -> Self {
@@ -307,8 +306,8 @@ impl Subscriber {
307
306
tokio:: select! {
308
307
maybe_update_feeds_event = self . notify_receiver. recv( ) => {
309
308
match maybe_update_feeds_event {
310
- Some ( event) => self . handle_price_feeds_update( event) . await ,
311
- None => Err ( anyhow!( "Update channel closed. This should never happen. Closing connection." ) )
309
+ Ok ( event) => self . handle_price_feeds_update( event) . await ,
310
+ Err ( e ) => Err ( anyhow!( "Failed to receive update from store: {:?}" , e ) ) ,
312
311
}
313
312
} ,
314
313
maybe_message_or_err = self . receiver. next( ) => {
@@ -610,33 +609,3 @@ impl Subscriber {
610
609
Ok ( ( ) )
611
610
}
612
611
}
613
-
614
- pub async fn notify_updates ( ws_state : Arc < WsState > , event : AggregationEvent ) {
615
- let closed_subscribers: Vec < Option < SubscriberId > > =
616
- join_all ( ws_state. subscribers . iter_mut ( ) . map ( |subscriber| {
617
- let event = event. clone ( ) ;
618
- async move {
619
- match subscriber. send ( event) . await {
620
- Ok ( _) => None ,
621
- Err ( _) => {
622
- // An error here indicates the channel is closed (which may happen either when the
623
- // client has sent Message::Close or some other abrupt disconnection). We remove
624
- // subscribers only when send fails so we can handle closure only once when we are
625
- // able to see send() fail.
626
- Some ( * subscriber. key ( ) )
627
- }
628
- }
629
- }
630
- } ) )
631
- . await ;
632
-
633
- // Remove closed_subscribers from ws_state
634
- closed_subscribers. into_iter ( ) . for_each ( |id| {
635
- if let Some ( id) = id {
636
- ws_state. subscribers . remove ( & id) ;
637
- }
638
- } ) ;
639
-
640
- // Clean the bytes limiting dictionary
641
- ws_state. rate_limiter . retain_recent ( ) ;
642
- }
0 commit comments