Skip to content

Commit 4f9e239

Browse files
authored
Merge pull request #38 from dimacurrentai/step11
Move the `WebSocket` out of `Arc<Mutex<...>>`, and ... wrapped it into a channel!
2 parents 109cd62 + ba6de6c commit 4f9e239

File tree

1 file changed

+49
-48
lines changed
  • step11_ws_state_machine/code/src

1 file changed

+49
-48
lines changed

step11_ws_state_machine/code/src/main.rs

Lines changed: 49 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@ use axum::{
1010
use clap::Parser;
1111
use std::cmp::Ordering;
1212
use std::collections::BinaryHeap;
13-
use std::future::Future;
1413
use std::net::SocketAddr;
15-
use std::pin::Pin;
1614
use std::sync::Arc;
1715
use std::time::Duration;
1816
use tokio::{
@@ -49,29 +47,37 @@ impl Timer for WallTimeTimer {
4947
}
5048

5149
trait Writer: Send + Sync + 'static {
52-
// NOTE(dkorolev): Using `Arc<Self>` is a workaround to avoid the `Send` constraint.
53-
// NOTE(dkorolev): I'd love this to be an `async fn`, but alas, does not play well with recursion and `axum`.
54-
fn write_text(
55-
self: Arc<Self>, text: String, timestamp: Option<LogicalTimeMs>,
56-
) -> Pin<Box<dyn Future<Output = Result<(), axum::Error>> + Send>>;
50+
async fn write_text(&self, text: String, timestamp: Option<LogicalTimeMs>) -> Result<(), Box<dyn std::error::Error>>
51+
where
52+
Self: Send;
5753
}
5854

59-
#[derive(Clone)]
6055
struct WebSocketWriter {
61-
socket: Arc<Mutex<WebSocket>>,
56+
sender: mpsc::Sender<String>,
57+
_task: tokio::task::JoinHandle<()>,
6258
}
6359

6460
impl WebSocketWriter {
6561
fn new(socket: WebSocket) -> Self {
66-
Self { socket: Arc::new(Mutex::new(socket)) }
62+
let (sender, mut receiver) = mpsc::channel::<String>(100);
63+
let mut socket = socket;
64+
65+
let task = tokio::spawn(async move {
66+
while let Some(text) = receiver.recv().await {
67+
let _ = socket.send(Message::Text(text.into())).await;
68+
}
69+
});
70+
71+
Self { sender, _task: task }
6772
}
6873
}
6974

7075
impl Writer for WebSocketWriter {
71-
fn write_text(
72-
self: Arc<Self>, text: String, _timestamp: Option<LogicalTimeMs>,
73-
) -> Pin<Box<dyn Future<Output = Result<(), axum::Error>> + Send>> {
74-
Box::pin(async move { self.socket.lock().await.send(Message::Text(text.into())).await })
76+
async fn write_text(
77+
&self, text: String, _timestamp: Option<LogicalTimeMs>,
78+
) -> Result<(), Box<dyn std::error::Error>> {
79+
self.sender.send(text).await.map_err(Box::new)?;
80+
Ok(())
7581
}
7682
}
7783

@@ -342,33 +348,28 @@ fn ackermann(m: u64, n: u64) -> u64 {
342348
}
343349
}
344350

345-
// NOTE(dkorolev): Even though the socket is "single-threaded", we still use a `Mutex` for now, because
346-
// the `on_upgrade` operation in `axum` for WebSocket-s assumes the execution may span thread boundaries.
347-
fn async_ack<W: Writer>(
348-
w: Arc<W>, m: i64, n: i64, indent: usize,
349-
) -> Pin<Box<dyn Future<Output = Result<i64, axum::Error>> + Send>> {
350-
Box::pin(async move {
351-
let indentation = " ".repeat(indent);
352-
if m == 0 {
353-
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
354-
w.write_text(format!("{indentation}ack({m},{n}) = {n} + 1"), None).await?;
355-
Ok(n + 1)
356-
} else {
357-
Arc::clone(&w).write_text(format!("{}ack({m},{n}) ...", indentation), None).await?;
358-
359-
let r = match (m, n) {
360-
(0, n) => n + 1,
361-
(m, 0) => async_ack(Arc::clone(&w), m - 1, 1, indent + 2).await?,
362-
(m, n) => {
363-
async_ack(Arc::clone(&w), m - 1, async_ack(Arc::clone(&w), m, n - 1, indent + 2).await?, indent + 2).await?
364-
}
365-
};
351+
async fn async_ack<W: Writer>(w: Arc<W>, m: i64, n: i64, indent: usize) -> Result<i64, Box<dyn std::error::Error>> {
352+
let indentation = " ".repeat(indent);
353+
if m == 0 {
354+
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
355+
w.write_text(format!("{indentation}ack({m},{n}) = {n} + 1"), None).await?;
356+
Ok(n + 1)
357+
} else {
358+
w.write_text(format!("{}ack({m},{n}) ...", indentation), None).await?;
359+
360+
let r = match (m, n) {
361+
(0, n) => n + 1,
362+
(m, 0) => Box::pin(async_ack(Arc::clone(&w), m - 1, 1, indent + 2)).await?,
363+
(m, n) => {
364+
let inner_result = Box::pin(async_ack(Arc::clone(&w), m, n - 1, indent + 2)).await?;
365+
Box::pin(async_ack(Arc::clone(&w), m - 1, inner_result, indent + 2)).await?
366+
}
367+
};
366368

367-
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
368-
w.write_text(format!("{}ack({m},{n}) = {r}", indentation), None).await?;
369-
Ok(r)
370-
}
371-
})
369+
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
370+
w.write_text(format!("{}ack({m},{n}) = {r}", indentation), None).await?;
371+
Ok(r)
372+
}
372373
}
373374

374375
async fn ackermann_handler_ws<T: Timer>(socket: WebSocket, m: i64, n: i64, _state: Arc<AppState<T, WebSocketWriter>>) {
@@ -610,14 +611,14 @@ mod tests {
610611
}
611612

612613
impl<T: Timer> Writer for MockWriter<T> {
613-
fn write_text(
614-
self: Arc<Self>, text: String, timestamp: Option<LogicalTimeMs>,
615-
) -> Pin<Box<dyn Future<Output = Result<(), axum::Error>> + Send>> {
616-
Box::pin(async move {
617-
let time_to_use = timestamp.unwrap_or_else(|| self.timer.millis_since_start());
618-
self.outputs.lock().unwrap().push(format!("{time_to_use}ms:{text}"));
619-
Ok(())
620-
})
614+
async fn write_text(
615+
&self, text: String, timestamp: Option<LogicalTimeMs>,
616+
) -> Result<(), Box<dyn std::error::Error>> {
617+
let timer = Arc::clone(&self.timer);
618+
let outputs = Arc::clone(&self.outputs);
619+
let time_to_use = timestamp.unwrap_or_else(|| timer.millis_since_start());
620+
outputs.lock().unwrap().push(format!("{time_to_use}ms:{text}"));
621+
Ok(())
621622
}
622623
}
623624

0 commit comments

Comments
 (0)