From e82eed4f12c7cc716584c1fac16e4d16749ffd31 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Sun, 2 Mar 2025 00:17:08 +0800 Subject: [PATCH 1/6] add axum sse example --- crates/mcp-server/Cargo.toml | 10 + .../mcp-server/examples/sse-axim/counter.rs | 180 ++++++++++++++++++ crates/mcp-server/examples/sse-axim/main.rs | 177 +++++++++++++++++ crates/mcp-server/src/lib.rs | 3 +- 4 files changed, 369 insertions(+), 1 deletion(-) create mode 100644 crates/mcp-server/examples/sse-axim/counter.rs create mode 100644 crates/mcp-server/examples/sse-axim/main.rs diff --git a/crates/mcp-server/Cargo.toml b/crates/mcp-server/Cargo.toml index 6baa7fe8..075d2c19 100644 --- a/crates/mcp-server/Cargo.toml +++ b/crates/mcp-server/Cargo.toml @@ -23,3 +23,13 @@ tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-appender = "0.2" async-trait = "0.1" + + +[dev-dependencies] +axum = { version = "0.8", features = ["macros"] } +tokio-util = { version = "0.7", features = ["io", "codec"]} +rand = { version = "0.8" } + +[[example]] +name = "example-sse-axum" +path = "examples/sse-axum/main.rs" diff --git a/crates/mcp-server/examples/sse-axim/counter.rs b/crates/mcp-server/examples/sse-axim/counter.rs new file mode 100644 index 00000000..557f3a0b --- /dev/null +++ b/crates/mcp-server/examples/sse-axim/counter.rs @@ -0,0 +1,180 @@ +use std::{future::Future, pin::Pin, sync::Arc}; + +use mcp_core::{handler::{PromptError, ResourceError}, prompt::{Prompt, PromptArgument}, protocol::ServerCapabilities, Content, Resource, Tool, ToolError}; +use mcp_server::router::CapabilitiesBuilder; +use serde_json::Value; +use tokio::sync::Mutex; + +#[derive(Clone)] +pub struct CounterRouter { + counter: Arc>, +} + +impl CounterRouter { + pub fn new() -> Self { + Self { + counter: Arc::new(Mutex::new(0)), + } + } + + async fn increment(&self) -> Result { + let mut counter = self.counter.lock().await; + *counter += 1; + Ok(*counter) + } + + async fn decrement(&self) -> Result { + let mut counter = self.counter.lock().await; + *counter -= 1; + Ok(*counter) + } + + async fn get_value(&self) -> Result { + let counter = self.counter.lock().await; + Ok(*counter) + } + + fn _create_resource_text(&self, uri: &str, name: &str) -> Resource { + Resource::new(uri, Some("text/plain".to_string()), Some(name.to_string())).unwrap() + } +} + +impl mcp_server::Router for CounterRouter { + fn name(&self) -> String { + "counter".to_string() + } + + fn instructions(&self) -> String { + "This server provides a counter tool that can increment and decrement values. The counter starts at 0 and can be modified using the 'increment' and 'decrement' tools. Use 'get_value' to check the current count.".to_string() + } + + fn capabilities(&self) -> ServerCapabilities { + CapabilitiesBuilder::new() + .with_tools(false) + .with_resources(false, false) + .with_prompts(false) + .build() + } + + fn list_tools(&self) -> Vec { + vec![ + Tool::new( + "increment".to_string(), + "Increment the counter by 1".to_string(), + serde_json::json!({ + "type": "object", + "properties": {}, + "required": [] + }), + ), + Tool::new( + "decrement".to_string(), + "Decrement the counter by 1".to_string(), + serde_json::json!({ + "type": "object", + "properties": {}, + "required": [] + }), + ), + Tool::new( + "get_value".to_string(), + "Get the current counter value".to_string(), + serde_json::json!({ + "type": "object", + "properties": {}, + "required": [] + }), + ), + ] + } + + fn call_tool( + &self, + tool_name: &str, + _arguments: Value, + ) -> Pin, ToolError>> + Send + 'static>> { + let this = self.clone(); + let tool_name = tool_name.to_string(); + + Box::pin(async move { + match tool_name.as_str() { + "increment" => { + let value = this.increment().await?; + Ok(vec![Content::text(value.to_string())]) + } + "decrement" => { + let value = this.decrement().await?; + Ok(vec![Content::text(value.to_string())]) + } + "get_value" => { + let value = this.get_value().await?; + Ok(vec![Content::text(value.to_string())]) + } + _ => Err(ToolError::NotFound(format!("Tool {} not found", tool_name))), + } + }) + } + + fn list_resources(&self) -> Vec { + vec![ + self._create_resource_text("str:////Users/to/some/path/", "cwd"), + self._create_resource_text("memo://insights", "memo-name"), + ] + } + + fn read_resource( + &self, + uri: &str, + ) -> Pin> + Send + 'static>> { + let uri = uri.to_string(); + Box::pin(async move { + match uri.as_str() { + "str:////Users/to/some/path/" => { + let cwd = "/Users/to/some/path/"; + Ok(cwd.to_string()) + } + "memo://insights" => { + let memo = + "Business Intelligence Memo\n\nAnalysis has revealed 5 key insights ..."; + Ok(memo.to_string()) + } + _ => Err(ResourceError::NotFound(format!( + "Resource {} not found", + uri + ))), + } + }) + } + + fn list_prompts(&self) -> Vec { + vec![Prompt::new( + "example_prompt", + Some("This is an example prompt that takes one required agrument, message"), + Some(vec![PromptArgument { + name: "message".to_string(), + description: Some("A message to put in the prompt".to_string()), + required: Some(true), + }]), + )] + } + + fn get_prompt( + &self, + prompt_name: &str, + ) -> Pin> + Send + 'static>> { + let prompt_name = prompt_name.to_string(); + Box::pin(async move { + match prompt_name.as_str() { + "example_prompt" => { + let prompt = "This is an example prompt with your message here: '{message}'"; + Ok(prompt.to_string()) + } + _ => Err(PromptError::NotFound(format!( + "Prompt {} not found", + prompt_name + ))), + } + }) + } +} + diff --git a/crates/mcp-server/examples/sse-axim/main.rs b/crates/mcp-server/examples/sse-axim/main.rs new file mode 100644 index 00000000..8106618d --- /dev/null +++ b/crates/mcp-server/examples/sse-axim/main.rs @@ -0,0 +1,177 @@ +use axum::{ + body::Body, + extract::{Query, State}, + http::StatusCode, + response::sse::{Event, Sse}, + routing::get, + Router, +}; +use futures::{stream::Stream, StreamExt, TryStreamExt}; +use mcp_server::{ByteTransport, Server}; +use std::collections::HashMap; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +use anyhow::Result; +use mcp_server::router::RouterService; +use std::sync::Arc; +use tokio::{ + io::{self, AsyncWriteExt}, + sync::Mutex, +}; +use tracing_subscriber::{self}; +mod counter; + +type C2SWriter = Arc>>; +type SessionId = Arc; + +const BIND_ADDRESS: &str = "127.0.0.1:8000"; + +#[derive(Clone, Default)] +pub struct App { + txs: Arc>>, +} + +impl App { + pub fn new() -> Self { + Self { + txs: Default::default(), + } + } + pub fn router(&self) -> Router { + Router::new() + .route("/sse", get(sse_handler).post(post_event_handler)) + .with_state(self.clone()) + } +} + +fn session_id() -> SessionId { + let id = format!("{:016x}", rand::random::()); + Arc::from(id) +} + +#[derive(Debug, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PostEventQuery { + pub session_id: String, +} + +async fn post_event_handler( + State(app): State, + Query(PostEventQuery { session_id }): Query, + body: Body, +) -> Result { + const BODY_BYTES_LIMIT: usize = 1 << 22; + let write_stream = { + let rg = app.txs.read().await; + rg.get(session_id.as_str()) + .ok_or(StatusCode::NOT_FOUND)? + .clone() + }; + let mut write_stream = write_stream.lock().await; + let mut body = body.into_data_stream(); + if let (_, Some(size)) = body.size_hint() { + if size > BODY_BYTES_LIMIT { + return Err(StatusCode::PAYLOAD_TOO_LARGE); + } + } + // calculate the body size + let mut size = 0; + while let Some(chunk) = body.next().await { + let Ok(chunk) = chunk else { + return Err(StatusCode::BAD_REQUEST); + }; + size += chunk.len(); + if size > BODY_BYTES_LIMIT { + return Err(StatusCode::PAYLOAD_TOO_LARGE); + } + write_stream + .write_all(&chunk) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + } + write_stream + .write_u8(b'\n') + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + Ok(StatusCode::ACCEPTED) +} + +async fn sse_handler(State(app): State) -> Sse>> { + // it's 4KB + const BUFFER_SIZE: usize = 1 << 12; + let session = session_id(); + tracing::info!(%session, "sse connection"); + let (c2s_read, c2s_write) = tokio::io::simplex(BUFFER_SIZE); + let (s2c_read, s2c_write) = tokio::io::simplex(BUFFER_SIZE); + app.txs + .write() + .await + .insert(session.clone(), Arc::new(Mutex::new(c2s_write))); + { + let session = session.clone(); + tokio::spawn(async move { + let router = RouterService(counter::CounterRouter::new()); + let server = Server::new(router); + let bytes_transport = ByteTransport::new(c2s_read, s2c_write); + let _result = server + .run(bytes_transport) + .await + .inspect_err(|e| tracing::error!(?e, "server run error")); + app.txs.write().await.remove(&session); + }); + } + + use tokio_util::codec::{Decoder, FramedRead}; + #[derive(Default)] + pub struct LinesBytesCodec; + impl Decoder for LinesBytesCodec { + type Item = tokio_util::bytes::Bytes; + type Error = io::Error; + fn decode( + &mut self, + src: &mut tokio_util::bytes::BytesMut, + ) -> Result, Self::Error> { + if let Some(end) = src + .iter() + .enumerate() + .find_map(|(idx, &b)| (b == b'\n').then_some(idx)) + { + let line = src.split_to(end); + let _char_next_line = src.split_to(1); + Ok(Some(line.freeze())) + } else { + Ok(None) + } + } + } + + let stream = futures::stream::once(futures::future::ok( + Event::default() + .event("endpoint") + .data(format!("?sessionId={session}")), + )) + .chain( + FramedRead::new(s2c_read, LinesBytesCodec) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) + .and_then(move |bytes| match std::str::from_utf8(&bytes) { + Ok(message) => futures::future::ok(Event::default().event("message").data(message)), + Err(e) => futures::future::err(io::Error::new(io::ErrorKind::InvalidData, e)), + }), + ); + Sse::new(stream) +} + +#[tokio::main] +async fn main() -> io::Result<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| format!("info,{}=debug", env!("CARGO_CRATE_NAME")).into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + let listener = tokio::net::TcpListener::bind(BIND_ADDRESS).await?; + + tracing::debug!("listening on {}", listener.local_addr()?); + axum::serve(listener, App::new().router()).await +} diff --git a/crates/mcp-server/src/lib.rs b/crates/mcp-server/src/lib.rs index 97594523..7588617d 100644 --- a/crates/mcp-server/src/lib.rs +++ b/crates/mcp-server/src/lib.rs @@ -142,7 +142,8 @@ where tracing::info!("Server started"); while let Some(msg_result) = transport.next().await { - let _span = tracing::span!(tracing::Level::INFO, "message_processing").entered(); + let _span = tracing::span!(tracing::Level::INFO, "message_processing"); + let _enter = _span.enter(); match msg_result { Ok(msg) => { match msg { From 4663da3a12977125974a9bcaf773690fe58a6607 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Mon, 3 Mar 2025 09:07:04 +0800 Subject: [PATCH 2/6] fix: make the sse axum folder name correct --- crates/mcp-server/examples/{sse-axim => sse-axum}/counter.rs | 0 crates/mcp-server/examples/{sse-axim => sse-axum}/main.rs | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename crates/mcp-server/examples/{sse-axim => sse-axum}/counter.rs (100%) rename crates/mcp-server/examples/{sse-axim => sse-axum}/main.rs (100%) diff --git a/crates/mcp-server/examples/sse-axim/counter.rs b/crates/mcp-server/examples/sse-axum/counter.rs similarity index 100% rename from crates/mcp-server/examples/sse-axim/counter.rs rename to crates/mcp-server/examples/sse-axum/counter.rs diff --git a/crates/mcp-server/examples/sse-axim/main.rs b/crates/mcp-server/examples/sse-axum/main.rs similarity index 100% rename from crates/mcp-server/examples/sse-axim/main.rs rename to crates/mcp-server/examples/sse-axum/main.rs From 830f2a6b3541ca01b421d67c1979e2945a2b77f0 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Mon, 3 Mar 2025 09:58:17 +0800 Subject: [PATCH 3/6] feat: add common modules for Axum SSE example This commit adds supporting modules for the Axum Server-Sent Events (SSE) example: - Added `counter.rs` with a simple counter router implementation - Added `jsonrpc_frame_codec.rs` for decoding JSON-RPC frames - Created a `mod.rs` to expose these modules - Removed the previous SSE example configuration from Cargo.toml --- crates/mcp-server/Cargo.toml | 6 +--- .../examples/{sse-axum/main.rs => axum.rs} | 30 +++---------------- .../examples/{sse-axum => common}/counter.rs | 3 +- .../examples/common/jsonrpc_frame_codec.rs | 24 +++++++++++++++ crates/mcp-server/examples/common/mod.rs | 2 ++ 5 files changed, 32 insertions(+), 33 deletions(-) rename crates/mcp-server/examples/{sse-axum/main.rs => axum.rs} (84%) rename crates/mcp-server/examples/{sse-axum => common}/counter.rs (99%) create mode 100644 crates/mcp-server/examples/common/jsonrpc_frame_codec.rs create mode 100644 crates/mcp-server/examples/common/mod.rs diff --git a/crates/mcp-server/Cargo.toml b/crates/mcp-server/Cargo.toml index 075d2c19..f06bfeda 100644 --- a/crates/mcp-server/Cargo.toml +++ b/crates/mcp-server/Cargo.toml @@ -28,8 +28,4 @@ async-trait = "0.1" [dev-dependencies] axum = { version = "0.8", features = ["macros"] } tokio-util = { version = "0.7", features = ["io", "codec"]} -rand = { version = "0.8" } - -[[example]] -name = "example-sse-axum" -path = "examples/sse-axum/main.rs" +rand = { version = "0.8" } \ No newline at end of file diff --git a/crates/mcp-server/examples/sse-axum/main.rs b/crates/mcp-server/examples/axum.rs similarity index 84% rename from crates/mcp-server/examples/sse-axum/main.rs rename to crates/mcp-server/examples/axum.rs index 8106618d..7c3e4b93 100644 --- a/crates/mcp-server/examples/sse-axum/main.rs +++ b/crates/mcp-server/examples/axum.rs @@ -9,6 +9,7 @@ use axum::{ use futures::{stream::Stream, StreamExt, TryStreamExt}; use mcp_server::{ByteTransport, Server}; use std::collections::HashMap; +use tokio_util::codec::FramedRead; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use anyhow::Result; @@ -19,7 +20,8 @@ use tokio::{ sync::Mutex, }; use tracing_subscriber::{self}; -mod counter; +mod common; +use common::counter; type C2SWriter = Arc>>; type SessionId = Arc; @@ -121,37 +123,13 @@ async fn sse_handler(State(app): State) -> Sse Result, Self::Error> { - if let Some(end) = src - .iter() - .enumerate() - .find_map(|(idx, &b)| (b == b'\n').then_some(idx)) - { - let line = src.split_to(end); - let _char_next_line = src.split_to(1); - Ok(Some(line.freeze())) - } else { - Ok(None) - } - } - } - let stream = futures::stream::once(futures::future::ok( Event::default() .event("endpoint") .data(format!("?sessionId={session}")), )) .chain( - FramedRead::new(s2c_read, LinesBytesCodec) + FramedRead::new(s2c_read, common::jsonrpc_frame_codec::JsonRpcFrameCodec) .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) .and_then(move |bytes| match std::str::from_utf8(&bytes) { Ok(message) => futures::future::ok(Event::default().event("message").data(message)), diff --git a/crates/mcp-server/examples/sse-axum/counter.rs b/crates/mcp-server/examples/common/counter.rs similarity index 99% rename from crates/mcp-server/examples/sse-axum/counter.rs rename to crates/mcp-server/examples/common/counter.rs index 557f3a0b..a1081984 100644 --- a/crates/mcp-server/examples/sse-axum/counter.rs +++ b/crates/mcp-server/examples/common/counter.rs @@ -176,5 +176,4 @@ impl mcp_server::Router for CounterRouter { } }) } -} - +} \ No newline at end of file diff --git a/crates/mcp-server/examples/common/jsonrpc_frame_codec.rs b/crates/mcp-server/examples/common/jsonrpc_frame_codec.rs new file mode 100644 index 00000000..4368118c --- /dev/null +++ b/crates/mcp-server/examples/common/jsonrpc_frame_codec.rs @@ -0,0 +1,24 @@ +use tokio_util::codec::Decoder; + +#[derive(Default)] +pub struct JsonRpcFrameCodec; +impl Decoder for JsonRpcFrameCodec { + type Item = tokio_util::bytes::Bytes; + type Error = tokio::io::Error; + fn decode( + &mut self, + src: &mut tokio_util::bytes::BytesMut, + ) -> Result, Self::Error> { + if let Some(end) = src + .iter() + .enumerate() + .find_map(|(idx, &b)| (b == b'\n').then_some(idx)) + { + let line = src.split_to(end); + let _char_next_line = src.split_to(1); + Ok(Some(line.freeze())) + } else { + Ok(None) + } + } +} diff --git a/crates/mcp-server/examples/common/mod.rs b/crates/mcp-server/examples/common/mod.rs new file mode 100644 index 00000000..02b1a761 --- /dev/null +++ b/crates/mcp-server/examples/common/mod.rs @@ -0,0 +1,2 @@ +pub mod counter; +pub mod jsonrpc_frame_codec; \ No newline at end of file From f541447266c94e51c03e503ede26b3f737169bcb Mon Sep 17 00:00:00 2001 From: RWDai <27391645+RWDai@users.noreply.github.com> Date: Mon, 3 Mar 2025 11:02:17 +0800 Subject: [PATCH 4/6] doc: add Poem SSE example with MCP server integration --- crates/mcp-server/Cargo.toml | 5 +- crates/mcp-server/examples/common/counter.rs | 9 +- crates/mcp-server/examples/common/mod.rs | 2 +- crates/mcp-server/examples/poem.rs | 156 +++++++++++++++++++ 4 files changed, 167 insertions(+), 5 deletions(-) create mode 100644 crates/mcp-server/examples/poem.rs diff --git a/crates/mcp-server/Cargo.toml b/crates/mcp-server/Cargo.toml index f06bfeda..cc7c7214 100644 --- a/crates/mcp-server/Cargo.toml +++ b/crates/mcp-server/Cargo.toml @@ -26,6 +26,7 @@ async-trait = "0.1" [dev-dependencies] +poem = { version = "3.1.7", features = ["sse"] } axum = { version = "0.8", features = ["macros"] } -tokio-util = { version = "0.7", features = ["io", "codec"]} -rand = { version = "0.8" } \ No newline at end of file +tokio-util = { version = "0.7", features = ["io", "codec"] } +rand = { version = "0.8" } diff --git a/crates/mcp-server/examples/common/counter.rs b/crates/mcp-server/examples/common/counter.rs index a1081984..4936a612 100644 --- a/crates/mcp-server/examples/common/counter.rs +++ b/crates/mcp-server/examples/common/counter.rs @@ -1,6 +1,11 @@ use std::{future::Future, pin::Pin, sync::Arc}; -use mcp_core::{handler::{PromptError, ResourceError}, prompt::{Prompt, PromptArgument}, protocol::ServerCapabilities, Content, Resource, Tool, ToolError}; +use mcp_core::{ + handler::{PromptError, ResourceError}, + prompt::{Prompt, PromptArgument}, + protocol::ServerCapabilities, + Content, Resource, Tool, ToolError, +}; use mcp_server::router::CapabilitiesBuilder; use serde_json::Value; use tokio::sync::Mutex; @@ -176,4 +181,4 @@ impl mcp_server::Router for CounterRouter { } }) } -} \ No newline at end of file +} diff --git a/crates/mcp-server/examples/common/mod.rs b/crates/mcp-server/examples/common/mod.rs index 02b1a761..2a456e37 100644 --- a/crates/mcp-server/examples/common/mod.rs +++ b/crates/mcp-server/examples/common/mod.rs @@ -1,2 +1,2 @@ pub mod counter; -pub mod jsonrpc_frame_codec; \ No newline at end of file +pub mod jsonrpc_frame_codec; diff --git a/crates/mcp-server/examples/poem.rs b/crates/mcp-server/examples/poem.rs new file mode 100644 index 00000000..d0ae6845 --- /dev/null +++ b/crates/mcp-server/examples/poem.rs @@ -0,0 +1,156 @@ +use futures::StreamExt; +use mcp_server::{ByteTransport, Server as McpServer}; +use poem::{ + handler, + http::StatusCode, + listener::TcpListener, + web::{ + sse::{Event, SSE}, + Data, Query, + }, + Body, EndpointExt, Error, IntoResponse, Route, Server, +}; +use std::collections::HashMap; +use tokio_util::codec::FramedRead; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +use mcp_server::router::RouterService; +use std::sync::Arc; +use tokio::{ + io::{self, AsyncWriteExt}, + sync::Mutex, +}; +use tracing_subscriber::{self}; + +mod common; +use common::counter; + +type C2SWriter = Arc>>; +type SessionId = Arc; + +const BIND_ADDRESS: &str = "127.0.0.1:8000"; + +#[derive(Clone, Default)] +pub struct App { + txs: Arc>>, +} + +impl App { + pub fn new() -> Self { + Self { + txs: Default::default(), + } + } + + pub fn route(&self) -> impl poem::Endpoint { + Route::new() + .at("/sse", poem::get(sse_handler).post(post_event_handler)) + .data(self.clone()) + } +} + +fn session_id() -> SessionId { + let id = format!("{:016x}", rand::random::()); + Arc::from(id) +} + +#[derive(Debug, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PostEventQuery { + pub session_id: String, +} + +#[handler] +async fn post_event_handler( + app: Data<&App>, + Query(query): Query, + body: Body, +) -> poem::Result { + const BODY_BYTES_LIMIT: usize = 1 << 22; + let write_stream = { + let rg = app.txs.read().await; + rg.get(query.session_id.as_str()) + .ok_or_else(|| Error::from_string("Session not found", StatusCode::NOT_FOUND))? + .clone() + }; + let mut write_stream = write_stream.lock().await; + let bytes = body.into_bytes().await?; + if bytes.len() > BODY_BYTES_LIMIT { + return Err(Error::from_string( + "Payload too large", + StatusCode::PAYLOAD_TOO_LARGE, + )); + } + write_stream + .write_all(&bytes) + .await + .map_err(|e| Error::from_string(e.to_string(), StatusCode::INTERNAL_SERVER_ERROR))?; + write_stream + .write_u8(b'\n') + .await + .map_err(|e| Error::from_string(e.to_string(), StatusCode::INTERNAL_SERVER_ERROR))?; + Ok(StatusCode::ACCEPTED) +} + +#[handler] +async fn sse_handler(app: Data<&App>) -> impl IntoResponse { + const BUFFER_SIZE: usize = 1 << 12; + let session = session_id(); + tracing::info!(%session, "sse connection"); + let (c2s_read, c2s_write) = tokio::io::simplex(BUFFER_SIZE); + let (s2c_read, s2c_write) = tokio::io::simplex(BUFFER_SIZE); + app.txs + .write() + .await + .insert(session.clone(), Arc::new(Mutex::new(c2s_write))); + + { + let session = session.clone(); + let app = app.clone(); + tokio::spawn(async move { + let router = RouterService(counter::CounterRouter::new()); + let server = McpServer::new(router); + let bytes_transport = ByteTransport::new(c2s_read, s2c_write); + let _result = server + .run(bytes_transport) + .await + .inspect_err(|e| tracing::error!(?e, "server run error")); + app.txs.write().await.remove(&session); + }); + } + + let stream = futures::stream::once(futures::future::ready( + Event::message(format!("?sessionId={session}")).event_type("endpoint"), + )) + .chain( + FramedRead::new(s2c_read, common::jsonrpc_frame_codec::JsonRpcFrameCodec).map(|result| { + match result { + Ok(bytes) => match std::str::from_utf8(&bytes) { + Ok(message) => Event::message(message), + Err(e) => Event::message(format!("Error: {}", e)), + }, + Err(e) => Event::message(format!("Error: {}", e)), + } + }), + ); + + SSE::new(stream) +} + +#[tokio::main] +async fn main() -> io::Result<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| format!("info,{}=debug", env!("CARGO_CRATE_NAME")).into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let app = App::new(); + let listener = TcpListener::bind(BIND_ADDRESS); + + tracing::debug!("listening on {}", BIND_ADDRESS); + Server::new(listener).run(app.route()).await?; + Ok(()) +} From 957f183b3fa37ffce72f91fc47684484f8a51934 Mon Sep 17 00:00:00 2001 From: 4t145 Date: Tue, 4 Mar 2025 09:42:38 +0800 Subject: [PATCH 5/6] move examples to root --- crates/mcp-server/Cargo.toml | 5 - crates/mcp-server/examples/common/mod.rs | 2 - examples/servers/Cargo.toml | 9 + .../examples => examples/servers/src}/axum.rs | 0 .../servers/src}/common/counter.rs | 9 +- .../src}/common/jsonrpc_frame_codec.rs | 0 examples/servers/src/common/mod.rs | 2 + examples/servers/src/counter_server.rs | 191 +----------------- 8 files changed, 23 insertions(+), 195 deletions(-) delete mode 100644 crates/mcp-server/examples/common/mod.rs rename {crates/mcp-server/examples => examples/servers/src}/axum.rs (100%) rename {crates/mcp-server/examples => examples/servers/src}/common/counter.rs (96%) rename {crates/mcp-server/examples => examples/servers/src}/common/jsonrpc_frame_codec.rs (100%) create mode 100644 examples/servers/src/common/mod.rs diff --git a/crates/mcp-server/Cargo.toml b/crates/mcp-server/Cargo.toml index f06bfeda..048f18ff 100644 --- a/crates/mcp-server/Cargo.toml +++ b/crates/mcp-server/Cargo.toml @@ -24,8 +24,3 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-appender = "0.2" async-trait = "0.1" - -[dev-dependencies] -axum = { version = "0.8", features = ["macros"] } -tokio-util = { version = "0.7", features = ["io", "codec"]} -rand = { version = "0.8" } \ No newline at end of file diff --git a/crates/mcp-server/examples/common/mod.rs b/crates/mcp-server/examples/common/mod.rs deleted file mode 100644 index 02b1a761..00000000 --- a/crates/mcp-server/examples/common/mod.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub mod counter; -pub mod jsonrpc_frame_codec; \ No newline at end of file diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index 5a64c19c..c8d47783 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -17,6 +17,15 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-appender = "0.2" futures = "0.3" +[dev-dependencies] +axum = { version = "0.8", features = ["macros"] } +tokio-util = { version = "0.7", features = ["io", "codec"]} +rand = { version = "0.8" } + [[example]] name = "counter-server" path = "src/counter_server.rs" + +[[example]] +name = "axum" +path = "src/axum.rs" \ No newline at end of file diff --git a/crates/mcp-server/examples/axum.rs b/examples/servers/src/axum.rs similarity index 100% rename from crates/mcp-server/examples/axum.rs rename to examples/servers/src/axum.rs diff --git a/crates/mcp-server/examples/common/counter.rs b/examples/servers/src/common/counter.rs similarity index 96% rename from crates/mcp-server/examples/common/counter.rs rename to examples/servers/src/common/counter.rs index a1081984..4936a612 100644 --- a/crates/mcp-server/examples/common/counter.rs +++ b/examples/servers/src/common/counter.rs @@ -1,6 +1,11 @@ use std::{future::Future, pin::Pin, sync::Arc}; -use mcp_core::{handler::{PromptError, ResourceError}, prompt::{Prompt, PromptArgument}, protocol::ServerCapabilities, Content, Resource, Tool, ToolError}; +use mcp_core::{ + handler::{PromptError, ResourceError}, + prompt::{Prompt, PromptArgument}, + protocol::ServerCapabilities, + Content, Resource, Tool, ToolError, +}; use mcp_server::router::CapabilitiesBuilder; use serde_json::Value; use tokio::sync::Mutex; @@ -176,4 +181,4 @@ impl mcp_server::Router for CounterRouter { } }) } -} \ No newline at end of file +} diff --git a/crates/mcp-server/examples/common/jsonrpc_frame_codec.rs b/examples/servers/src/common/jsonrpc_frame_codec.rs similarity index 100% rename from crates/mcp-server/examples/common/jsonrpc_frame_codec.rs rename to examples/servers/src/common/jsonrpc_frame_codec.rs diff --git a/examples/servers/src/common/mod.rs b/examples/servers/src/common/mod.rs new file mode 100644 index 00000000..2a456e37 --- /dev/null +++ b/examples/servers/src/common/mod.rs @@ -0,0 +1,2 @@ +pub mod counter; +pub mod jsonrpc_frame_codec; diff --git a/examples/servers/src/counter_server.rs b/examples/servers/src/counter_server.rs index 907cc1b1..d9fdf614 100644 --- a/examples/servers/src/counter_server.rs +++ b/examples/servers/src/counter_server.rs @@ -1,192 +1,11 @@ use anyhow::Result; -use mcp_core::content::Content; -use mcp_core::handler::{PromptError, ResourceError}; -use mcp_core::prompt::{Prompt, PromptArgument}; -use mcp_core::{handler::ToolError, protocol::ServerCapabilities, resource::Resource, tool::Tool}; -use mcp_server::router::{CapabilitiesBuilder, RouterService}; -use mcp_server::{ByteTransport, Router, Server}; -use serde_json::Value; -use std::{future::Future, pin::Pin, sync::Arc}; -use tokio::{ - io::{stdin, stdout}, - sync::Mutex, -}; +use mcp_server::router::RouterService; +use mcp_server::{ByteTransport, Server}; +use tokio::io::{stdin, stdout}; use tracing_appender::rolling::{RollingFileAppender, Rotation}; use tracing_subscriber::{self, EnvFilter}; -// A simple counter service that demonstrates the Router trait -#[derive(Clone)] -struct CounterRouter { - counter: Arc>, -} - -impl CounterRouter { - fn new() -> Self { - Self { - counter: Arc::new(Mutex::new(0)), - } - } - - async fn increment(&self) -> Result { - let mut counter = self.counter.lock().await; - *counter += 1; - Ok(*counter) - } - - async fn decrement(&self) -> Result { - let mut counter = self.counter.lock().await; - *counter -= 1; - Ok(*counter) - } - - async fn get_value(&self) -> Result { - let counter = self.counter.lock().await; - Ok(*counter) - } - - fn _create_resource_text(&self, uri: &str, name: &str) -> Resource { - Resource::new(uri, Some("text/plain".to_string()), Some(name.to_string())).unwrap() - } -} - -impl Router for CounterRouter { - fn name(&self) -> String { - "counter".to_string() - } - - fn instructions(&self) -> String { - "This server provides a counter tool that can increment and decrement values. The counter starts at 0 and can be modified using the 'increment' and 'decrement' tools. Use 'get_value' to check the current count.".to_string() - } - - fn capabilities(&self) -> ServerCapabilities { - CapabilitiesBuilder::new() - .with_tools(false) - .with_resources(false, false) - .with_prompts(false) - .build() - } - - fn list_tools(&self) -> Vec { - vec![ - Tool::new( - "increment".to_string(), - "Increment the counter by 1".to_string(), - serde_json::json!({ - "type": "object", - "properties": {}, - "required": [] - }), - ), - Tool::new( - "decrement".to_string(), - "Decrement the counter by 1".to_string(), - serde_json::json!({ - "type": "object", - "properties": {}, - "required": [] - }), - ), - Tool::new( - "get_value".to_string(), - "Get the current counter value".to_string(), - serde_json::json!({ - "type": "object", - "properties": {}, - "required": [] - }), - ), - ] - } - - fn call_tool( - &self, - tool_name: &str, - _arguments: Value, - ) -> Pin, ToolError>> + Send + 'static>> { - let this = self.clone(); - let tool_name = tool_name.to_string(); - - Box::pin(async move { - match tool_name.as_str() { - "increment" => { - let value = this.increment().await?; - Ok(vec![Content::text(value.to_string())]) - } - "decrement" => { - let value = this.decrement().await?; - Ok(vec![Content::text(value.to_string())]) - } - "get_value" => { - let value = this.get_value().await?; - Ok(vec![Content::text(value.to_string())]) - } - _ => Err(ToolError::NotFound(format!("Tool {} not found", tool_name))), - } - }) - } - - fn list_resources(&self) -> Vec { - vec![ - self._create_resource_text("str:////Users/to/some/path/", "cwd"), - self._create_resource_text("memo://insights", "memo-name"), - ] - } - - fn read_resource( - &self, - uri: &str, - ) -> Pin> + Send + 'static>> { - let uri = uri.to_string(); - Box::pin(async move { - match uri.as_str() { - "str:////Users/to/some/path/" => { - let cwd = "/Users/to/some/path/"; - Ok(cwd.to_string()) - } - "memo://insights" => { - let memo = - "Business Intelligence Memo\n\nAnalysis has revealed 5 key insights ..."; - Ok(memo.to_string()) - } - _ => Err(ResourceError::NotFound(format!( - "Resource {} not found", - uri - ))), - } - }) - } - - fn list_prompts(&self) -> Vec { - vec![Prompt::new( - "example_prompt", - Some("This is an example prompt that takes one required agrument, message"), - Some(vec![PromptArgument { - name: "message".to_string(), - description: Some("A message to put in the prompt".to_string()), - required: Some(true), - }]), - )] - } - - fn get_prompt( - &self, - prompt_name: &str, - ) -> Pin> + Send + 'static>> { - let prompt_name = prompt_name.to_string(); - Box::pin(async move { - match prompt_name.as_str() { - "example_prompt" => { - let prompt = "This is an example prompt with your message here: '{message}'"; - Ok(prompt.to_string()) - } - _ => Err(PromptError::NotFound(format!( - "Prompt {} not found", - prompt_name - ))), - } - }) - } -} +mod common; #[tokio::main] async fn main() -> Result<()> { @@ -206,7 +25,7 @@ async fn main() -> Result<()> { tracing::info!("Starting MCP server"); // Create an instance of our counter router - let router = RouterService(CounterRouter::new()); + let router = RouterService(common::counter::CounterRouter::new()); // Create and run the server let server = Server::new(router); From c725e4b4d757a85e42ebd4f241ed6c3dd53def6e Mon Sep 17 00:00:00 2001 From: RWDai <27391645+RWDai@users.noreply.github.com> Date: Tue, 4 Mar 2025 10:08:12 +0800 Subject: [PATCH 6/6] merge upstream code --- examples/servers/Cargo.toml | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index c8d47783..2e8b0073 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -19,7 +19,8 @@ futures = "0.3" [dev-dependencies] axum = { version = "0.8", features = ["macros"] } -tokio-util = { version = "0.7", features = ["io", "codec"]} +poem = { version = "3.1.7", features = ["sse"] } +tokio-util = { version = "0.7", features = ["io", "codec"] } rand = { version = "0.8" } [[example]] @@ -28,4 +29,8 @@ path = "src/counter_server.rs" [[example]] name = "axum" -path = "src/axum.rs" \ No newline at end of file +path = "src/axum.rs" + +[[example]] +name = "poem" +path = "src/poem.rs"