-
Notifications
You must be signed in to change notification settings - Fork 442
Add an axum sse example #12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
kalvinnchau
merged 6 commits into
modelcontextprotocol:main
from
4t145:add-example-axum-sse-server
Mar 4, 2025
Merged
Changes from 5 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
e82eed4
add axum sse example
4t145 4663da3
fix: make the sse axum folder name correct
4t145 830f2a6
feat: add common modules for Axum SSE example
4t145 05d13b1
Merge remote-tracking branch 'origin/main' into add-example-axum-sse-…
4t145 957f183
move examples to root
4t145 0ad8b3b
nit: remove extra newline here
4t145 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,155 @@ | ||
| use axum::{ | ||
| body::Body, | ||
| extract::{Query, State}, | ||
| http::StatusCode, | ||
| response::sse::{Event, Sse}, | ||
| routing::get, | ||
| Router, | ||
| }; | ||
| use futures::{stream::Stream, StreamExt, TryStreamExt}; | ||
| use mcp_server::{ByteTransport, Server}; | ||
| use std::collections::HashMap; | ||
| use tokio_util::codec::FramedRead; | ||
| use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; | ||
|
|
||
| use anyhow::Result; | ||
| use mcp_server::router::RouterService; | ||
| use std::sync::Arc; | ||
| use tokio::{ | ||
| io::{self, AsyncWriteExt}, | ||
| sync::Mutex, | ||
| }; | ||
| use tracing_subscriber::{self}; | ||
| mod common; | ||
| use common::counter; | ||
|
|
||
| type C2SWriter = Arc<Mutex<io::WriteHalf<io::SimplexStream>>>; | ||
| type SessionId = Arc<str>; | ||
|
|
||
| const BIND_ADDRESS: &str = "127.0.0.1:8000"; | ||
|
|
||
| #[derive(Clone, Default)] | ||
| pub struct App { | ||
| txs: Arc<tokio::sync::RwLock<HashMap<SessionId, C2SWriter>>>, | ||
| } | ||
|
|
||
| impl App { | ||
| pub fn new() -> Self { | ||
| Self { | ||
| txs: Default::default(), | ||
| } | ||
| } | ||
| pub fn router(&self) -> Router { | ||
| Router::new() | ||
| .route("/sse", get(sse_handler).post(post_event_handler)) | ||
| .with_state(self.clone()) | ||
| } | ||
| } | ||
|
|
||
| fn session_id() -> SessionId { | ||
| let id = format!("{:016x}", rand::random::<u128>()); | ||
| Arc::from(id) | ||
| } | ||
|
|
||
| #[derive(Debug, serde::Deserialize)] | ||
| #[serde(rename_all = "camelCase")] | ||
| pub struct PostEventQuery { | ||
| pub session_id: String, | ||
| } | ||
|
|
||
| async fn post_event_handler( | ||
| State(app): State<App>, | ||
| Query(PostEventQuery { session_id }): Query<PostEventQuery>, | ||
| body: Body, | ||
| ) -> Result<StatusCode, StatusCode> { | ||
| const BODY_BYTES_LIMIT: usize = 1 << 22; | ||
| let write_stream = { | ||
| let rg = app.txs.read().await; | ||
| rg.get(session_id.as_str()) | ||
| .ok_or(StatusCode::NOT_FOUND)? | ||
| .clone() | ||
| }; | ||
| let mut write_stream = write_stream.lock().await; | ||
| let mut body = body.into_data_stream(); | ||
| if let (_, Some(size)) = body.size_hint() { | ||
| if size > BODY_BYTES_LIMIT { | ||
| return Err(StatusCode::PAYLOAD_TOO_LARGE); | ||
| } | ||
| } | ||
| // calculate the body size | ||
| let mut size = 0; | ||
| while let Some(chunk) = body.next().await { | ||
| let Ok(chunk) = chunk else { | ||
| return Err(StatusCode::BAD_REQUEST); | ||
| }; | ||
| size += chunk.len(); | ||
| if size > BODY_BYTES_LIMIT { | ||
| return Err(StatusCode::PAYLOAD_TOO_LARGE); | ||
| } | ||
| write_stream | ||
| .write_all(&chunk) | ||
| .await | ||
| .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; | ||
| } | ||
| write_stream | ||
| .write_u8(b'\n') | ||
| .await | ||
| .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; | ||
| Ok(StatusCode::ACCEPTED) | ||
| } | ||
|
|
||
| async fn sse_handler(State(app): State<App>) -> Sse<impl Stream<Item = Result<Event, io::Error>>> { | ||
| // it's 4KB | ||
| const BUFFER_SIZE: usize = 1 << 12; | ||
| let session = session_id(); | ||
| tracing::info!(%session, "sse connection"); | ||
| let (c2s_read, c2s_write) = tokio::io::simplex(BUFFER_SIZE); | ||
| let (s2c_read, s2c_write) = tokio::io::simplex(BUFFER_SIZE); | ||
| app.txs | ||
| .write() | ||
| .await | ||
| .insert(session.clone(), Arc::new(Mutex::new(c2s_write))); | ||
| { | ||
| let session = session.clone(); | ||
| tokio::spawn(async move { | ||
| let router = RouterService(counter::CounterRouter::new()); | ||
| let server = Server::new(router); | ||
| let bytes_transport = ByteTransport::new(c2s_read, s2c_write); | ||
| let _result = server | ||
| .run(bytes_transport) | ||
| .await | ||
| .inspect_err(|e| tracing::error!(?e, "server run error")); | ||
| app.txs.write().await.remove(&session); | ||
| }); | ||
| } | ||
|
|
||
| let stream = futures::stream::once(futures::future::ok( | ||
| Event::default() | ||
| .event("endpoint") | ||
| .data(format!("?sessionId={session}")), | ||
| )) | ||
| .chain( | ||
| FramedRead::new(s2c_read, common::jsonrpc_frame_codec::JsonRpcFrameCodec) | ||
| .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e)) | ||
| .and_then(move |bytes| match std::str::from_utf8(&bytes) { | ||
| Ok(message) => futures::future::ok(Event::default().event("message").data(message)), | ||
| Err(e) => futures::future::err(io::Error::new(io::ErrorKind::InvalidData, e)), | ||
| }), | ||
| ); | ||
| Sse::new(stream) | ||
| } | ||
|
|
||
| #[tokio::main] | ||
| async fn main() -> io::Result<()> { | ||
| tracing_subscriber::registry() | ||
| .with( | ||
| tracing_subscriber::EnvFilter::try_from_default_env() | ||
| .unwrap_or_else(|_| format!("info,{}=debug", env!("CARGO_CRATE_NAME")).into()), | ||
| ) | ||
| .with(tracing_subscriber::fmt::layer()) | ||
| .init(); | ||
| let listener = tokio::net::TcpListener::bind(BIND_ADDRESS).await?; | ||
|
|
||
| tracing::debug!("listening on {}", listener.local_addr()?); | ||
| axum::serve(listener, App::new().router()).await | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,184 @@ | ||
| use std::{future::Future, pin::Pin, sync::Arc}; | ||
|
|
||
| use mcp_core::{ | ||
| handler::{PromptError, ResourceError}, | ||
| prompt::{Prompt, PromptArgument}, | ||
| protocol::ServerCapabilities, | ||
| Content, Resource, Tool, ToolError, | ||
| }; | ||
| use mcp_server::router::CapabilitiesBuilder; | ||
| use serde_json::Value; | ||
| use tokio::sync::Mutex; | ||
|
|
||
| #[derive(Clone)] | ||
| pub struct CounterRouter { | ||
| counter: Arc<Mutex<i32>>, | ||
| } | ||
|
|
||
| impl CounterRouter { | ||
| pub fn new() -> Self { | ||
| Self { | ||
| counter: Arc::new(Mutex::new(0)), | ||
| } | ||
| } | ||
|
|
||
| async fn increment(&self) -> Result<i32, ToolError> { | ||
| let mut counter = self.counter.lock().await; | ||
| *counter += 1; | ||
| Ok(*counter) | ||
| } | ||
|
|
||
| async fn decrement(&self) -> Result<i32, ToolError> { | ||
| let mut counter = self.counter.lock().await; | ||
| *counter -= 1; | ||
| Ok(*counter) | ||
| } | ||
|
|
||
| async fn get_value(&self) -> Result<i32, ToolError> { | ||
| let counter = self.counter.lock().await; | ||
| Ok(*counter) | ||
| } | ||
|
|
||
| fn _create_resource_text(&self, uri: &str, name: &str) -> Resource { | ||
| Resource::new(uri, Some("text/plain".to_string()), Some(name.to_string())).unwrap() | ||
| } | ||
| } | ||
|
|
||
| impl mcp_server::Router for CounterRouter { | ||
| fn name(&self) -> String { | ||
| "counter".to_string() | ||
| } | ||
|
|
||
| fn instructions(&self) -> String { | ||
| "This server provides a counter tool that can increment and decrement values. The counter starts at 0 and can be modified using the 'increment' and 'decrement' tools. Use 'get_value' to check the current count.".to_string() | ||
| } | ||
|
|
||
| fn capabilities(&self) -> ServerCapabilities { | ||
| CapabilitiesBuilder::new() | ||
| .with_tools(false) | ||
| .with_resources(false, false) | ||
| .with_prompts(false) | ||
| .build() | ||
| } | ||
|
|
||
| fn list_tools(&self) -> Vec<Tool> { | ||
| vec![ | ||
| Tool::new( | ||
| "increment".to_string(), | ||
| "Increment the counter by 1".to_string(), | ||
| serde_json::json!({ | ||
| "type": "object", | ||
| "properties": {}, | ||
| "required": [] | ||
| }), | ||
| ), | ||
| Tool::new( | ||
| "decrement".to_string(), | ||
| "Decrement the counter by 1".to_string(), | ||
| serde_json::json!({ | ||
| "type": "object", | ||
| "properties": {}, | ||
| "required": [] | ||
| }), | ||
| ), | ||
| Tool::new( | ||
| "get_value".to_string(), | ||
| "Get the current counter value".to_string(), | ||
| serde_json::json!({ | ||
| "type": "object", | ||
| "properties": {}, | ||
| "required": [] | ||
| }), | ||
| ), | ||
| ] | ||
| } | ||
|
|
||
| fn call_tool( | ||
| &self, | ||
| tool_name: &str, | ||
| _arguments: Value, | ||
| ) -> Pin<Box<dyn Future<Output = Result<Vec<Content>, ToolError>> + Send + 'static>> { | ||
| let this = self.clone(); | ||
| let tool_name = tool_name.to_string(); | ||
|
|
||
| Box::pin(async move { | ||
| match tool_name.as_str() { | ||
| "increment" => { | ||
| let value = this.increment().await?; | ||
| Ok(vec![Content::text(value.to_string())]) | ||
| } | ||
| "decrement" => { | ||
| let value = this.decrement().await?; | ||
| Ok(vec![Content::text(value.to_string())]) | ||
| } | ||
| "get_value" => { | ||
| let value = this.get_value().await?; | ||
| Ok(vec![Content::text(value.to_string())]) | ||
| } | ||
| _ => Err(ToolError::NotFound(format!("Tool {} not found", tool_name))), | ||
| } | ||
| }) | ||
| } | ||
|
|
||
| fn list_resources(&self) -> Vec<Resource> { | ||
| vec![ | ||
| self._create_resource_text("str:////Users/to/some/path/", "cwd"), | ||
| self._create_resource_text("memo://insights", "memo-name"), | ||
| ] | ||
| } | ||
|
|
||
| fn read_resource( | ||
| &self, | ||
| uri: &str, | ||
| ) -> Pin<Box<dyn Future<Output = Result<String, ResourceError>> + Send + 'static>> { | ||
| let uri = uri.to_string(); | ||
| Box::pin(async move { | ||
| match uri.as_str() { | ||
| "str:////Users/to/some/path/" => { | ||
| let cwd = "/Users/to/some/path/"; | ||
| Ok(cwd.to_string()) | ||
| } | ||
| "memo://insights" => { | ||
| let memo = | ||
| "Business Intelligence Memo\n\nAnalysis has revealed 5 key insights ..."; | ||
| Ok(memo.to_string()) | ||
| } | ||
| _ => Err(ResourceError::NotFound(format!( | ||
| "Resource {} not found", | ||
| uri | ||
| ))), | ||
| } | ||
| }) | ||
| } | ||
|
|
||
| fn list_prompts(&self) -> Vec<Prompt> { | ||
| vec![Prompt::new( | ||
| "example_prompt", | ||
| Some("This is an example prompt that takes one required agrument, message"), | ||
| Some(vec![PromptArgument { | ||
| name: "message".to_string(), | ||
| description: Some("A message to put in the prompt".to_string()), | ||
| required: Some(true), | ||
| }]), | ||
| )] | ||
| } | ||
|
|
||
| fn get_prompt( | ||
| &self, | ||
| prompt_name: &str, | ||
| ) -> Pin<Box<dyn Future<Output = Result<String, PromptError>> + Send + 'static>> { | ||
| let prompt_name = prompt_name.to_string(); | ||
| Box::pin(async move { | ||
| match prompt_name.as_str() { | ||
| "example_prompt" => { | ||
| let prompt = "This is an example prompt with your message here: '{message}'"; | ||
| Ok(prompt.to_string()) | ||
| } | ||
| _ => Err(PromptError::NotFound(format!( | ||
| "Prompt {} not found", | ||
| prompt_name | ||
| ))), | ||
| } | ||
| }) | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| use tokio_util::codec::Decoder; | ||
|
|
||
| #[derive(Default)] | ||
| pub struct JsonRpcFrameCodec; | ||
| impl Decoder for JsonRpcFrameCodec { | ||
| type Item = tokio_util::bytes::Bytes; | ||
| type Error = tokio::io::Error; | ||
| fn decode( | ||
| &mut self, | ||
| src: &mut tokio_util::bytes::BytesMut, | ||
| ) -> Result<Option<Self::Item>, Self::Error> { | ||
| if let Some(end) = src | ||
| .iter() | ||
| .enumerate() | ||
| .find_map(|(idx, &b)| (b == b'\n').then_some(idx)) | ||
| { | ||
| let line = src.split_to(end); | ||
| let _char_next_line = src.split_to(1); | ||
| Ok(Some(line.freeze())) | ||
| } else { | ||
| Ok(None) | ||
| } | ||
| } | ||
| } |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.