diff --git a/crates/mcp-server/src/lib.rs b/crates/mcp-server/src/lib.rs index 97594523..7588617d 100644 --- a/crates/mcp-server/src/lib.rs +++ b/crates/mcp-server/src/lib.rs @@ -142,7 +142,8 @@ where tracing::info!("Server started"); while let Some(msg_result) = transport.next().await { - let _span = tracing::span!(tracing::Level::INFO, "message_processing").entered(); + let _span = tracing::span!(tracing::Level::INFO, "message_processing"); + let _enter = _span.enter(); match msg_result { Ok(msg) => { match msg { diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index 5a64c19c..c8d47783 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -17,6 +17,15 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-appender = "0.2" futures = "0.3" +[dev-dependencies] +axum = { version = "0.8", features = ["macros"] } +tokio-util = { version = "0.7", features = ["io", "codec"]} +rand = { version = "0.8" } + [[example]] name = "counter-server" path = "src/counter_server.rs" + +[[example]] +name = "axum" +path = "src/axum.rs" \ No newline at end of file diff --git a/examples/servers/src/axum.rs b/examples/servers/src/axum.rs new file mode 100644 index 00000000..7c3e4b93 --- /dev/null +++ b/examples/servers/src/axum.rs @@ -0,0 +1,155 @@ +use axum::{ + body::Body, + extract::{Query, State}, + http::StatusCode, + response::sse::{Event, Sse}, + routing::get, + Router, +}; +use futures::{stream::Stream, StreamExt, TryStreamExt}; +use mcp_server::{ByteTransport, Server}; +use std::collections::HashMap; +use tokio_util::codec::FramedRead; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +use anyhow::Result; +use mcp_server::router::RouterService; +use std::sync::Arc; +use tokio::{ + io::{self, AsyncWriteExt}, + sync::Mutex, +}; +use tracing_subscriber::{self}; +mod common; +use common::counter; + +type C2SWriter = Arc>>; +type SessionId = Arc; + +const BIND_ADDRESS: &str = "127.0.0.1:8000"; + +#[derive(Clone, Default)] +pub struct App { + txs: Arc>>, +} + +impl App { + pub fn new() -> Self { + Self { + txs: Default::default(), + } + } + pub fn router(&self) -> Router { + Router::new() + .route("/sse", get(sse_handler).post(post_event_handler)) + .with_state(self.clone()) + } +} + +fn session_id() -> SessionId { + let id = format!("{:016x}", rand::random::()); + Arc::from(id) +} + +#[derive(Debug, serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct PostEventQuery { + pub session_id: String, +} + +async fn post_event_handler( + State(app): State, + Query(PostEventQuery { session_id }): Query, + body: Body, +) -> Result { + const BODY_BYTES_LIMIT: usize = 1 << 22; + let write_stream = { + let rg = app.txs.read().await; + rg.get(session_id.as_str()) + .ok_or(StatusCode::NOT_FOUND)? + .clone() + }; + let mut write_stream = write_stream.lock().await; + let mut body = body.into_data_stream(); + if let (_, Some(size)) = body.size_hint() { + if size > BODY_BYTES_LIMIT { + return Err(StatusCode::PAYLOAD_TOO_LARGE); + } + } + // calculate the body size + let mut size = 0; + while let Some(chunk) = body.next().await { + let Ok(chunk) = chunk else { + return Err(StatusCode::BAD_REQUEST); + }; + size += chunk.len(); + if size > BODY_BYTES_LIMIT { + return Err(StatusCode::PAYLOAD_TOO_LARGE); + } + write_stream + .write_all(&chunk) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + } + write_stream + .write_u8(b'\n') + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + Ok(StatusCode::ACCEPTED) +} + +async fn sse_handler(State(app): State) -> Sse>> { + // it's 4KB + const BUFFER_SIZE: usize = 1 << 12; + let session = session_id(); + tracing::info!(%session, "sse connection"); + let (c2s_read, c2s_write) = tokio::io::simplex(BUFFER_SIZE); + let (s2c_read, s2c_write) = tokio::io::simplex(BUFFER_SIZE); + app.txs + .write() + .await + .insert(session.clone(), Arc::new(Mutex::new(c2s_write))); + { + let session = session.clone(); + tokio::spawn(async move { + let router = RouterService(counter::CounterRouter::new()); + let server = Server::new(router); + let bytes_transport = ByteTransport::new(c2s_read, s2c_write); + let _result = server + .run(bytes_transport) + .await + .inspect_err(|e| tracing::error!(?e, "server run error")); + app.txs.write().await.remove(&session); + }); + } + + let stream = futures::stream::once(futures::future::ok( + Event::default() + .event("endpoint") + .data(format!("?sessionId={session}")), + )) + .chain( + FramedRead::new(s2c_read, common::jsonrpc_frame_codec::JsonRpcFrameCodec) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) + .and_then(move |bytes| match std::str::from_utf8(&bytes) { + Ok(message) => futures::future::ok(Event::default().event("message").data(message)), + Err(e) => futures::future::err(io::Error::new(io::ErrorKind::InvalidData, e)), + }), + ); + Sse::new(stream) +} + +#[tokio::main] +async fn main() -> io::Result<()> { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| format!("info,{}=debug", env!("CARGO_CRATE_NAME")).into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + let listener = tokio::net::TcpListener::bind(BIND_ADDRESS).await?; + + tracing::debug!("listening on {}", listener.local_addr()?); + axum::serve(listener, App::new().router()).await +} diff --git a/examples/servers/src/common/counter.rs b/examples/servers/src/common/counter.rs new file mode 100644 index 00000000..4936a612 --- /dev/null +++ b/examples/servers/src/common/counter.rs @@ -0,0 +1,184 @@ +use std::{future::Future, pin::Pin, sync::Arc}; + +use mcp_core::{ + handler::{PromptError, ResourceError}, + prompt::{Prompt, PromptArgument}, + protocol::ServerCapabilities, + Content, Resource, Tool, ToolError, +}; +use mcp_server::router::CapabilitiesBuilder; +use serde_json::Value; +use tokio::sync::Mutex; + +#[derive(Clone)] +pub struct CounterRouter { + counter: Arc>, +} + +impl CounterRouter { + pub fn new() -> Self { + Self { + counter: Arc::new(Mutex::new(0)), + } + } + + async fn increment(&self) -> Result { + let mut counter = self.counter.lock().await; + *counter += 1; + Ok(*counter) + } + + async fn decrement(&self) -> Result { + let mut counter = self.counter.lock().await; + *counter -= 1; + Ok(*counter) + } + + async fn get_value(&self) -> Result { + let counter = self.counter.lock().await; + Ok(*counter) + } + + fn _create_resource_text(&self, uri: &str, name: &str) -> Resource { + Resource::new(uri, Some("text/plain".to_string()), Some(name.to_string())).unwrap() + } +} + +impl mcp_server::Router for CounterRouter { + fn name(&self) -> String { + "counter".to_string() + } + + fn instructions(&self) -> String { + "This server provides a counter tool that can increment and decrement values. The counter starts at 0 and can be modified using the 'increment' and 'decrement' tools. Use 'get_value' to check the current count.".to_string() + } + + fn capabilities(&self) -> ServerCapabilities { + CapabilitiesBuilder::new() + .with_tools(false) + .with_resources(false, false) + .with_prompts(false) + .build() + } + + fn list_tools(&self) -> Vec { + vec![ + Tool::new( + "increment".to_string(), + "Increment the counter by 1".to_string(), + serde_json::json!({ + "type": "object", + "properties": {}, + "required": [] + }), + ), + Tool::new( + "decrement".to_string(), + "Decrement the counter by 1".to_string(), + serde_json::json!({ + "type": "object", + "properties": {}, + "required": [] + }), + ), + Tool::new( + "get_value".to_string(), + "Get the current counter value".to_string(), + serde_json::json!({ + "type": "object", + "properties": {}, + "required": [] + }), + ), + ] + } + + fn call_tool( + &self, + tool_name: &str, + _arguments: Value, + ) -> Pin, ToolError>> + Send + 'static>> { + let this = self.clone(); + let tool_name = tool_name.to_string(); + + Box::pin(async move { + match tool_name.as_str() { + "increment" => { + let value = this.increment().await?; + Ok(vec![Content::text(value.to_string())]) + } + "decrement" => { + let value = this.decrement().await?; + Ok(vec![Content::text(value.to_string())]) + } + "get_value" => { + let value = this.get_value().await?; + Ok(vec![Content::text(value.to_string())]) + } + _ => Err(ToolError::NotFound(format!("Tool {} not found", tool_name))), + } + }) + } + + fn list_resources(&self) -> Vec { + vec![ + self._create_resource_text("str:////Users/to/some/path/", "cwd"), + self._create_resource_text("memo://insights", "memo-name"), + ] + } + + fn read_resource( + &self, + uri: &str, + ) -> Pin> + Send + 'static>> { + let uri = uri.to_string(); + Box::pin(async move { + match uri.as_str() { + "str:////Users/to/some/path/" => { + let cwd = "/Users/to/some/path/"; + Ok(cwd.to_string()) + } + "memo://insights" => { + let memo = + "Business Intelligence Memo\n\nAnalysis has revealed 5 key insights ..."; + Ok(memo.to_string()) + } + _ => Err(ResourceError::NotFound(format!( + "Resource {} not found", + uri + ))), + } + }) + } + + fn list_prompts(&self) -> Vec { + vec![Prompt::new( + "example_prompt", + Some("This is an example prompt that takes one required agrument, message"), + Some(vec![PromptArgument { + name: "message".to_string(), + description: Some("A message to put in the prompt".to_string()), + required: Some(true), + }]), + )] + } + + fn get_prompt( + &self, + prompt_name: &str, + ) -> Pin> + Send + 'static>> { + let prompt_name = prompt_name.to_string(); + Box::pin(async move { + match prompt_name.as_str() { + "example_prompt" => { + let prompt = "This is an example prompt with your message here: '{message}'"; + Ok(prompt.to_string()) + } + _ => Err(PromptError::NotFound(format!( + "Prompt {} not found", + prompt_name + ))), + } + }) + } +} diff --git a/examples/servers/src/common/jsonrpc_frame_codec.rs b/examples/servers/src/common/jsonrpc_frame_codec.rs new file mode 100644 index 00000000..4368118c --- /dev/null +++ b/examples/servers/src/common/jsonrpc_frame_codec.rs @@ -0,0 +1,24 @@ +use tokio_util::codec::Decoder; + +#[derive(Default)] +pub struct JsonRpcFrameCodec; +impl Decoder for JsonRpcFrameCodec { + type Item = tokio_util::bytes::Bytes; + type Error = tokio::io::Error; + fn decode( + &mut self, + src: &mut tokio_util::bytes::BytesMut, + ) -> Result, Self::Error> { + if let Some(end) = src + .iter() + .enumerate() + .find_map(|(idx, &b)| (b == b'\n').then_some(idx)) + { + let line = src.split_to(end); + let _char_next_line = src.split_to(1); + Ok(Some(line.freeze())) + } else { + Ok(None) + } + } +} diff --git a/examples/servers/src/common/mod.rs b/examples/servers/src/common/mod.rs new file mode 100644 index 00000000..2a456e37 --- /dev/null +++ b/examples/servers/src/common/mod.rs @@ -0,0 +1,2 @@ +pub mod counter; +pub mod jsonrpc_frame_codec; diff --git a/examples/servers/src/counter_server.rs b/examples/servers/src/counter_server.rs index 907cc1b1..d9fdf614 100644 --- a/examples/servers/src/counter_server.rs +++ b/examples/servers/src/counter_server.rs @@ -1,192 +1,11 @@ use anyhow::Result; -use mcp_core::content::Content; -use mcp_core::handler::{PromptError, ResourceError}; -use mcp_core::prompt::{Prompt, PromptArgument}; -use mcp_core::{handler::ToolError, protocol::ServerCapabilities, resource::Resource, tool::Tool}; -use mcp_server::router::{CapabilitiesBuilder, RouterService}; -use mcp_server::{ByteTransport, Router, Server}; -use serde_json::Value; -use std::{future::Future, pin::Pin, sync::Arc}; -use tokio::{ - io::{stdin, stdout}, - sync::Mutex, -}; +use mcp_server::router::RouterService; +use mcp_server::{ByteTransport, Server}; +use tokio::io::{stdin, stdout}; use tracing_appender::rolling::{RollingFileAppender, Rotation}; use tracing_subscriber::{self, EnvFilter}; -// A simple counter service that demonstrates the Router trait -#[derive(Clone)] -struct CounterRouter { - counter: Arc>, -} - -impl CounterRouter { - fn new() -> Self { - Self { - counter: Arc::new(Mutex::new(0)), - } - } - - async fn increment(&self) -> Result { - let mut counter = self.counter.lock().await; - *counter += 1; - Ok(*counter) - } - - async fn decrement(&self) -> Result { - let mut counter = self.counter.lock().await; - *counter -= 1; - Ok(*counter) - } - - async fn get_value(&self) -> Result { - let counter = self.counter.lock().await; - Ok(*counter) - } - - fn _create_resource_text(&self, uri: &str, name: &str) -> Resource { - Resource::new(uri, Some("text/plain".to_string()), Some(name.to_string())).unwrap() - } -} - -impl Router for CounterRouter { - fn name(&self) -> String { - "counter".to_string() - } - - fn instructions(&self) -> String { - "This server provides a counter tool that can increment and decrement values. The counter starts at 0 and can be modified using the 'increment' and 'decrement' tools. Use 'get_value' to check the current count.".to_string() - } - - fn capabilities(&self) -> ServerCapabilities { - CapabilitiesBuilder::new() - .with_tools(false) - .with_resources(false, false) - .with_prompts(false) - .build() - } - - fn list_tools(&self) -> Vec { - vec![ - Tool::new( - "increment".to_string(), - "Increment the counter by 1".to_string(), - serde_json::json!({ - "type": "object", - "properties": {}, - "required": [] - }), - ), - Tool::new( - "decrement".to_string(), - "Decrement the counter by 1".to_string(), - serde_json::json!({ - "type": "object", - "properties": {}, - "required": [] - }), - ), - Tool::new( - "get_value".to_string(), - "Get the current counter value".to_string(), - serde_json::json!({ - "type": "object", - "properties": {}, - "required": [] - }), - ), - ] - } - - fn call_tool( - &self, - tool_name: &str, - _arguments: Value, - ) -> Pin, ToolError>> + Send + 'static>> { - let this = self.clone(); - let tool_name = tool_name.to_string(); - - Box::pin(async move { - match tool_name.as_str() { - "increment" => { - let value = this.increment().await?; - Ok(vec![Content::text(value.to_string())]) - } - "decrement" => { - let value = this.decrement().await?; - Ok(vec![Content::text(value.to_string())]) - } - "get_value" => { - let value = this.get_value().await?; - Ok(vec![Content::text(value.to_string())]) - } - _ => Err(ToolError::NotFound(format!("Tool {} not found", tool_name))), - } - }) - } - - fn list_resources(&self) -> Vec { - vec![ - self._create_resource_text("str:////Users/to/some/path/", "cwd"), - self._create_resource_text("memo://insights", "memo-name"), - ] - } - - fn read_resource( - &self, - uri: &str, - ) -> Pin> + Send + 'static>> { - let uri = uri.to_string(); - Box::pin(async move { - match uri.as_str() { - "str:////Users/to/some/path/" => { - let cwd = "/Users/to/some/path/"; - Ok(cwd.to_string()) - } - "memo://insights" => { - let memo = - "Business Intelligence Memo\n\nAnalysis has revealed 5 key insights ..."; - Ok(memo.to_string()) - } - _ => Err(ResourceError::NotFound(format!( - "Resource {} not found", - uri - ))), - } - }) - } - - fn list_prompts(&self) -> Vec { - vec![Prompt::new( - "example_prompt", - Some("This is an example prompt that takes one required agrument, message"), - Some(vec![PromptArgument { - name: "message".to_string(), - description: Some("A message to put in the prompt".to_string()), - required: Some(true), - }]), - )] - } - - fn get_prompt( - &self, - prompt_name: &str, - ) -> Pin> + Send + 'static>> { - let prompt_name = prompt_name.to_string(); - Box::pin(async move { - match prompt_name.as_str() { - "example_prompt" => { - let prompt = "This is an example prompt with your message here: '{message}'"; - Ok(prompt.to_string()) - } - _ => Err(PromptError::NotFound(format!( - "Prompt {} not found", - prompt_name - ))), - } - }) - } -} +mod common; #[tokio::main] async fn main() -> Result<()> { @@ -206,7 +25,7 @@ async fn main() -> Result<()> { tracing::info!("Starting MCP server"); // Create an instance of our counter router - let router = RouterService(CounterRouter::new()); + let router = RouterService(common::counter::CounterRouter::new()); // Create and run the server let server = Server::new(router);