Skip to content

Commit d9220fe

Browse files
committed
feat: simplify connection store
1 parent 4e2610d commit d9220fe

File tree

3 files changed

+53
-54
lines changed

3 files changed

+53
-54
lines changed

Cargo.lock

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,5 @@ open = "5"
4444
chrono = {version = "0.4", features = ["serde"]}
4545

4646
[workspace.package]
47-
version = "0.1.1"
47+
version = "0.1.2"
4848
edition = "2024"

server/src/grpc.rs

Lines changed: 46 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@ pub struct PendingRequest {
1919
pub response_tx: oneshot::Sender<HttpResponse>,
2020
}
2121

22-
/// Connected client with its request channel and pending requests
23-
pub struct ConnectedClient {
24-
pub request_tx: mpsc::Sender<HttpRequest>,
25-
pub pending_requests: Arc<RwLock<HashMap<String, PendingRequest>>>,
22+
pub struct Envelope {
23+
pub request: HttpRequest,
24+
pub response_tx: oneshot::Sender<HttpResponse>,
2625
}
2726

2827
/// Active client connections indexed by route
29-
pub type ClientConnections = Arc<RwLock<HashMap<String, Arc<ConnectedClient>>>>;
28+
pub type ClientConnections = Arc<RwLock<HashMap<String, mpsc::Sender<Envelope>>>>;
3029

3130
pub struct RelayServiceImpl {
3231
connections: ClientConnections,
@@ -35,18 +34,26 @@ pub struct RelayServiceImpl {
3534
}
3635

3736
impl RelayServiceImpl {
38-
pub fn new(connections: ClientConnections, jwks_cache: Arc<JwksCache>, config: &Config) -> Self {
37+
pub fn new(
38+
connections: ClientConnections,
39+
jwks_cache: Arc<JwksCache>,
40+
config: &Config,
41+
) -> Self {
3942
Self {
4043
connections,
4144
jwks_cache,
42-
external_url: config.external_url.clone().trim_end_matches('/').to_string(),
45+
external_url: config
46+
.external_url
47+
.clone()
48+
.trim_end_matches('/')
49+
.to_string(),
4350
}
4451
}
4552
}
4653

4754
/// Generate a route from user_id and session_id using a hash
4855
fn session_to_route(user_id: &str, session_id: &str) -> String {
49-
use sha2::{Sha256, Digest};
56+
use sha2::{Digest, Sha256};
5057
let mut hasher = Sha256::new();
5158
hasher.update(user_id.as_bytes());
5259
hasher.update(session_id.as_bytes());
@@ -111,7 +118,9 @@ impl RelayService for RelayServiceImpl {
111118
// Always return session_id in response metadata
112119
response.metadata_mut().insert(
113120
SESSION_ID_HEADER,
114-
session_id.parse().map_err(|_| Status::internal("Failed to set session header"))?,
121+
session_id
122+
.parse()
123+
.map_err(|_| Status::internal("Failed to set session header"))?,
115124
);
116125

117126
Ok(response)
@@ -128,18 +137,15 @@ impl RelayService for RelayServiceImpl {
128137
let session_id = get_session_id(&request)?;
129138
let route = session_to_route(&user_id, &session_id);
130139

131-
let (request_tx, request_rx) = mpsc::channel::<HttpRequest>(32);
132-
let pending_requests = Arc::new(RwLock::new(HashMap::new()));
133-
134-
let client = Arc::new(ConnectedClient {
135-
request_tx,
136-
pending_requests: pending_requests.clone(),
137-
});
140+
let (request_tx, request_rx) = mpsc::channel::<Envelope>(32);
141+
let pending_requests: Arc<RwLock<HashMap<String, oneshot::Sender<HttpResponse>>>> =
142+
Arc::new(RwLock::new(HashMap::new()));
143+
let pending_requests_ref = pending_requests.clone();
138144

139145
// Register client connection
140146
{
141147
let mut connections = self.connections.write().await;
142-
connections.insert(route.clone(), client);
148+
connections.insert(route.clone(), request_tx);
143149
}
144150

145151
tracing::info!(
@@ -161,7 +167,7 @@ impl RelayService for RelayServiceImpl {
161167
let request_id = response.request_id.clone();
162168
let mut pending = pending_requests.write().await;
163169
if let Some(pending_req) = pending.remove(&request_id) {
164-
let _ = pending_req.response_tx.send(response);
170+
let _ = pending_req.send(response);
165171
} else {
166172
tracing::warn!(
167173
request_id = %request_id,
@@ -182,59 +188,52 @@ impl RelayService for RelayServiceImpl {
182188
tracing::info!(route = %route_clone, "Client stream ended, unregistered");
183189
});
184190

185-
let output_stream = ReceiverStream::new(request_rx).map(Ok);
191+
let output_stream = ReceiverStream::new(request_rx).then(move |ev| {
192+
let pending_requests_ref = pending_requests_ref.clone();
193+
async move {
194+
let mut pending = pending_requests_ref.write().await;
195+
pending.insert(ev.request.request_id.clone(), ev.response_tx);
196+
Ok(ev.request)
197+
}
198+
});
186199

187200
let mut response = Response::new(Box::pin(output_stream) as Self::DoWebhookStream);
188-
201+
189202
// Return session_id in response metadata
190203
response.metadata_mut().insert(
191204
SESSION_ID_HEADER,
192-
session_id.parse().map_err(|_| Status::internal("Failed to set session header"))?,
205+
session_id
206+
.parse()
207+
.map_err(|_| Status::internal("Failed to set session header"))?,
193208
);
194209

195210
Ok(response)
196211
}
197212
}
198213

199214
pub async fn send_request_to_client(
200-
client: &Arc<ConnectedClient>,
215+
request_tx: &mpsc::Sender<Envelope>,
201216
request: HttpRequest,
202217
timeout: std::time::Duration,
203218
) -> Result<HttpResponse, Status> {
204-
let request_id = request.request_id.clone();
205-
206219
// Create oneshot channel for response
207220
let (response_tx, response_rx) = oneshot::channel();
208221

209-
// Register pending request
210-
{
211-
let mut pending = client.pending_requests.write().await;
212-
pending.insert(request_id.clone(), PendingRequest { response_tx });
213-
}
214-
215222
// Send request to client
216-
client
217-
.request_tx
218-
.send(request)
223+
request_tx
224+
.send(Envelope {
225+
request,
226+
response_tx,
227+
})
219228
.await
220229
.map_err(|_| Status::unavailable("Client disconnected"))?;
221230

222231
// Wait for response with timeout
223232
match tokio::time::timeout(timeout, response_rx).await {
224233
Ok(Ok(response)) => Ok(response),
225-
Ok(Err(_)) => {
226-
// Channel closed, remove pending request
227-
let mut pending = client.pending_requests.write().await;
228-
pending.remove(&request_id);
229-
Err(Status::unavailable(
230-
"Client disconnected while waiting for response",
231-
))
232-
}
233-
Err(_) => {
234-
// Timeout, remove pending request
235-
let mut pending = client.pending_requests.write().await;
236-
pending.remove(&request_id);
237-
Err(Status::deadline_exceeded("Request timed out"))
238-
}
234+
Ok(Err(_)) => Err(Status::unavailable(
235+
"Client disconnected while waiting for response",
236+
)),
237+
Err(_) => Err(Status::deadline_exceeded("Request timed out")),
239238
}
240239
}

0 commit comments

Comments
 (0)