Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/mcp-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tracing-appender = "0.2"
async-trait = "0.1"

3 changes: 2 additions & 1 deletion crates/mcp-server/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ where

tracing::info!("Server started");
while let Some(msg_result) = transport.next().await {
let _span = tracing::span!(tracing::Level::INFO, "message_processing").entered();
let _span = tracing::span!(tracing::Level::INFO, "message_processing");
let _enter = _span.enter();
match msg_result {
Ok(msg) => {
match msg {
Expand Down
9 changes: 9 additions & 0 deletions examples/servers/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,15 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tracing-appender = "0.2"
futures = "0.3"

[dev-dependencies]
axum = { version = "0.8", features = ["macros"] }
tokio-util = { version = "0.7", features = ["io", "codec"]}
rand = { version = "0.8" }

[[example]]
name = "counter-server"
path = "src/counter_server.rs"

[[example]]
name = "axum"
path = "src/axum.rs"
155 changes: 155 additions & 0 deletions examples/servers/src/axum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
use axum::{
body::Body,
extract::{Query, State},
http::StatusCode,
response::sse::{Event, Sse},
routing::get,
Router,
};
use futures::{stream::Stream, StreamExt, TryStreamExt};
use mcp_server::{ByteTransport, Server};
use std::collections::HashMap;
use tokio_util::codec::FramedRead;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};

use anyhow::Result;
use mcp_server::router::RouterService;
use std::sync::Arc;
use tokio::{
io::{self, AsyncWriteExt},
sync::Mutex,
};
use tracing_subscriber::{self};
mod common;
use common::counter;

type C2SWriter = Arc<Mutex<io::WriteHalf<io::SimplexStream>>>;
type SessionId = Arc<str>;

const BIND_ADDRESS: &str = "127.0.0.1:8000";

#[derive(Clone, Default)]
pub struct App {
txs: Arc<tokio::sync::RwLock<HashMap<SessionId, C2SWriter>>>,
}

impl App {
pub fn new() -> Self {
Self {
txs: Default::default(),
}
}
pub fn router(&self) -> Router {
Router::new()
.route("/sse", get(sse_handler).post(post_event_handler))
.with_state(self.clone())
}
}

fn session_id() -> SessionId {
let id = format!("{:016x}", rand::random::<u128>());
Arc::from(id)
}

#[derive(Debug, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PostEventQuery {
pub session_id: String,
}

async fn post_event_handler(
State(app): State<App>,
Query(PostEventQuery { session_id }): Query<PostEventQuery>,
body: Body,
) -> Result<StatusCode, StatusCode> {
const BODY_BYTES_LIMIT: usize = 1 << 22;
let write_stream = {
let rg = app.txs.read().await;
rg.get(session_id.as_str())
.ok_or(StatusCode::NOT_FOUND)?
.clone()
};
let mut write_stream = write_stream.lock().await;
let mut body = body.into_data_stream();
if let (_, Some(size)) = body.size_hint() {
if size > BODY_BYTES_LIMIT {
return Err(StatusCode::PAYLOAD_TOO_LARGE);
}
}
// calculate the body size
let mut size = 0;
while let Some(chunk) = body.next().await {
let Ok(chunk) = chunk else {
return Err(StatusCode::BAD_REQUEST);
};
size += chunk.len();
if size > BODY_BYTES_LIMIT {
return Err(StatusCode::PAYLOAD_TOO_LARGE);
}
write_stream
.write_all(&chunk)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
}
write_stream
.write_u8(b'\n')
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(StatusCode::ACCEPTED)
}

async fn sse_handler(State(app): State<App>) -> Sse<impl Stream<Item = Result<Event, io::Error>>> {
// it's 4KB
const BUFFER_SIZE: usize = 1 << 12;
let session = session_id();
tracing::info!(%session, "sse connection");
let (c2s_read, c2s_write) = tokio::io::simplex(BUFFER_SIZE);
let (s2c_read, s2c_write) = tokio::io::simplex(BUFFER_SIZE);
app.txs
.write()
.await
.insert(session.clone(), Arc::new(Mutex::new(c2s_write)));
{
let session = session.clone();
tokio::spawn(async move {
let router = RouterService(counter::CounterRouter::new());
let server = Server::new(router);
let bytes_transport = ByteTransport::new(c2s_read, s2c_write);
let _result = server
.run(bytes_transport)
.await
.inspect_err(|e| tracing::error!(?e, "server run error"));
app.txs.write().await.remove(&session);
});
}

let stream = futures::stream::once(futures::future::ok(
Event::default()
.event("endpoint")
.data(format!("?sessionId={session}")),
))
.chain(
FramedRead::new(s2c_read, common::jsonrpc_frame_codec::JsonRpcFrameCodec)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))
.and_then(move |bytes| match std::str::from_utf8(&bytes) {
Ok(message) => futures::future::ok(Event::default().event("message").data(message)),
Err(e) => futures::future::err(io::Error::new(io::ErrorKind::InvalidData, e)),
}),
);
Sse::new(stream)
}

#[tokio::main]
async fn main() -> io::Result<()> {
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| format!("info,{}=debug", env!("CARGO_CRATE_NAME")).into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
let listener = tokio::net::TcpListener::bind(BIND_ADDRESS).await?;

tracing::debug!("listening on {}", listener.local_addr()?);
axum::serve(listener, App::new().router()).await
}
184 changes: 184 additions & 0 deletions examples/servers/src/common/counter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
use std::{future::Future, pin::Pin, sync::Arc};

