Skip to content

Commit 56d794b

Browse files
committed
refactor: replace websocket with server-sent events for queue updates
1 parent 104c6c0 commit 56d794b

File tree

4 files changed

+37
-67
lines changed

4 files changed

+37
-67
lines changed

Cargo.lock

Lines changed: 5 additions & 34 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ edition = "2024"
66
[dependencies]
77
anyhow = "1.0.101"
88
async-trait = "0.1"
9-
axum = { version = "0.8.8", features = ["ws"] }
9+
axum = { version = "0.8.8" }
10+
futures = "0.3"
11+
tokio-stream = { version = "0.1", features = ["sync"] }
1012
chrono = { version = "0.4", features = ["serde"] }
1113
dotenv = "0.15.0"
1214
poise = "0.6.1"

justfile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,6 @@ test:
1414

1515
migrate:
1616
sqlx migrate run --source migrations --database-url "$DATABASE_URL"
17+
18+
offline:
19+
cargo sqlx prepare

src/adapters/http/mod.rs

Lines changed: 26 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,16 @@
11
use axum::{
2-
extract::{
3-
ws::{Message, WebSocket, WebSocketUpgrade},
4-
Path, State,
5-
},
6-
response::Response,
7-
routing::{any, get},
82
Json, Router,
3+
extract::{Path, State},
4+
response::sse::{Event, KeepAlive, Sse},
5+
routing::get,
96
};
7+
use futures::stream::{self, Stream, StreamExt};
8+
use tokio_stream::wrappers::BroadcastStream;
109
use tower::ServiceBuilder;
1110
use tower_http::trace::TraceLayer;
1211
use tracing::info;
1312

14-
use std::{io, sync::Arc};
13+
use std::{convert::Infallible, io, sync::Arc};
1514

1615
use crate::domain::{OrderRepository, QueueEntry, QueueEvent, QueueRepository};
1716

@@ -41,7 +40,7 @@ impl HttpAdapter {
4140
let app = Router::new()
4241
.route("/{guild_id}/status", get(queue_status))
4342
.route("/{guild_id}/queue", get(list_queue))
44-
.route("/{guild_id}/queue/ws", any(list_queue_ws))
43+
.route("/{guild_id}/queue/sse", get(list_queue_sse))
4544
.layer(ServiceBuilder::new().layer(TraceLayer::new_for_http()))
4645
.with_state(state);
4746

@@ -66,37 +65,32 @@ async fn list_queue(
6665
Json(queue)
6766
}
6867

69-
async fn list_queue_ws(
68+
async fn list_queue_sse(
7069
State(state): State<Arc<AppState>>,
7170
Path(guild_id): Path<String>,
72-
ws: WebSocketUpgrade,
73-
) -> Response {
74-
ws.on_upgrade(move |socket| list_queue_ws_handler(state, guild_id, socket))
75-
}
76-
77-
async fn list_queue_ws_handler(state: Arc<AppState>, guild_id: String, mut socket: WebSocket) {
78-
info!("new websocket connection for guild_id: {}", guild_id);
71+
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
72+
let rx = state.queue.subscribe();
7973

80-
let mut rx = state.queue.subscribe();
81-
82-
// Initial state
8374
let queue = state.queue.list(&guild_id).await;
84-
let msg = serde_json::to_string(&queue).unwrap();
85-
if socket.send(Message::Text(msg.into())).await.is_err() {
86-
return;
87-
}
75+
let initial = serde_json::to_string(&queue).unwrap();
8876

89-
while let Ok(event) = rx.recv().await {
90-
match event {
91-
QueueEvent::Updated { guild_id: gid } => {
92-
if gid == guild_id {
77+
let first =
78+
stream::once(async move { Ok::<Event, Infallible>(Event::default().data(initial)) });
79+
80+
let updates = BroadcastStream::new(rx).filter_map(move |event| {
81+
let guild_id = guild_id.clone();
82+
let state = state.clone();
83+
async move {
84+
match event {
85+
Ok(QueueEvent::Updated { guild_id: gid }) if gid == guild_id => {
9386
let queue = state.queue.list(&guild_id).await;
94-
let msg = serde_json::to_string(&queue).unwrap();
95-
if socket.send(Message::Text(msg.into())).await.is_err() {
96-
break;
97-
}
87+
let data = serde_json::to_string(&queue).unwrap();
88+
Some(Ok(Event::default().data(data)))
9889
}
90+
_ => None,
9991
}
10092
}
101-
}
93+
});
94+
95+
Sse::new(first.chain(updates)).keep_alive(KeepAlive::default())
10296
}

0 commit comments

Comments
 (0)