@@ -19,12 +19,15 @@ use {
1919 pyth_sdk:: PriceIdentifier ,
2020 serde:: Deserialize ,
2121 serde_qs:: axum:: QsQuery ,
22- std:: convert:: Infallible ,
23- tokio:: sync:: broadcast,
22+ std:: { convert:: Infallible , time :: Duration } ,
23+ tokio:: { sync:: broadcast, time :: Instant } ,
2424 tokio_stream:: { wrappers:: BroadcastStream , StreamExt as _} ,
2525 utoipa:: IntoParams ,
2626} ;
2727
28+ // Constants
29+ const MAX_CONNECTION_DURATION : Duration = Duration :: from_secs ( 10 ) ; // 24 hours
30+
2831#[ derive( Debug , Deserialize , IntoParams ) ]
2932#[ into_params( parameter_in = Query ) ]
3033pub struct StreamPriceUpdatesQueryParams {
@@ -93,10 +96,17 @@ where
9396 // Convert the broadcast receiver into a Stream
9497 let stream = BroadcastStream :: new ( update_rx) ;
9598
99+ // Set connection deadline
100+ let connection_deadline = Instant :: now ( ) + MAX_CONNECTION_DURATION ;
101+
96102 let sse_stream = stream
103+ . take_while ( move |_| {
104+ let now = Instant :: now ( ) ;
105+ now < connection_deadline
106+ } )
97107 . then ( move |message| {
98- let state_clone = state. clone ( ) ; // Clone again to use inside the async block
99- let price_ids_clone = price_ids. clone ( ) ; // Clone again for use inside the async block
108+ let state_clone = state. clone ( ) ;
109+ let price_ids_clone = price_ids. clone ( ) ;
100110 async move {
101111 match message {
102112 Ok ( event) => {
@@ -122,7 +132,12 @@ where
122132 }
123133 }
124134 } )
125- . filter_map ( |x| x) ;
135+ . filter_map ( |x| x)
136+ . chain ( futures:: stream:: once ( async {
137+ Ok ( Event :: default ( )
138+ . event ( "error" )
139+ . data ( "Connection timeout reached (24h)" ) )
140+ } ) ) ;
126141
127142 Ok ( Sse :: new ( sse_stream) . keep_alive ( KeepAlive :: default ( ) ) )
128143}
0 commit comments