Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion crates/mcp-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ where

tracing::info!("Server started");
while let Some(msg_result) = transport.next().await {
let _span = tracing::span!(tracing::Level::INFO, "message_processing").entered();
let _span = tracing::span!(tracing::Level::INFO, "message_processing");
let _enter = _span.enter();
match msg_result {
Ok(msg) => {
match msg {
Expand Down
9 changes: 9 additions & 0 deletions examples/servers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tracing-appender = "0.2"
futures = "0.3"

[dev-dependencies]
axum = { version = "0.8", features = ["macros"] }
tokio-util = { version = "0.7", features = ["io", "codec"]}
rand = { version = "0.8" }

[[example]]
name = "counter-server"
path = "src/counter_server.rs"

[[example]]
name = "axum"
path = "src/axum.rs"
155 changes: 155 additions & 0 deletions examples/servers/src/axum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
use axum::{
body::Body,
extract::{Query, State},
http::StatusCode,
response::sse::{Event, Sse},
routing::get,
Router,
};
use futures::{stream::Stream, StreamExt, TryStreamExt};
use mcp_server::{ByteTransport, Server};
use std::collections::HashMap;
use tokio_util::codec::FramedRead;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

use anyhow::Result;
use mcp_server::router::RouterService;
use std::sync::Arc;
use tokio::{
io::{self, AsyncWriteExt},
sync::Mutex,
};
use tracing_subscriber::{self};
mod common;
use common::counter;

type C2SWriter = Arc<Mutex<io::WriteHalf<io::SimplexStream>>>;
type SessionId = Arc<str>;

const BIND_ADDRESS: &str = "127.0.0.1:8000";

#[derive(Clone, Default)]
pub struct App {
txs: Arc<tokio::sync::RwLock<HashMap<SessionId, C2SWriter>>>,
}

impl App {
pub fn new() -> Self {
Self {
txs: Default::default(),
}
}
pub fn router(&self) -> Router {
Router::new()
.route("/sse", get(sse_handler).post(post_event_handler))
.with_state(self.clone())
}
}

fn session_id() -> SessionId {
let id = format!("{:016x}", rand::random::<u128>());
Arc::from(id)
}

#[derive(Debug, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PostEventQuery {
pub session_id: String,
}

async fn post_event_handler(
State(app): State<App>,
Query(PostEventQuery { session_id }): Query<PostEventQuery>,
body: Body,
) -> Result<StatusCode, StatusCode> {
const BODY_BYTES_LIMIT: usize = 1 << 22;
let write_stream = {
let rg = app.txs.read().await;
rg.get(session_id.as_str())
.ok_or(StatusCode::NOT_FOUND)?
.clone()
};
let mut write_stream = write_stream.lock().await;
let mut body = body.into_data_stream();
if let (_, Some(size)) = body.size_hint() {
if size > BODY_BYTES_LIMIT {
return Err(StatusCode::PAYLOAD_TOO_LARGE);
}
}
// calculate the body size
let mut size = 0;
while let Some(chunk) = body.next().await {
let Ok(chunk) = chunk else {
return Err(StatusCode::BAD_REQUEST);
};
size += chunk.len();
if size > BODY_BYTES_LIMIT {
return Err(StatusCode::PAYLOAD_TOO_LARGE);
}
write_stream
.write_all(&chunk)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
}
write_stream
.write_u8(b'\n')
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(StatusCode::ACCEPTED)
}

async fn sse_handler(State(app): State<App>) -> Sse<impl Stream<Item = Result<Event, io::Error>>> {
// it's 4KB
const BUFFER_SIZE: usize = 1 << 12;
let session = session_id();
tracing::info!(%session, "sse connection");
let (c2s_read, c2s_write) = tokio::io::simplex(BUFFER_SIZE);
let (s2c_read, s2c_write) = tokio::io::simplex(BUFFER_SIZE);
app.txs
.write()
.await
.insert(session.clone(), Arc::new(Mutex::new(c2s_write)));
{
let session = session.clone();
tokio::spawn(async move {
let router = RouterService(counter::CounterRouter::new());
let server = Server::new(router);
let bytes_transport = ByteTransport::new(c2s_read, s2c_write);
let _result = server
.run(bytes_transport)
.await
.inspect_err(|e| tracing::error!(?e, "server run error"));
app.txs.write().await.remove(&session);
});
}

let stream = futures::stream::once(futures::future::ok(
Event::default()
.event("endpoint")
.data(format!("?sessionId={session}")),
))
.chain(
FramedRead::new(s2c_read, common::jsonrpc_frame_codec::JsonRpcFrameCodec)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
.and_then(move |bytes| match std::str::from_utf8(&bytes) {
Ok(message) => futures::future::ok(Event::default().event("message").data(message)),
Err(e) => futures::future::err(io::Error::new(io::ErrorKind::InvalidData, e)),
}),
);
Sse::new(stream)
}

#[tokio::main]
async fn main() -> io::Result<()> {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| format!("info,{}=debug", env!("CARGO_CRATE_NAME")).into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
let listener = tokio::net::TcpListener::bind(BIND_ADDRESS).await?;

tracing::debug!("listening on {}", listener.local_addr()?);
axum::serve(listener, App::new().router()).await
}
184 changes: 184 additions & 0 deletions examples/servers/src/common/counter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
use std::{future::Future, pin::Pin, sync::Arc};

use mcp_core::{
handler::{PromptError, ResourceError},
prompt::{Prompt, PromptArgument},
protocol::ServerCapabilities,
Content, Resource, Tool, ToolError,
};
use mcp_server::router::CapabilitiesBuilder;
use serde_json::Value;
use tokio::sync::Mutex;

#[derive(Clone)]
pub struct CounterRouter {
counter: Arc<Mutex<i32>>,
}

impl CounterRouter {
pub fn new() -> Self {
Self {
counter: Arc::new(Mutex::new(0)),
}
}

async fn increment(&self) -> Result<i32, ToolError> {
let mut counter = self.counter.lock().await;
*counter += 1;
Ok(*counter)
}

async fn decrement(&self) -> Result<i32, ToolError> {
let mut counter = self.counter.lock().await;
*counter -= 1;
Ok(*counter)
}

async fn get_value(&self) -> Result<i32, ToolError> {
let counter = self.counter.lock().await;
Ok(*counter)
}

fn _create_resource_text(&self, uri: &str, name: &str) -> Resource {
Resource::new(uri, Some("text/plain".to_string()), Some(name.to_string())).unwrap()
}
}

impl mcp_server::Router for CounterRouter {
fn name(&self) -> String {
"counter".to_string()
}

fn instructions(&self) -> String {
"This server provides a counter tool that can increment and decrement values. The counter starts at 0 and can be modified using the 'increment' and 'decrement' tools. Use 'get_value' to check the current count.".to_string()
}

fn capabilities(&self) -> ServerCapabilities {
CapabilitiesBuilder::new()
.with_tools(false)
.with_resources(false, false)
.with_prompts(false)
.build()
}

fn list_tools(&self) -> Vec<Tool> {
vec![
Tool::new(
"increment".to_string(),
"Increment the counter by 1".to_string(),
serde_json::json!({
"type": "object",
"properties": {},
"required": []
}),
),
Tool::new(
"decrement".to_string(),
"Decrement the counter by 1".to_string(),
serde_json::json!({
"type": "object",
"properties": {},
"required": []
}),
),
Tool::new(
"get_value".to_string(),
"Get the current counter value".to_string(),
serde_json::json!({
"type": "object",
"properties": {},
"required": []
}),
),
]
}

fn call_tool(
&self,
tool_name: &str,
_arguments: Value,
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
let this = self.clone();
let tool_name = tool_name.to_string();

Box::pin(async move {
match tool_name.as_str() {
"increment" => {
let value = this.increment().await?;
Ok(vec![Content::text(value.to_string())])
}
"decrement" => {
let value = this.decrement().await?;
Ok(vec![Content::text(value.to_string())])
}
"get_value" => {
let value = this.get_value().await?;
Ok(vec![Content::text(value.to_string())])
}
_ => Err(ToolError::NotFound(format!("Tool {} not found", tool_name))),
}
})
}

fn list_resources(&self) -> Vec<Resource> {
vec![
self._create_resource_text("str:////Users/to/some/path/", "cwd"),
self._create_resource_text("memo://insights", "memo-name"),
]
}

fn read_resource(
&self,
uri: &str,
) -> Pin<Box<dyn Future<Output = Result<String, ResourceError>> + Send + 'static>> {
let uri = uri.to_string();
Box::pin(async move {
match uri.as_str() {
"str:////Users/to/some/path/" => {
let cwd = "/Users/to/some/path/";
Ok(cwd.to_string())
}
"memo://insights" => {
let memo =
"Business Intelligence Memo\n\nAnalysis has revealed 5 key insights ...";
Ok(memo.to_string())
}
_ => Err(ResourceError::NotFound(format!(
"Resource {} not found",
uri
))),
}
})
}

fn list_prompts(&self) -> Vec<Prompt> {
vec![Prompt::new(
"example_prompt",
Some("This is an example prompt that takes one required agrument, message"),
Some(vec![PromptArgument {
name: "message".to_string(),
description: Some("A message to put in the prompt".to_string()),
required: Some(true),
}]),
)]
}

fn get_prompt(
&self,
prompt_name: &str,
) -> Pin<Box<dyn Future<Output = Result<String, PromptError>> + Send + 'static>> {
let prompt_name = prompt_name.to_string();
Box::pin(async move {
match prompt_name.as_str() {
"example_prompt" => {
let prompt = "This is an example prompt with your message here: '{message}'";
Ok(prompt.to_string())
}
_ => Err(PromptError::NotFound(format!(
"Prompt {} not found",
prompt_name
))),
}
})
}
}
24 changes: 24 additions & 0 deletions examples/servers/src/common/jsonrpc_frame_codec.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use tokio_util::codec::Decoder;

#[derive(Default)]
pub struct JsonRpcFrameCodec;
impl Decoder for JsonRpcFrameCodec {
type Item = tokio_util::bytes::Bytes;
type Error = tokio::io::Error;
fn decode(
&mut self,
src: &mut tokio_util::bytes::BytesMut,
) -> Result<Option<Self::Item>, Self::Error> {
if let Some(end) = src
.iter()
.enumerate()
.find_map(|(idx, &b)| (b == b'\n').then_some(idx))
{
let line = src.split_to(end);
let _char_next_line = src.split_to(1);
Ok(Some(line.freeze()))
} else {
Ok(None)
}
}
}
2 changes: 2 additions & 0 deletions examples/servers/src/common/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub mod counter;
pub mod jsonrpc_frame_codec;
Loading