@@ -19,12 +19,15 @@ use {
19
19
pyth_sdk:: PriceIdentifier ,
20
20
serde:: Deserialize ,
21
21
serde_qs:: axum:: QsQuery ,
22
- std:: convert:: Infallible ,
23
- tokio:: sync:: broadcast,
22
+ std:: { convert:: Infallible , time :: Duration } ,
23
+ tokio:: { sync:: broadcast, time :: Instant } ,
24
24
tokio_stream:: { wrappers:: BroadcastStream , StreamExt as _} ,
25
25
utoipa:: IntoParams ,
26
26
} ;
27
27
28
+ // Constants
29
+ const MAX_CONNECTION_DURATION : Duration = Duration :: from_secs ( 10 ) ; // 24 hours
30
+
28
31
#[ derive( Debug , Deserialize , IntoParams ) ]
29
32
#[ into_params( parameter_in = Query ) ]
30
33
pub struct StreamPriceUpdatesQueryParams {
@@ -93,10 +96,17 @@ where
93
96
// Convert the broadcast receiver into a Stream
94
97
let stream = BroadcastStream :: new ( update_rx) ;
95
98
99
+ // Set connection deadline
100
+ let connection_deadline = Instant :: now ( ) + MAX_CONNECTION_DURATION ;
101
+
96
102
let sse_stream = stream
103
+ . take_while ( move |_| {
104
+ let now = Instant :: now ( ) ;
105
+ now < connection_deadline
106
+ } )
97
107
. then ( move |message| {
98
- let state_clone = state. clone ( ) ; // Clone again to use inside the async block
99
- let price_ids_clone = price_ids. clone ( ) ; // Clone again for use inside the async block
108
+ let state_clone = state. clone ( ) ;
109
+ let price_ids_clone = price_ids. clone ( ) ;
100
110
async move {
101
111
match message {
102
112
Ok ( event) => {
@@ -122,7 +132,12 @@ where
122
132
}
123
133
}
124
134
} )
125
- . filter_map ( |x| x) ;
135
+ . filter_map ( |x| x)
136
+ . chain ( futures:: stream:: once ( async {
137
+ Ok ( Event :: default ( )
138
+ . event ( "error" )
139
+ . data ( "Connection timeout reached (24h)" ) )
140
+ } ) ) ;
126
141
127
142
Ok ( Sse :: new ( sse_stream) . keep_alive ( KeepAlive :: default ( ) ) )
128
143
}
0 commit comments