@@ -10,9 +10,7 @@ use axum::{
1010use clap:: Parser ;
1111use std:: cmp:: Ordering ;
1212use std:: collections:: BinaryHeap ;
13- use std:: future:: Future ;
1413use std:: net:: SocketAddr ;
15- use std:: pin:: Pin ;
1614use std:: sync:: Arc ;
1715use std:: time:: Duration ;
1816use tokio:: {
@@ -49,29 +47,37 @@ impl Timer for WallTimeTimer {
4947}
5048
5149trait Writer : Send + Sync + ' static {
52- // NOTE(dkorolev): Using `Arc<Self>` is a workaround to avoid the `Send` constraint.
53- // NOTE(dkorolev): I'd love this to be an `async fn`, but alas, does not play well with recursion and `axum`.
54- fn write_text (
55- self : Arc < Self > , text : String , timestamp : Option < LogicalTimeMs > ,
56- ) -> Pin < Box < dyn Future < Output = Result < ( ) , axum:: Error > > + Send > > ;
50+ async fn write_text ( & self , text : String , timestamp : Option < LogicalTimeMs > ) -> Result < ( ) , Box < dyn std:: error:: Error > >
51+ where
52+ Self : Send ;
5753}
5854
59- #[ derive( Clone ) ]
6055struct WebSocketWriter {
61- socket : Arc < Mutex < WebSocket > > ,
56+ sender : mpsc:: Sender < String > ,
57+ _task : tokio:: task:: JoinHandle < ( ) > ,
6258}
6359
6460impl WebSocketWriter {
6561 fn new ( socket : WebSocket ) -> Self {
66- Self { socket : Arc :: new ( Mutex :: new ( socket) ) }
62+ let ( sender, mut receiver) = mpsc:: channel :: < String > ( 100 ) ;
63+ let mut socket = socket;
64+
65+ let task = tokio:: spawn ( async move {
66+ while let Some ( text) = receiver. recv ( ) . await {
67+ let _ = socket. send ( Message :: Text ( text. into ( ) ) ) . await ;
68+ }
69+ } ) ;
70+
71+ Self { sender, _task : task }
6772 }
6873}
6974
7075impl Writer for WebSocketWriter {
71- fn write_text (
72- self : Arc < Self > , text : String , _timestamp : Option < LogicalTimeMs > ,
73- ) -> Pin < Box < dyn Future < Output = Result < ( ) , axum:: Error > > + Send > > {
74- Box :: pin ( async move { self . socket . lock ( ) . await . send ( Message :: Text ( text. into ( ) ) ) . await } )
76+ async fn write_text (
77+ & self , text : String , _timestamp : Option < LogicalTimeMs > ,
78+ ) -> Result < ( ) , Box < dyn std:: error:: Error > > {
79+ self . sender . send ( text) . await . map_err ( Box :: new) ?;
80+ Ok ( ( ) )
7581 }
7682}
7783
@@ -342,33 +348,28 @@ fn ackermann(m: u64, n: u64) -> u64 {
342348 }
343349}
344350
345- // NOTE(dkorolev): Even though the socket is "single-threaded", we still use a `Mutex` for now, because
346- // the `on_upgrade` operation in `axum` for WebSocket-s assumes the execution may span thread boundaries.
347- fn async_ack < W : Writer > (
348- w : Arc < W > , m : i64 , n : i64 , indent : usize ,
349- ) -> Pin < Box < dyn Future < Output = Result < i64 , axum:: Error > > + Send > > {
350- Box :: pin ( async move {
351- let indentation = " " . repeat ( indent) ;
352- if m == 0 {
353- tokio:: time:: sleep ( std:: time:: Duration :: from_millis ( 10 ) ) . await ;
354- w. write_text ( format ! ( "{indentation}ack({m},{n}) = {n} + 1" ) , None ) . await ?;
355- Ok ( n + 1 )
356- } else {
357- Arc :: clone ( & w) . write_text ( format ! ( "{}ack({m},{n}) ..." , indentation) , None ) . await ?;
358-
359- let r = match ( m, n) {
360- ( 0 , n) => n + 1 ,
361- ( m, 0 ) => async_ack ( Arc :: clone ( & w) , m - 1 , 1 , indent + 2 ) . await ?,
362- ( m, n) => {
363- async_ack ( Arc :: clone ( & w) , m - 1 , async_ack ( Arc :: clone ( & w) , m, n - 1 , indent + 2 ) . await ?, indent + 2 ) . await ?
364- }
365- } ;
351+ async fn async_ack < W : Writer > ( w : Arc < W > , m : i64 , n : i64 , indent : usize ) -> Result < i64 , Box < dyn std:: error:: Error > > {
352+ let indentation = " " . repeat ( indent) ;
353+ if m == 0 {
354+ tokio:: time:: sleep ( std:: time:: Duration :: from_millis ( 10 ) ) . await ;
355+ w. write_text ( format ! ( "{indentation}ack({m},{n}) = {n} + 1" ) , None ) . await ?;
356+ Ok ( n + 1 )
357+ } else {
358+ w. write_text ( format ! ( "{}ack({m},{n}) ..." , indentation) , None ) . await ?;
359+
360+ let r = match ( m, n) {
361+ ( 0 , n) => n + 1 ,
362+ ( m, 0 ) => Box :: pin ( async_ack ( Arc :: clone ( & w) , m - 1 , 1 , indent + 2 ) ) . await ?,
363+ ( m, n) => {
364+ let inner_result = Box :: pin ( async_ack ( Arc :: clone ( & w) , m, n - 1 , indent + 2 ) ) . await ?;
365+ Box :: pin ( async_ack ( Arc :: clone ( & w) , m - 1 , inner_result, indent + 2 ) ) . await ?
366+ }
367+ } ;
366368
367- tokio:: time:: sleep ( std:: time:: Duration :: from_millis ( 10 ) ) . await ;
368- w. write_text ( format ! ( "{}ack({m},{n}) = {r}" , indentation) , None ) . await ?;
369- Ok ( r)
370- }
371- } )
369+ tokio:: time:: sleep ( std:: time:: Duration :: from_millis ( 10 ) ) . await ;
370+ w. write_text ( format ! ( "{}ack({m},{n}) = {r}" , indentation) , None ) . await ?;
371+ Ok ( r)
372+ }
372373}
373374
374375async fn ackermann_handler_ws < T : Timer > ( socket : WebSocket , m : i64 , n : i64 , _state : Arc < AppState < T , WebSocketWriter > > ) {
@@ -610,14 +611,14 @@ mod tests {
610611 }
611612
612613 impl < T : Timer > Writer for MockWriter < T > {
613- fn write_text (
614- self : Arc < Self > , text : String , timestamp : Option < LogicalTimeMs > ,
615- ) -> Pin < Box < dyn Future < Output = Result < ( ) , axum :: Error > > + Send > > {
616- Box :: pin ( async move {
617- let time_to_use = timestamp . unwrap_or_else ( || self . timer . millis_since_start ( ) ) ;
618- self . outputs . lock ( ) . unwrap ( ) . push ( format ! ( "{time_to_use}ms:{text}" ) ) ;
619- Ok ( ( ) )
620- } )
614+ async fn write_text (
615+ & self , text : String , timestamp : Option < LogicalTimeMs > ,
616+ ) -> Result < ( ) , Box < dyn std :: error :: Error > > {
617+ let timer = Arc :: clone ( & self . timer ) ;
618+ let outputs = Arc :: clone ( & self . outputs ) ;
619+ let time_to_use = timestamp . unwrap_or_else ( || timer . millis_since_start ( ) ) ;
620+ outputs . lock ( ) . unwrap ( ) . push ( format ! ( "{time_to_use}ms:{text}" ) ) ;
621+ Ok ( ( ) )
621622 }
622623 }
623624
0 commit comments