11use axum:: {
2- extract:: {
3- ws:: { Message , WebSocket , WebSocketUpgrade } ,
4- Path , State ,
5- } ,
6- response:: Response ,
7- routing:: { any, get} ,
82 Json , Router ,
3+ extract:: { Path , State } ,
4+ response:: sse:: { Event , KeepAlive , Sse } ,
5+ routing:: get,
96} ;
7+ use futures:: stream:: { self , Stream , StreamExt } ;
8+ use tokio_stream:: wrappers:: BroadcastStream ;
109use tower:: ServiceBuilder ;
1110use tower_http:: trace:: TraceLayer ;
1211use tracing:: info;
1312
14- use std:: { io, sync:: Arc } ;
13+ use std:: { convert :: Infallible , io, sync:: Arc } ;
1514
1615use crate :: domain:: { OrderRepository , QueueEntry , QueueEvent , QueueRepository } ;
1716
@@ -41,7 +40,7 @@ impl HttpAdapter {
4140 let app = Router :: new ( )
4241 . route ( "/{guild_id}/status" , get ( queue_status) )
4342 . route ( "/{guild_id}/queue" , get ( list_queue) )
44- . route ( "/{guild_id}/queue/ws " , any ( list_queue_ws ) )
43+ . route ( "/{guild_id}/queue/sse " , get ( list_queue_sse ) )
4544 . layer ( ServiceBuilder :: new ( ) . layer ( TraceLayer :: new_for_http ( ) ) )
4645 . with_state ( state) ;
4746
@@ -66,37 +65,32 @@ async fn list_queue(
6665 Json ( queue)
6766}
6867
69- async fn list_queue_ws (
68+ async fn list_queue_sse (
7069 State ( state) : State < Arc < AppState > > ,
7170 Path ( guild_id) : Path < String > ,
72- ws : WebSocketUpgrade ,
73- ) -> Response {
74- ws. on_upgrade ( move |socket| list_queue_ws_handler ( state, guild_id, socket) )
75- }
76-
77- async fn list_queue_ws_handler ( state : Arc < AppState > , guild_id : String , mut socket : WebSocket ) {
78- info ! ( "new websocket connection for guild_id: {}" , guild_id) ;
71+ ) -> Sse < impl Stream < Item = Result < Event , Infallible > > > {
72+ let rx = state. queue . subscribe ( ) ;
7973
80- let mut rx = state. queue . subscribe ( ) ;
81-
82- // Initial state
8374 let queue = state. queue . list ( & guild_id) . await ;
84- let msg = serde_json:: to_string ( & queue) . unwrap ( ) ;
85- if socket. send ( Message :: Text ( msg. into ( ) ) ) . await . is_err ( ) {
86- return ;
87- }
75+ let initial = serde_json:: to_string ( & queue) . unwrap ( ) ;
8876
89- while let Ok ( event) = rx. recv ( ) . await {
90- match event {
91- QueueEvent :: Updated { guild_id : gid } => {
92- if gid == guild_id {
77+ let first =
78+ stream:: once ( async move { Ok :: < Event , Infallible > ( Event :: default ( ) . data ( initial) ) } ) ;
79+
80+ let updates = BroadcastStream :: new ( rx) . filter_map ( move |event| {
81+ let guild_id = guild_id. clone ( ) ;
82+ let state = state. clone ( ) ;
83+ async move {
84+ match event {
85+ Ok ( QueueEvent :: Updated { guild_id : gid } ) if gid == guild_id => {
9386 let queue = state. queue . list ( & guild_id) . await ;
94- let msg = serde_json:: to_string ( & queue) . unwrap ( ) ;
95- if socket. send ( Message :: Text ( msg. into ( ) ) ) . await . is_err ( ) {
96- break ;
97- }
87+ let data = serde_json:: to_string ( & queue) . unwrap ( ) ;
88+ Some ( Ok ( Event :: default ( ) . data ( data) ) )
9889 }
90+ _ => None ,
9991 }
10092 }
101- }
93+ } ) ;
94+
95+ Sse :: new ( first. chain ( updates) ) . keep_alive ( KeepAlive :: default ( ) )
10296}
0 commit comments