Skip to content

Commit b3d655d

Browse files
authored
add an actix-web server sse example (#33)
* add an actix-web server sse example remove unused import * fix to apply cargo fmt and cargo clippy to pass the the repo checks
1 parent 9c33eed commit b3d655d

File tree

2 files changed

+174
-1
lines changed

2 files changed

+174
-1
lines changed

examples/servers/Cargo.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ rand = { version = "0.8" }
2525
[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
2626
axum = { version = "0.8", features = ["macros"] }
2727
tokio = { version = "1", features = ["full"] }
28+
actix-web = "4"
2829

2930
[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
3031
tokio = { version = "1", features = ["io-util", "rt", "time", "macros"] }
@@ -40,4 +41,8 @@ path = "src/axum.rs"
4041

4142
[[example]]
4243
name = "wasi_std_io"
43-
path = "src/wasi_std_io.rs"
44+
path = "src/wasi_std_io.rs"
45+
46+
[[example]]
47+
name = "actix_web"
48+
path = "src/actix_web.rs"

examples/servers/src/actix_web.rs

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
use actix_web::web::{Bytes, Data, Payload, Query};
2+
use actix_web::{get, post, App, Error, HttpResponse, HttpServer, Result};
3+
use futures::{StreamExt, TryStreamExt};
4+
use mcp_server::{ByteTransport, Server};
5+
use std::collections::HashMap;
6+
use tokio_util::codec::FramedRead;
7+
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
8+
9+
use actix_web::middleware::Logger;
10+
use mcp_server::router::RouterService;
11+
use std::sync::Arc;
12+
use tokio::{
13+
io::{self, AsyncWriteExt},
14+
sync::Mutex,
15+
};
16+
mod common;
17+
use common::counter;
18+
19+
type C2SWriter = Arc<Mutex<io::WriteHalf<io::SimplexStream>>>;
20+
type SessionId = Arc<str>;
21+
22+
const BIND_ADDRESS: &str = "127.0.0.1:8000";
23+
24+
#[derive(Clone, Default)]
25+
pub struct AppState {
26+
txs: Arc<tokio::sync::RwLock<HashMap<SessionId, C2SWriter>>>,
27+
}
28+
29+
impl AppState {
30+
pub fn new() -> Self {
31+
Self {
32+
txs: Default::default(),
33+
}
34+
}
35+
}
36+
37+
fn session_id() -> SessionId {
38+
let id = format!("{:016x}", rand::random::<u128>());
39+
Arc::from(id)
40+
}
41+
42+
#[derive(Debug, serde::Deserialize)]
43+
#[serde(rename_all = "camelCase")]
44+
pub struct PostEventQuery {
45+
pub session_id: String,
46+
}
47+
48+
#[post("/sse")]
49+
async fn post_event_handler(
50+
app_state: Data<AppState>,
51+
query: Query<PostEventQuery>,
52+
mut payload: Payload,
53+
) -> Result<HttpResponse, actix_web::Error> {
54+
const BODY_BYTES_LIMIT: usize = 1 << 22;
55+
let session_id = &query.session_id;
56+
57+
let write_stream = {
58+
let rg = app_state.txs.read().await;
59+
match rg.get(session_id.as_str()) {
60+
Some(stream) => stream.clone(),
61+
None => return Ok(HttpResponse::NotFound().finish()),
62+
}
63+
};
64+
65+
let mut write_stream = write_stream.lock().await;
66+
let mut size = 0;
67+
68+
// Process the request body in chunks
69+
while let Some(chunk) = payload.next().await {
70+
let chunk = chunk?;
71+
size += chunk.len();
72+
if size > BODY_BYTES_LIMIT {
73+
return Ok(HttpResponse::PayloadTooLarge().finish());
74+
}
75+
76+
if (write_stream.write_all(&chunk).await).is_err() {
77+
return Ok(HttpResponse::InternalServerError().finish());
78+
}
79+
}
80+
81+
if (write_stream.write_u8(b'\n').await).is_err() {
82+
return Ok(HttpResponse::InternalServerError().finish());
83+
}
84+
85+
Ok(HttpResponse::Accepted().finish())
86+
}
87+
88+
#[get("/sse")]
89+
async fn sse_handler(app_state: Data<AppState>) -> Result<HttpResponse, Error> {
90+
// it's 4KB
91+
const BUFFER_SIZE: usize = 1 << 12;
92+
let session = session_id();
93+
tracing::info!(%session, "sse connection");
94+
95+
let (c2s_read, c2s_write) = tokio::io::simplex(BUFFER_SIZE);
96+
let (s2c_read, s2c_write) = tokio::io::simplex(BUFFER_SIZE);
97+
98+
app_state
99+
.txs
100+
.write()
101+
.await
102+
.insert(session.clone(), Arc::new(Mutex::new(c2s_write)));
103+
104+
{
105+
let session = session.clone();
106+
let app_state = app_state.clone();
107+
tokio::spawn(async move {
108+
let router = RouterService(counter::CounterRouter::new());
109+
let server = Server::new(router);
110+
let bytes_transport = ByteTransport::new(c2s_read, s2c_write);
111+
let _result = server
112+
.run(bytes_transport)
113+
.await
114+
.inspect_err(|e| tracing::error!(?e, "server run error"));
115+
tracing::info!(%session, "connection closed, removing session");
116+
app_state.txs.write().await.remove(&session);
117+
});
118+
}
119+
120+
// Create SSE stream with correct types
121+
let stream = futures::stream::once(futures::future::ready(Ok::<_, io::Error>(Bytes::from(
122+
format!("event: endpoint\ndata: ?sessionId={}\n\n", session),
123+
))))
124+
.chain(
125+
FramedRead::new(s2c_read, common::jsonrpc_frame_codec::JsonRpcFrameCodec)
126+
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
127+
.map_ok(move |bytes| {
128+
let message = match std::str::from_utf8(&bytes) {
129+
Ok(message) => format!("event: message\ndata: {}\n\n", message),
130+
Err(_) => "event: error\ndata: Invalid UTF-8 data\n\n".to_string(),
131+
};
132+
Bytes::from(message)
133+
}),
134+
);
135+
136+
// Return SSE response
137+
Ok(HttpResponse::Ok()
138+
.append_header(("Content-Type", "text/event-stream"))
139+
.append_header(("Cache-Control", "no-cache"))
140+
.append_header(("Connection", "keep-alive"))
141+
.streaming(stream))
142+
}
143+
144+
#[actix_web::main]
145+
async fn main() -> io::Result<()> {
146+
tracing_subscriber::registry()
147+
.with(
148+
tracing_subscriber::EnvFilter::try_from_default_env()
149+
.unwrap_or_else(|_| format!("info,{}=debug", env!("CARGO_CRATE_NAME")).into()),
150+
)
151+
.with(tracing_subscriber::fmt::layer())
152+
.init();
153+
154+
tracing::debug!("starting server at {}", BIND_ADDRESS);
155+
156+
let app_state = Data::new(AppState::new());
157+
158+
HttpServer::new(move || {
159+
App::new()
160+
.wrap(Logger::default())
161+
.app_data(app_state.clone())
162+
.service(sse_handler)
163+
.service(post_event_handler)
164+
})
165+
.bind(BIND_ADDRESS)?
166+
.run()
167+
.await
168+
}

0 commit comments

Comments
 (0)