Skip to content

Commit c0bd94d

Browse files
authored
Add an axum sse example (#12)
* add axum sse example * fix: make the sse axum folder name correct * feat: add common modules for Axum SSE example This commit adds supporting modules for the Axum Server-Sent Events (SSE) example: - Added `counter.rs` with a simple counter router implementation - Added `jsonrpc_frame_codec.rs` for decoding JSON-RPC frames - Created a `mod.rs` to expose these modules - Removed the previous SSE example configuration from Cargo.toml * move examples to root * nit: remove extra newline here
1 parent 6609aa6 commit c0bd94d

File tree

7 files changed

+381
-187
lines changed

7 files changed

+381
-187
lines changed

crates/mcp-server/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ where
142142

143143
tracing::info!("Server started");
144144
while let Some(msg_result) = transport.next().await {
145-
let _span = tracing::span!(tracing::Level::INFO, "message_processing").entered();
145+
let _span = tracing::span!(tracing::Level::INFO, "message_processing");
146+
let _enter = _span.enter();
146147
match msg_result {
147148
Ok(msg) => {
148149
match msg {

examples/servers/Cargo.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] }
1717
tracing-appender = "0.2"
1818
futures = "0.3"
1919

20+
[dev-dependencies]
21+
axum = { version = "0.8", features = ["macros"] }
22+
tokio-util = { version = "0.7", features = ["io", "codec"]}
23+
rand = { version = "0.8" }
24+
2025
[[example]]
2126
name = "counter-server"
2227
path = "src/counter_server.rs"
28+
29+
[[example]]
30+
name = "axum"
31+
path = "src/axum.rs"

examples/servers/src/axum.rs

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
use axum::{
2+
body::Body,
3+
extract::{Query, State},
4+
http::StatusCode,
5+
response::sse::{Event, Sse},
6+
routing::get,
7+
Router,
8+
};
9+
use futures::{stream::Stream, StreamExt, TryStreamExt};
10+
use mcp_server::{ByteTransport, Server};
11+
use std::collections::HashMap;
12+
use tokio_util::codec::FramedRead;
13+
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
14+
15+
use anyhow::Result;
16+
use mcp_server::router::RouterService;
17+
use std::sync::Arc;
18+
use tokio::{
19+
io::{self, AsyncWriteExt},
20+
sync::Mutex,
21+
};
22+
use tracing_subscriber::{self};
23+
mod common;
24+
use common::counter;
25+
26+
type C2SWriter = Arc<Mutex<io::WriteHalf<io::SimplexStream>>>;
27+
type SessionId = Arc<str>;
28+
29+
const BIND_ADDRESS: &str = "127.0.0.1:8000";
30+
31+
#[derive(Clone, Default)]
32+
pub struct App {
33+
txs: Arc<tokio::sync::RwLock<HashMap<SessionId, C2SWriter>>>,
34+
}
35+
36+
impl App {
37+
pub fn new() -> Self {
38+
Self {
39+
txs: Default::default(),
40+
}
41+
}
42+
pub fn router(&self) -> Router {
43+
Router::new()
44+
.route("/sse", get(sse_handler).post(post_event_handler))
45+
.with_state(self.clone())
46+
}
47+
}
48+
49+
fn session_id() -> SessionId {
50+
let id = format!("{:016x}", rand::random::<u128>());
51+
Arc::from(id)
52+
}
53+
54+
#[derive(Debug, serde::Deserialize)]
55+
#[serde(rename_all = "camelCase")]
56+
pub struct PostEventQuery {
57+
pub session_id: String,
58+
}
59+
60+
async fn post_event_handler(
61+
State(app): State<App>,
62+
Query(PostEventQuery { session_id }): Query<PostEventQuery>,
63+
body: Body,
64+
) -> Result<StatusCode, StatusCode> {
65+
const BODY_BYTES_LIMIT: usize = 1 << 22;
66+
let write_stream = {
67+
let rg = app.txs.read().await;
68+
rg.get(session_id.as_str())
69+
.ok_or(StatusCode::NOT_FOUND)?
70+
.clone()
71+
};
72+
let mut write_stream = write_stream.lock().await;
73+
let mut body = body.into_data_stream();
74+
if let (_, Some(size)) = body.size_hint() {
75+
if size > BODY_BYTES_LIMIT {
76+
return Err(StatusCode::PAYLOAD_TOO_LARGE);
77+
}
78+
}
79+
// calculate the body size
80+
let mut size = 0;
81+
while let Some(chunk) = body.next().await {
82+
let Ok(chunk) = chunk else {
83+
return Err(StatusCode::BAD_REQUEST);
84+
};
85+
size += chunk.len();
86+
if size > BODY_BYTES_LIMIT {
87+
return Err(StatusCode::PAYLOAD_TOO_LARGE);
88+
}
89+
write_stream
90+
.write_all(&chunk)
91+
.await
92+
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
93+
}
94+
write_stream
95+
.write_u8(b'\n')
96+
.await
97+
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
98+
Ok(StatusCode::ACCEPTED)
99+
}
100+
101+
async fn sse_handler(State(app): State<App>) -> Sse<impl Stream<Item = Result<Event, io::Error>>> {
102+
// it's 4KB
103+
const BUFFER_SIZE: usize = 1 << 12;
104+
let session = session_id();
105+
tracing::info!(%session, "sse connection");
106+
let (c2s_read, c2s_write) = tokio::io::simplex(BUFFER_SIZE);
107+
let (s2c_read, s2c_write) = tokio::io::simplex(BUFFER_SIZE);
108+
app.txs
109+
.write()
110+
.await
111+
.insert(session.clone(), Arc::new(Mutex::new(c2s_write)));
112+
{
113+
let session = session.clone();
114+
tokio::spawn(async move {
115+
let router = RouterService(counter::CounterRouter::new());
116+
let server = Server::new(router);
117+
let bytes_transport = ByteTransport::new(c2s_read, s2c_write);
118+
let _result = server
119+
.run(bytes_transport)
120+
.await
121+
.inspect_err(|e| tracing::error!(?e, "server run error"));
122+
app.txs.write().await.remove(&session);
123+
});
124+
}
125+
126+
let stream = futures::stream::once(futures::future::ok(
127+
Event::default()
128+
.event("endpoint")
129+
.data(format!("?sessionId={session}")),
130+
))
131+
.chain(
132+
FramedRead::new(s2c_read, common::jsonrpc_frame_codec::JsonRpcFrameCodec)
133+
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
134+
.and_then(move |bytes| match std::str::from_utf8(&bytes) {
135+
Ok(message) => futures::future::ok(Event::default().event("message").data(message)),
136+
Err(e) => futures::future::err(io::Error::new(io::ErrorKind::InvalidData, e)),
137+
}),
138+
);
139+
Sse::new(stream)
140+
}
141+
142+
#[tokio::main]
143+
async fn main() -> io::Result<()> {
144+
tracing_subscriber::registry()
145+
.with(
146+
tracing_subscriber::EnvFilter::try_from_default_env()
147+
.unwrap_or_else(|_| format!("info,{}=debug", env!("CARGO_CRATE_NAME")).into()),
148+
)
149+
.with(tracing_subscriber::fmt::layer())
150+
.init();
151+
let listener = tokio::net::TcpListener::bind(BIND_ADDRESS).await?;
152+
153+
tracing::debug!("listening on {}", listener.local_addr()?);
154+
axum::serve(listener, App::new().router()).await
155+
}
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
use std::{future::Future, pin::Pin, sync::Arc};
2+
3+
use mcp_core::{
4+
handler::{PromptError, ResourceError},
5+
prompt::{Prompt, PromptArgument},
6+
protocol::ServerCapabilities,
7+
Content, Resource, Tool, ToolError,
8+
};
9+
use mcp_server::router::CapabilitiesBuilder;
10+
use serde_json::Value;
11+
use tokio::sync::Mutex;
12+
13+
#[derive(Clone)]
14+
pub struct CounterRouter {
15+
counter: Arc<Mutex<i32>>,
16+
}
17+
18+
impl CounterRouter {
19+
pub fn new() -> Self {
20+
Self {
21+
counter: Arc::new(Mutex::new(0)),
22+
}
23+
}
24+
25+
async fn increment(&self) -> Result<i32, ToolError> {
26+
let mut counter = self.counter.lock().await;
27+
*counter += 1;
28+
Ok(*counter)
29+
}
30+
31+
async fn decrement(&self) -> Result<i32, ToolError> {
32+
let mut counter = self.counter.lock().await;
33+
*counter -= 1;
34+
Ok(*counter)
35+
}
36+
37+
async fn get_value(&self) -> Result<i32, ToolError> {
38+
let counter = self.counter.lock().await;
39+
Ok(*counter)
40+
}
41+
42+
fn _create_resource_text(&self, uri: &str, name: &str) -> Resource {
43+
Resource::new(uri, Some("text/plain".to_string()), Some(name.to_string())).unwrap()
44+
}
45+
}
46+
47+
impl mcp_server::Router for CounterRouter {
48+
fn name(&self) -> String {
49+
"counter".to_string()
50+
}
51+
52+
fn instructions(&self) -> String {
53+
"This server provides a counter tool that can increment and decrement values. The counter starts at 0 and can be modified using the 'increment' and 'decrement' tools. Use 'get_value' to check the current count.".to_string()
54+
}
55+
56+
fn capabilities(&self) -> ServerCapabilities {
57+
CapabilitiesBuilder::new()
58+
.with_tools(false)
59+
.with_resources(false, false)
60+
.with_prompts(false)
61+
.build()
62+
}
63+
64+
fn list_tools(&self) -> Vec<Tool> {
65+
vec![
66+
Tool::new(
67+
"increment".to_string(),
68+
"Increment the counter by 1".to_string(),
69+
serde_json::json!({
70+
"type": "object",
71+
"properties": {},
72+
"required": []
73+
}),
74+
),
75+
Tool::new(
76+
"decrement".to_string(),
77+
"Decrement the counter by 1".to_string(),
78+
serde_json::json!({
79+
"type": "object",
80+
"properties": {},
81+
"required": []
82+
}),
83+
),
84+
Tool::new(
85+
"get_value".to_string(),
86+
"Get the current counter value".to_string(),
87+
serde_json::json!({
88+
"type": "object",
89+
"properties": {},
90+
"required": []
91+
}),
92+
),
93+
]
94+
}
95+
96+
fn call_tool(
97+
&self,
98+
tool_name: &str,
99+
_arguments: Value,
100+
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
101+
let this = self.clone();
102+
let tool_name = tool_name.to_string();
103+
104+
Box::pin(async move {
105+
match tool_name.as_str() {
106+
"increment" => {
107+
let value = this.increment().await?;
108+
Ok(vec![Content::text(value.to_string())])
109+
}
110+
"decrement" => {
111+
let value = this.decrement().await?;
112+
Ok(vec![Content::text(value.to_string())])
113+
}
114+
"get_value" => {
115+
let value = this.get_value().await?;
116+
Ok(vec![Content::text(value.to_string())])
117+
}
118+
_ => Err(ToolError::NotFound(format!("Tool {} not found", tool_name))),
119+
}
120+
})
121+
}
122+
123+
fn list_resources(&self) -> Vec<Resource> {
124+
vec![
125+
self._create_resource_text("str:////Users/to/some/path/", "cwd"),
126+
self._create_resource_text("memo://insights", "memo-name"),
127+
]
128+
}
129+
130+
fn read_resource(
131+
&self,
132+
uri: &str,
133+
) -> Pin<Box<dyn Future<Output = Result<String, ResourceError>> + Send + 'static>> {
134+
let uri = uri.to_string();
135+
Box::pin(async move {
136+
match uri.as_str() {
137+
"str:////Users/to/some/path/" => {
138+
let cwd = "/Users/to/some/path/";
139+
Ok(cwd.to_string())
140+
}
141+
"memo://insights" => {
142+
let memo =
143+
"Business Intelligence Memo\n\nAnalysis has revealed 5 key insights ...";
144+
Ok(memo.to_string())
145+
}
146+
_ => Err(ResourceError::NotFound(format!(
147+
"Resource {} not found",
148+
uri
149+
))),
150+
}
151+
})
152+
}
153+
154+
fn list_prompts(&self) -> Vec<Prompt> {
155+
vec![Prompt::new(
156+
"example_prompt",
157+
Some("This is an example prompt that takes one required agrument, message"),
158+
Some(vec![PromptArgument {
159+
name: "message".to_string(),
160+
description: Some("A message to put in the prompt".to_string()),
161+
required: Some(true),
162+
}]),
163+
)]
164+
}
165+
166+
fn get_prompt(
167+
&self,
168+
prompt_name: &str,
169+
) -> Pin<Box<dyn Future<Output = Result<String, PromptError>> + Send + 'static>> {
170+
let prompt_name = prompt_name.to_string();
171+
Box::pin(async move {
172+
match prompt_name.as_str() {
173+
"example_prompt" => {
174+
let prompt = "This is an example prompt with your message here: '{message}'";
175+
Ok(prompt.to_string())
176+
}
177+
_ => Err(PromptError::NotFound(format!(
178+
"Prompt {} not found",
179+
prompt_name
180+
))),
181+
}
182+
})
183+
}
184+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
use tokio_util::codec::Decoder;
2+
3+
#[derive(Default)]
4+
pub struct JsonRpcFrameCodec;
5+
impl Decoder for JsonRpcFrameCodec {
6+
type Item = tokio_util::bytes::Bytes;
7+
type Error = tokio::io::Error;
8+
fn decode(
9+
&mut self,
10+
src: &mut tokio_util::bytes::BytesMut,
11+
) -> Result<Option<Self::Item>, Self::Error> {
12+
if let Some(end) = src
13+
.iter()
14+
.enumerate()
15+
.find_map(|(idx, &b)| (b == b'\n').then_some(idx))
16+
{
17+
let line = src.split_to(end);
18+
let _char_next_line = src.split_to(1);
19+
Ok(Some(line.freeze()))
20+
} else {
21+
Ok(None)
22+
}
23+
}
24+
}

examples/servers/src/common/mod.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
pub mod counter;
2+
pub mod jsonrpc_frame_codec;

0 commit comments

Comments
 (0)