@@ -19,14 +19,13 @@ pub struct PendingRequest {
1919 pub response_tx : oneshot:: Sender < HttpResponse > ,
2020}
2121
22- /// Connected client with its request channel and pending requests
23- pub struct ConnectedClient {
24- pub request_tx : mpsc:: Sender < HttpRequest > ,
25- pub pending_requests : Arc < RwLock < HashMap < String , PendingRequest > > > ,
22+ pub struct Envelope {
23+ pub request : HttpRequest ,
24+ pub response_tx : oneshot:: Sender < HttpResponse > ,
2625}
2726
2827/// Active client connections indexed by route
29- pub type ClientConnections = Arc < RwLock < HashMap < String , Arc < ConnectedClient > > > > ;
28+ pub type ClientConnections = Arc < RwLock < HashMap < String , mpsc :: Sender < Envelope > > > > ;
3029
3130pub struct RelayServiceImpl {
3231 connections : ClientConnections ,
@@ -35,18 +34,26 @@ pub struct RelayServiceImpl {
3534}
3635
3736impl RelayServiceImpl {
38- pub fn new ( connections : ClientConnections , jwks_cache : Arc < JwksCache > , config : & Config ) -> Self {
37+ pub fn new (
38+ connections : ClientConnections ,
39+ jwks_cache : Arc < JwksCache > ,
40+ config : & Config ,
41+ ) -> Self {
3942 Self {
4043 connections,
4144 jwks_cache,
42- external_url : config. external_url . clone ( ) . trim_end_matches ( '/' ) . to_string ( ) ,
45+ external_url : config
46+ . external_url
47+ . clone ( )
48+ . trim_end_matches ( '/' )
49+ . to_string ( ) ,
4350 }
4451 }
4552}
4653
4754/// Generate a route from user_id and session_id using a hash
4855fn session_to_route ( user_id : & str , session_id : & str ) -> String {
49- use sha2:: { Sha256 , Digest } ;
56+ use sha2:: { Digest , Sha256 } ;
5057 let mut hasher = Sha256 :: new ( ) ;
5158 hasher. update ( user_id. as_bytes ( ) ) ;
5259 hasher. update ( session_id. as_bytes ( ) ) ;
@@ -111,7 +118,9 @@ impl RelayService for RelayServiceImpl {
111118 // Always return session_id in response metadata
112119 response. metadata_mut ( ) . insert (
113120 SESSION_ID_HEADER ,
114- session_id. parse ( ) . map_err ( |_| Status :: internal ( "Failed to set session header" ) ) ?,
121+ session_id
122+ . parse ( )
123+ . map_err ( |_| Status :: internal ( "Failed to set session header" ) ) ?,
115124 ) ;
116125
117126 Ok ( response)
@@ -128,18 +137,15 @@ impl RelayService for RelayServiceImpl {
128137 let session_id = get_session_id ( & request) ?;
129138 let route = session_to_route ( & user_id, & session_id) ;
130139
131- let ( request_tx, request_rx) = mpsc:: channel :: < HttpRequest > ( 32 ) ;
132- let pending_requests = Arc :: new ( RwLock :: new ( HashMap :: new ( ) ) ) ;
133-
134- let client = Arc :: new ( ConnectedClient {
135- request_tx,
136- pending_requests : pending_requests. clone ( ) ,
137- } ) ;
140+ let ( request_tx, request_rx) = mpsc:: channel :: < Envelope > ( 32 ) ;
141+ let pending_requests: Arc < RwLock < HashMap < String , oneshot:: Sender < HttpResponse > > > > =
142+ Arc :: new ( RwLock :: new ( HashMap :: new ( ) ) ) ;
143+ let pending_requests_ref = pending_requests. clone ( ) ;
138144
139145 // Register client connection
140146 {
141147 let mut connections = self . connections . write ( ) . await ;
142- connections. insert ( route. clone ( ) , client ) ;
148+ connections. insert ( route. clone ( ) , request_tx ) ;
143149 }
144150
145151 tracing:: info!(
@@ -161,7 +167,7 @@ impl RelayService for RelayServiceImpl {
161167 let request_id = response. request_id . clone ( ) ;
162168 let mut pending = pending_requests. write ( ) . await ;
163169 if let Some ( pending_req) = pending. remove ( & request_id) {
164- let _ = pending_req. response_tx . send ( response) ;
170+ let _ = pending_req. send ( response) ;
165171 } else {
166172 tracing:: warn!(
167173 request_id = %request_id,
@@ -182,59 +188,52 @@ impl RelayService for RelayServiceImpl {
182188 tracing:: info!( route = %route_clone, "Client stream ended, unregistered" ) ;
183189 } ) ;
184190
185- let output_stream = ReceiverStream :: new ( request_rx) . map ( Ok ) ;
191+ let output_stream = ReceiverStream :: new ( request_rx) . then ( move |ev| {
192+ let pending_requests_ref = pending_requests_ref. clone ( ) ;
193+ async move {
194+ let mut pending = pending_requests_ref. write ( ) . await ;
195+ pending. insert ( ev. request . request_id . clone ( ) , ev. response_tx ) ;
196+ Ok ( ev. request )
197+ }
198+ } ) ;
186199
187200 let mut response = Response :: new ( Box :: pin ( output_stream) as Self :: DoWebhookStream ) ;
188-
201+
189202 // Return session_id in response metadata
190203 response. metadata_mut ( ) . insert (
191204 SESSION_ID_HEADER ,
192- session_id. parse ( ) . map_err ( |_| Status :: internal ( "Failed to set session header" ) ) ?,
205+ session_id
206+ . parse ( )
207+ . map_err ( |_| Status :: internal ( "Failed to set session header" ) ) ?,
193208 ) ;
194209
195210 Ok ( response)
196211 }
197212}
198213
199214pub async fn send_request_to_client (
200- client : & Arc < ConnectedClient > ,
215+ request_tx : & mpsc :: Sender < Envelope > ,
201216 request : HttpRequest ,
202217 timeout : std:: time:: Duration ,
203218) -> Result < HttpResponse , Status > {
204- let request_id = request. request_id . clone ( ) ;
205-
206219 // Create oneshot channel for response
207220 let ( response_tx, response_rx) = oneshot:: channel ( ) ;
208221
209- // Register pending request
210- {
211- let mut pending = client. pending_requests . write ( ) . await ;
212- pending. insert ( request_id. clone ( ) , PendingRequest { response_tx } ) ;
213- }
214-
215222 // Send request to client
216- client
217- . request_tx
218- . send ( request)
223+ request_tx
224+ . send ( Envelope {
225+ request,
226+ response_tx,
227+ } )
219228 . await
220229 . map_err ( |_| Status :: unavailable ( "Client disconnected" ) ) ?;
221230
222231 // Wait for response with timeout
223232 match tokio:: time:: timeout ( timeout, response_rx) . await {
224233 Ok ( Ok ( response) ) => Ok ( response) ,
225- Ok ( Err ( _) ) => {
226- // Channel closed, remove pending request
227- let mut pending = client. pending_requests . write ( ) . await ;
228- pending. remove ( & request_id) ;
229- Err ( Status :: unavailable (
230- "Client disconnected while waiting for response" ,
231- ) )
232- }
233- Err ( _) => {
234- // Timeout, remove pending request
235- let mut pending = client. pending_requests . write ( ) . await ;
236- pending. remove ( & request_id) ;
237- Err ( Status :: deadline_exceeded ( "Request timed out" ) )
238- }
234+ Ok ( Err ( _) ) => Err ( Status :: unavailable (
235+ "Client disconnected while waiting for response" ,
236+ ) ) ,
237+ Err ( _) => Err ( Status :: deadline_exceeded ( "Request timed out" ) ) ,
239238 }
240239}
0 commit comments