use mcp_core::{
handler::{PromptError, ResourceError},
prompt::{Prompt, PromptArgument},
protocol::ServerCapabilities,
Content, Resource, Tool, ToolError,
};
use mcp_server::router::CapabilitiesBuilder;
use serde_json::Value;
use tokio::sync::Mutex;

#[derive(Clone)]
pub struct CounterRouter {
counter: Arc<Mutex<i32>>,
}

impl CounterRouter {
pub fn new() -> Self {
Self {
counter: Arc::new(Mutex::new(0)),
}
}

async fn increment(&self) -> Result<i32, ToolError> {
let mut counter = self.counter.lock().await;
*counter += 1;
Ok(*counter)
}

async fn decrement(&self) -> Result<i32, ToolError> {
let mut counter = self.counter.lock().await;
*counter -= 1;
Ok(*counter)
}

async fn get_value(&self) -> Result<i32, ToolError> {
let counter = self.counter.lock().await;
Ok(*counter)
}

fn _create_resource_text(&self, uri: &str, name: &str) -> Resource {
Resource::new(uri, Some("text/plain".to_string()), Some(name.to_string())).unwrap()
}
}

impl mcp_server::Router for CounterRouter {
fn name(&self) -> String {
"counter".to_string()
}

fn instructions(&self) -> String {
"This server provides a counter tool that can increment and decrement values. The counter starts at 0 and can be modified using the 'increment' and 'decrement' tools. Use 'get_value' to check the current count.".to_string()
}

fn capabilities(&self) -> ServerCapabilities {
CapabilitiesBuilder::new()
.with_tools(false)
.with_resources(false, false)
.with_prompts(false)
.build()
}

fn list_tools(&self) -> Vec<Tool> {
vec![
Tool::new(
"increment".to_string(),
"Increment the counter by 1".to_string(),
serde_json::json!({
"type": "object",
"properties": {},
"required": []
}),
),
Tool::new(
"decrement".to_string(),
"Decrement the counter by 1".to_string(),
serde_json::json!({
"type": "object",
"properties": {},
"required": []
}),
),
Tool::new(
"get_value".to_string(),
"Get the current counter value".to_string(),
serde_json::json!({
"type": "object",
"properties": {},
"required": []
}),
),
]
}

fn call_tool(
&self,
tool_name: &str,
_arguments: Value,
) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> {
let this = self.clone();
let tool_name = tool_name.to_string();

Box::pin(async move {
match tool_name.as_str() {
"increment" => {
let value = this.increment().await?;
Ok(vec![Content::text(value.to_string())])
}
"decrement" => {
let value = this.decrement().await?;
Ok(vec![Content::text(value.to_string())])
}
"get_value" => {
let value = this.get_value().await?;
Ok(vec![Content::text(value.to_string())])
}
_ => Err(ToolError::NotFound(format!("Tool {} not found", tool_name))),
}
})
}

fn list_resources(&self) -> Vec<Resource> {
vec![
self._create_resource_text("str:////Users/to/some/path/", "cwd"),
self._create_resource_text("memo://insights", "memo-name"),
]
}

fn read_resource(
&self,
uri: &str,
) -> Pin<Box<dyn Future<Output = Result<String, ResourceError>> + Send + 'static>> {
let uri = uri.to_string();
Box::pin(async move {
match uri.as_str() {
"str:////Users/to/some/path/" => {
let cwd = "/Users/to/some/path/";
Ok(cwd.to_string())
}
"memo://insights" => {
let memo =
"Business Intelligence Memo\n\nAnalysis has revealed 5 key insights ...";
Ok(memo.to_string())
}
_ => Err(ResourceError::NotFound(format!(
"Resource {} not found",
uri
))),
}
})
}

fn list_prompts(&self) -> Vec<Prompt> {
vec![Prompt::new(
"example_prompt",
Some("This is an example prompt that takes one required agrument, message"),
Some(vec![PromptArgument {
name: "message".to_string(),
description: Some("A message to put in the prompt".to_string()),
required: Some(true),
}]),
)]
}

fn get_prompt(
&self,
prompt_name: &str,
) -> Pin<Box<dyn Future<Output = Result<String, PromptError>> + Send + 'static>> {
let prompt_name = prompt_name.to_string();
Box::pin(async move {
match prompt_name.as_str() {
"example_prompt" => {
let prompt = "This is an example prompt with your message here: '{message}'";
Ok(prompt.to_string())
}
_ => Err(PromptError::NotFound(format!(
"Prompt {} not found",
prompt_name
))),
}
})
}
}
24 changes: 24 additions & 0 deletions examples/servers/src/common/jsonrpc_frame_codec.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use tokio_util::codec::Decoder;

#[derive(Default)]
pub struct JsonRpcFrameCodec;
impl Decoder for JsonRpcFrameCodec {
type Item = tokio_util::bytes::Bytes;
type Error = tokio::io::Error;
fn decode(
&mut self,
src: &mut tokio_util::bytes::BytesMut,
) -> Result<Option<Self::Item>, Self::Error> {
if let Some(end) = src
.iter()
.enumerate()
.find_map(|(idx, &b)| (b == b'\n').then_some(idx))
{
let line = src.split_to(end);
let _char_next_line = src.split_to(1);
Ok(Some(line.freeze()))
} else {
Ok(None)
}
}
}
Loading