Skip to content

Commit baa8a49

Browse files
committed
add an actix-web server sse example
remove unused import
1 parent 4597d1f commit baa8a49

File tree

2 files changed

+177
-1
lines changed

2 files changed

+177
-1
lines changed

examples/servers/Cargo.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ rand = { version = "0.8" }
2525
[target.'cfg(not(target_arch = "wasm32"))'.dev-dependencies]
2626
axum = { version = "0.8", features = ["macros"] }
2727
tokio = { version = "1", features = ["full"] }
28+
actix-web = "4"
2829

2930
[target.'cfg(target_arch = "wasm32")'.dev-dependencies]
3031
tokio = { version = "1", features = ["io-util", "rt", "time", "macros"] }
@@ -40,4 +41,8 @@ path = "src/axum.rs"
4041

4142
[[example]]
4243
name = "wasi_std_io"
43-
path = "src/wasi_std_io.rs"
44+
path = "src/wasi_std_io.rs"
45+
46+
[[example]]
47+
name = "actix_web"
48+
path = "src/actix_web.rs"

examples/servers/src/actix_web.rs

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
use actix_web::web::{Bytes, Data, Payload, Query};
2+
use actix_web::{
3+
get, post, App, Error, HttpResponse, HttpServer, Result,
4+
};
5+
use futures::{StreamExt, TryStreamExt};
6+
use mcp_server::{ByteTransport, Server};
7+
use std::collections::HashMap;
8+
use tokio_util::codec::FramedRead;
9+
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
10+
11+
use actix_web::middleware::Logger;
12+
use mcp_server::router::RouterService;
13+
use std::sync::Arc;
14+
use tokio::{
15+
io::{self, AsyncWriteExt},
16+
sync::Mutex,
17+
};
18+
use tracing_subscriber;
19+
mod common;
20+
use common::counter;
21+
22+
type C2SWriter = Arc<Mutex<io::WriteHalf<io::SimplexStream>>>;
23+
type SessionId = Arc<str>;
24+
25+
const BIND_ADDRESS: &str = "127.0.0.1:8000";
26+
27+
#[derive(Clone, Default)]
28+
pub struct AppState {
29+
txs: Arc<tokio::sync::RwLock<HashMap<SessionId, C2SWriter>>>,
30+
}
31+
32+
impl AppState {
33+
pub fn new() -> Self {
34+
Self {
35+
txs: Default::default(),
36+
}
37+
}
38+
}
39+
40+
fn session_id() -> SessionId {
41+
let id = format!("{:016x}", rand::random::<u128>());
42+
Arc::from(id)
43+
}
44+
45+
#[derive(Debug, serde::Deserialize)]
46+
#[serde(rename_all = "camelCase")]
47+
pub struct PostEventQuery {
48+
pub session_id: String,
49+
}
50+
51+
#[post("/sse")]
52+
async fn post_event_handler(
53+
app_state: Data<AppState>,
54+
query: Query<PostEventQuery>,
55+
mut payload: Payload,
56+
) -> Result<HttpResponse, actix_web::Error> {
57+
const BODY_BYTES_LIMIT: usize = 1 << 22;
58+
let session_id = &query.session_id;
59+
60+
let write_stream = {
61+
let rg = app_state.txs.read().await;
62+
match rg.get(session_id.as_str()) {
63+
Some(stream) => stream.clone(),
64+
None => return Ok(HttpResponse::NotFound().finish()),
65+
}
66+
};
67+
68+
let mut write_stream = write_stream.lock().await;
69+
let mut size = 0;
70+
71+
// Process the request body in chunks
72+
while let Some(chunk) = payload.next().await {
73+
let chunk = chunk?;
74+
size += chunk.len();
75+
if size > BODY_BYTES_LIMIT {
76+
return Ok(HttpResponse::PayloadTooLarge().finish());
77+
}
78+
79+
if let Err(_) = write_stream.write_all(&chunk).await {
80+
return Ok(HttpResponse::InternalServerError().finish());
81+
}
82+
}
83+
84+
if let Err(_) = write_stream.write_u8(b'\n').await {
85+
return Ok(HttpResponse::InternalServerError().finish());
86+
}
87+
88+
Ok(HttpResponse::Accepted().finish())
89+
}
90+
91+
#[get("/sse")]
92+
async fn sse_handler(app_state: Data<AppState>) -> Result<HttpResponse, Error> {
93+
// it's 4KB
94+
const BUFFER_SIZE: usize = 1 << 12;
95+
let session = session_id();
96+
tracing::info!(%session, "sse connection");
97+
98+
let (c2s_read, c2s_write) = tokio::io::simplex(BUFFER_SIZE);
99+
let (s2c_read, s2c_write) = tokio::io::simplex(BUFFER_SIZE);
100+
101+
app_state
102+
.txs
103+
.write()
104+
.await
105+
.insert(session.clone(), Arc::new(Mutex::new(c2s_write)));
106+
107+
{
108+
let session = session.clone();
109+
let app_state = app_state.clone();
110+
tokio::spawn(async move {
111+
let router = RouterService(counter::CounterRouter::new());
112+
let server = Server::new(router);
113+
let bytes_transport = ByteTransport::new(c2s_read, s2c_write);
114+
let _result = server
115+
.run(bytes_transport)
116+
.await
117+
.inspect_err(|e| tracing::error!(?e, "server run error"));
118+
tracing::info!(%session, "connection closed, removing session");
119+
app_state.txs.write().await.remove(&session);
120+
});
121+
}
122+
123+
// Create SSE stream with correct types
124+
let stream = futures::stream::once(futures::future::ready(Ok::<_, io::Error>(Bytes::from(
125+
format!("event: endpoint\ndata: ?sessionId={}\n\n", session),
126+
))))
127+
.chain(
128+
FramedRead::new(s2c_read, common::jsonrpc_frame_codec::JsonRpcFrameCodec)
129+
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
130+
.map_ok(move |bytes| {
131+
let message = match std::str::from_utf8(&bytes) {
132+
Ok(message) => format!("event: message\ndata: {}\n\n", message),
133+
Err(_) => format!("event: error\ndata: Invalid UTF-8 data\n\n"),
134+
};
135+
Bytes::from(message)
136+
}),
137+
);
138+
139+
// Return SSE response
140+
Ok(HttpResponse::Ok()
141+
.append_header(("Content-Type", "text/event-stream"))
142+
.append_header(("Cache-Control", "no-cache"))
143+
.append_header(("Connection", "keep-alive"))
144+
.streaming(stream))
145+
}
146+
147+
#[actix_web::main]
148+
async fn main() -> io::Result<()> {
149+
tracing_subscriber::registry()
150+
.with(
151+
tracing_subscriber::EnvFilter::try_from_default_env()
152+
.unwrap_or_else(|_| format!("info,{}=debug", env!("CARGO_CRATE_NAME")).into()),
153+
)
154+
.with(tracing_subscriber::fmt::layer())
155+
.init();
156+
157+
tracing::debug!("starting server at {}", BIND_ADDRESS);
158+
159+
let app_state = Data::new(AppState::new());
160+
161+
HttpServer::new(move || {
162+
App::new()
163+
.wrap(Logger::default())
164+
.app_data(app_state.clone())
165+
.service(sse_handler)
166+
.service(post_event_handler)
167+
})
168+
.bind(BIND_ADDRESS)?
169+
.run()
170+
.await
171+
}

0 commit comments

Comments
 (0)