diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 00000000..bb5151cb --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,54 @@ +name: Release + +on: + push: + branches: [ release ] + tags: + - 'release-*' + pull_request: + branches: [ release ] +env: + CARGO_TERM_COLOR: always + ARTIFACT_DIR: release-artifacts + +jobs: + release: + name: Release crates + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + + - name: Install Rust + uses: dtolnay/rust-toolchain@stable + + - name: Cache dependencies + uses: actions/cache@v3 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: ${{ runner.os }}-cargo- + + - name: Check formatting + run: cargo fmt --all -- --check + - name: Run clippy + run: cargo clippy --all-targets --all-features -- -D warnings + - name: Cargo login + run: cargo login ${{ secrets.CRATES_TOKEN }} + - name: Publish macros dry run + run: cargo publish -p rmcp-macros --dry-run + continue-on-error: true + - name: Publish rmcp dry run + run: cargo publish -p rmcp --dry-run + continue-on-error: true + - name: Publish macro + if: ${{ startsWith(github.ref, 'refs/tags/release') }} + continue-on-error: true + run: cargo publish -p rmcp-macros + - name: Publish rmcp + if: ${{ startsWith(github.ref, 'refs/tags/release') }} + continue-on-error: true + run: cargo publish -p rmcp + diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8408c6c6..3045fc1a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,10 +15,19 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - + # install nodejs + - name: Setup Node.js + uses: actions/setup-node@v2 + with: + node-version: '20' + - name: Install uv + uses: astral-sh/setup-uv@v5 - name: Install Rust uses: dtolnay/rust-toolchain@stable - + - name: Set up Python + run: uv python install + - name: Create venv for python + run: uv venv - name: Cache dependencies uses: actions/cache@v3 with: diff --git a/Cargo.toml b/Cargo.toml index 25a9720c..9209257c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,20 +1,22 @@ [workspace] -members = [ - "crates/*", - "examples/clients", - "examples/servers", - "examples/macros" -] +members = ["crates/rmcp", "crates/rmcp-macros", "examples/*"] resolver = "2" [workspace.dependencies] -mcp-core = { path = "./crates/mcp-core" } -mcp-macros = { path = "./crates/mcp-macros" } +rmcp = { version = "0.1.5", path = "./crates/rmcp" } +rmcp-macros = { version = "0.1.5", path = "./crates/rmcp-macros" } [workspace.package] -edition = "2021" -version = "1.0.7" -authors = ["Block "] -license = "MIT" -repository = "https://github.com/modelcontextprotocol/rust-sdk/" +edition = "2024" +version = "0.1.5" +authors = ["4t145 "] +license = "MIT/Apache-2.0" +repository = "https://github.commodelcontextprotocol/rust-sdk/" description = "Rust SDK for the Model Context Protocol" +keywords = ["mcp", "sdk", "tokio", "modelcontextprotocol"] +homepage = "https://github.com/modelcontextprotocol/rust-sdk" +categories = [ + "network-programming", + "asynchronous", +] +readme = "README.md" diff --git a/README.md b/README.md index 52adb920..9e7283dd 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,169 @@ -# rust-sdk -The official Rust SDK for the Model Context Protocol +
+简体中文 +
+ +# RMCP +[![Crates.io Version](https://img.shields.io/crates/v/rmcp)](https://crates.io/crates/rmcp) +![Release status](https://github.commodelcontextprotocol/rust-sdk/actions/workflows/release.yml/badge.svg) +[![docs.rs](https://img.shields.io/docsrs/rmcp)](https://docs.rs/rmcp/latest/rmcp) + +An official rust Model Context Protocol SDK implementation with tokio async runtime. + +## Usage + +### Import +```toml +rmcp = { version = "0.1", features = ["server"] } +## or dev channel +rmcp = { git = "https://github.com/modelcontextprotocol/rust-sdk", branch = "dev" } +``` + +### Quick start +Start a client in one line: +```rust +use rmcp::{ServiceExt, transport::TokioChildProcess}; +use tokio::process::Command; + +let client = ().serve( + TokioChildProcess::new(Command::new("npx").arg("-y").arg("@modelcontextprotocol/server-everything"))? +).await?; +``` + +#### 1. Build a transport + +```rust, ignore +use tokio::io::{stdin, stdout}; +let transport = (stdin(), stdout()); +``` + +The transport type must implemented [`IntoTransport`](crate::transport::IntoTransport) trait, which allow split into a sink and a stream. + +For client, the sink item is [`ClientJsonRpcMessage`](crate::model::ClientJsonRpcMessage) and stream item is [`ServerJsonRpcMessage`](crate::model::ServerJsonRpcMessage) + +For server, the sink item is [`ServerJsonRpcMessage`](crate::model::ServerJsonRpcMessage) and stream item is [`ClientJsonRpcMessage`](crate::model::ClientJsonRpcMessage) + +##### These types is automatically implemented [`IntoTransport`](crate::transport::IntoTransport) trait +1. The types that already implement both [`Sink`](futures::Sink) and [`Stream`](futures::Stream) trait. +2. A tuple of sink `Tx` and stream `Rx`: `(Tx, Rx)`. +3. The type that implement both [`tokio::io::AsyncRead`] and [`tokio::io::AsyncWrite`] trait. +4. A tuple of [`tokio::io::AsyncRead`] `R `and [`tokio::io::AsyncWrite`] `W`: `(R, W)`. + +For example, you can see how we build a transport through TCP stream or http upgrade so easily. [examples](examples/README.md) + +#### 2. Build a service +You can easily build a service by using [`ServerHandler`](crates/rmcp/src/handler/server.rs) or [`ClientHandler`](crates/rmcp/src/handler/client.rs). + +```rust, ignore +let service = common::counter::Counter::new(); +``` + +#### 3. Serve them together +```rust, ignore +// this call will finish the initialization process +let server = service.serve(transport).await?; +``` + +#### 4. Interact with the server +Once the server is initialized, you can send requests or notifications: + +```rust, ignore +// request +let roots = server.list_roots().await?; + +// or send notification +server.notify_cancelled(...).await?; +``` + +#### 5. Waiting for service shutdown +```rust, ignore +let quit_reason = server.waiting().await?; +// or cancel it +let quit_reason = server.cancel().await?; +``` + +### Use marcos to declaring tool +Use `toolbox` and `tool` macros to create tool quickly. + +Check this [file](examples/servers/src/common/calculator.rs). +```rust, ignore +use rmcp::{ServerHandler, model::ServerInfo, schemars, tool}; + +use super::counter::Counter; + +#[derive(Debug, serde::Deserialize, schemars::JsonSchema)] +pub struct SumRequest { + #[schemars(description = "the left hand side number")] + pub a: i32, + #[schemars(description = "the right hand side number")] + pub b: i32, +} +#[derive(Debug, Clone)] +pub struct Calculator; + +// create a static toolbox to store the tool attributes +#[tool(tool_box)] +impl Calculator { + // async function + #[tool(description = "Calculate the sum of two numbers")] + async fn sum(&self, #[tool(aggr)] SumRequest { a, b }: SumRequest) -> String { + (a + b).to_string() + } + + // sync function + #[tool(description = "Calculate the sum of two numbers")] + fn sub( + &self, + #[tool(param)] + // this macro will transfer the schemars and serde's attributes + #[schemars(description = "the left hand side number")] + a: i32, + #[tool(param)] + #[schemars(description = "the right hand side number")] + b: i32, + ) -> String { + (a - b).to_string() + } +} + +// impl call_tool and list_tool by querying static toolbox +#[tool(tool_box)] +impl ServerHandler for Calculator { + fn get_info(&self) -> ServerInfo { + ServerInfo { + instructions: Some("A simple calculator".into()), + ..Default::default() + } + } +} + +``` +The only thing you should do is to make the function's return type implement `IntoCallToolResult`. + +And you can just implement `IntoContents`, and the return value will be marked as success automatically. + +If you return a type of `Result` where `T` and `E` both implemented `IntoContents`, it's also OK. + +### Manage Multi Services +For many cases you need to manage several service in a collection, you can call `into_dyn` to convert services into the same type. +```rust, ignore +let service = service.into_dyn(); +``` + + +### Examples +See [examples](examples/README.md) + +### Features +- `client`: use client side sdk +- `server`: use server side sdk +- `macros`: macros default +#### Transports +- `transport-io`: Server stdio transport +- `transport-sse-server`: Server SSE transport +- `transport-child-process`: Client stdio transport +- `transport-sse`: Client sse transport + +## Related Resources +- [MCP Specification](https://spec.modelcontextprotocol.io/specification/2024-11-05/) + +- [Schema](https://github.com/modelcontextprotocol/specification/blob/main/schema/2024-11-05/schema.ts) diff --git a/crates/mcp-client/Cargo.toml b/crates/mcp-client/Cargo.toml deleted file mode 100644 index e54836da..00000000 --- a/crates/mcp-client/Cargo.toml +++ /dev/null @@ -1,27 +0,0 @@ -[package] -name = "mcp-client" -license.workspace = true -version.workspace = true -edition.workspace = true -repository.workspace = true -description = "Client SDK for the Model Context Protocol" - -[dependencies] -mcp-core = { workspace = true } -tokio = { version = "1", features = ["full"] } -reqwest = { version = "0.11", default-features = false, features = ["json", "stream", "rustls-tls"] } -eventsource-client = "0.12.0" -futures = "0.3" -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" -async-trait = "0.1.83" -url = "2.5.4" -thiserror = "1.0" -anyhow = "1.0" -tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter"] } -tower = { version = "0.4", features = ["timeout", "util"] } -tower-service = "0.3" -rand = "0.8" - -[dev-dependencies] diff --git a/crates/mcp-client/README.md b/crates/mcp-client/README.md deleted file mode 100644 index a43c4c21..00000000 --- a/crates/mcp-client/README.md +++ /dev/null @@ -1,11 +0,0 @@ -## Testing stdio transport - -```bash -cargo run -p mcp-client --example stdio -``` - -## Testing SSE transport - -1. Start the MCP server in one terminal: `fastmcp run -t sse echo.py` -2. Run the client example in new terminal: `cargo run -p mcp-client --example sse` - diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs deleted file mode 100644 index 0d722e55..00000000 --- a/crates/mcp-client/src/client.rs +++ /dev/null @@ -1,391 +0,0 @@ -use mcp_core::protocol::{ - CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcError, - JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListPromptsResult, - ListResourcesResult, ListToolsResult, ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND, -}; -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use std::sync::atomic::{AtomicU64, Ordering}; -use thiserror::Error; -use tokio::sync::Mutex; -use tower::{Service, ServiceExt}; // for Service::ready() - -pub type BoxError = Box; - -/// Error type for MCP client operations. -#[derive(Debug, Error)] -pub enum Error { - #[error("Transport error: {0}")] - Transport(#[from] super::transport::Error), - - #[error("RPC error: code={code}, message={message}")] - RpcError { code: i32, message: String }, - - #[error("Serialization error: {0}")] - Serialization(#[from] serde_json::Error), - - #[error("Unexpected response from server: {0}")] - UnexpectedResponse(String), - - #[error("Not initialized")] - NotInitialized, - - #[error("Timeout or service not ready")] - NotReady, - - #[error("Request timed out")] - Timeout(#[from] tower::timeout::error::Elapsed), - - #[error("Error from mcp-server: {0}")] - ServerBoxError(BoxError), - - #[error("Call to '{server}' failed for '{method}'. {source}")] - McpServerError { - method: String, - server: String, - #[source] - source: BoxError, - }, -} - -// BoxError from mcp-server gets converted to our Error type -impl From for Error { - fn from(err: BoxError) -> Self { - Error::ServerBoxError(err) - } -} - -#[derive(Serialize, Deserialize)] -pub struct ClientInfo { - pub name: String, - pub version: String, -} - -#[derive(Serialize, Deserialize, Default)] -pub struct ClientCapabilities { - // Add fields as needed. For now, empty capabilities are fine. -} - -#[derive(Serialize, Deserialize)] -pub struct InitializeParams { - #[serde(rename = "protocolVersion")] - pub protocol_version: String, - pub capabilities: ClientCapabilities, - #[serde(rename = "clientInfo")] - pub client_info: ClientInfo, -} - -#[async_trait::async_trait] -pub trait McpClientTrait: Send + Sync { - async fn initialize( - &mut self, - info: ClientInfo, - capabilities: ClientCapabilities, - ) -> Result; - - async fn list_resources( - &self, - next_cursor: Option, - ) -> Result; - - async fn read_resource(&self, uri: &str) -> Result; - - async fn list_tools(&self, next_cursor: Option) -> Result; - - async fn call_tool(&self, name: &str, arguments: Value) -> Result; - - async fn list_prompts(&self, next_cursor: Option) -> Result; - - async fn get_prompt(&self, name: &str, arguments: Value) -> Result; -} - -/// The MCP client is the interface for MCP operations. -pub struct McpClient -where - S: Service + Clone + Send + Sync + 'static, - S::Error: Into, - S::Future: Send, -{ - service: Mutex, - next_id: AtomicU64, - server_capabilities: Option, - server_info: Option, -} - -impl McpClient -where - S: Service + Clone + Send + Sync + 'static, - S::Error: Into, - S::Future: Send, -{ - pub fn new(service: S) -> Self { - Self { - service: Mutex::new(service), - next_id: AtomicU64::new(1), - server_capabilities: None, - server_info: None, - } - } - - /// Send a JSON-RPC request and check we don't get an error response. - async fn send_request(&self, method: &str, params: Value) -> Result - where - R: for<'de> Deserialize<'de>, - { - let mut service = self.service.lock().await; - service.ready().await.map_err(|_| Error::NotReady)?; - - let id = self.next_id.fetch_add(1, Ordering::SeqCst); - let request = JsonRpcMessage::Request(JsonRpcRequest { - jsonrpc: "2.0".to_string(), - id: Some(id), - method: method.to_string(), - params: Some(params.clone()), - }); - - let response_msg = service - .call(request) - .await - .map_err(|e| Error::McpServerError { - server: self - .server_info - .as_ref() - .map(|s| s.name.clone()) - .unwrap_or("".to_string()), - method: method.to_string(), - // we don't need include params because it can be really large - source: Box::new(e.into()), - })?; - - match response_msg { - JsonRpcMessage::Response(JsonRpcResponse { - id, result, error, .. - }) => { - // Verify id matches - if id != Some(self.next_id.load(Ordering::SeqCst) - 1) { - return Err(Error::UnexpectedResponse( - "id mismatch for JsonRpcResponse".to_string(), - )); - } - if let Some(err) = error { - Err(Error::RpcError { - code: err.code, - message: err.message, - }) - } else if let Some(r) = result { - Ok(serde_json::from_value(r)?) - } else { - Err(Error::UnexpectedResponse("missing result".to_string())) - } - } - JsonRpcMessage::Error(JsonRpcError { id, error, .. }) => { - if id != Some(self.next_id.load(Ordering::SeqCst) - 1) { - return Err(Error::UnexpectedResponse( - "id mismatch for JsonRpcError".to_string(), - )); - } - Err(Error::RpcError { - code: error.code, - message: error.message, - }) - } - _ => { - // Requests/notifications not expected as a response - Err(Error::UnexpectedResponse( - "unexpected message type".to_string(), - )) - } - } - } - - /// Send a JSON-RPC notification. - async fn send_notification(&self, method: &str, params: Value) -> Result<(), Error> { - let mut service = self.service.lock().await; - service.ready().await.map_err(|_| Error::NotReady)?; - - let notification = JsonRpcMessage::Notification(JsonRpcNotification { - jsonrpc: "2.0".to_string(), - method: method.to_string(), - params: Some(params.clone()), - }); - - service - .call(notification) - .await - .map_err(|e| Error::McpServerError { - server: self - .server_info - .as_ref() - .map(|s| s.name.clone()) - .unwrap_or("".to_string()), - method: method.to_string(), - // we don't need include params because it can be really large - source: Box::new(e.into()), - })?; - - Ok(()) - } - - // Check if the client has completed initialization - fn completed_initialization(&self) -> bool { - self.server_capabilities.is_some() - } -} - -#[async_trait::async_trait] -impl McpClientTrait for McpClient -where - S: Service + Clone + Send + Sync + 'static, - S::Error: Into, - S::Future: Send, -{ - async fn initialize( - &mut self, - info: ClientInfo, - capabilities: ClientCapabilities, - ) -> Result { - let params = InitializeParams { - protocol_version: "1.0.0".into(), - client_info: info, - capabilities, - }; - let result: InitializeResult = self - .send_request("initialize", serde_json::to_value(params)?) - .await?; - - self.send_notification("notifications/initialized", serde_json::json!({})) - .await?; - - self.server_capabilities = Some(result.capabilities.clone()); - - self.server_info = Some(result.server_info.clone()); - - Ok(result) - } - - async fn list_resources( - &self, - next_cursor: Option, - ) -> Result { - if !self.completed_initialization() { - return Err(Error::NotInitialized); - } - // If resources is not supported, return an empty list - if self - .server_capabilities - .as_ref() - .unwrap() - .resources - .is_none() - { - return Ok(ListResourcesResult { - resources: vec![], - next_cursor: None, - }); - } - - let payload = next_cursor - .map(|cursor| serde_json::json!({"cursor": cursor})) - .unwrap_or_else(|| serde_json::json!({})); - - self.send_request("resources/list", payload).await - } - - async fn read_resource(&self, uri: &str) -> Result { - if !self.completed_initialization() { - return Err(Error::NotInitialized); - } - // If resources is not supported, return an error - if self - .server_capabilities - .as_ref() - .unwrap() - .resources - .is_none() - { - return Err(Error::RpcError { - code: METHOD_NOT_FOUND, - message: "Server does not support 'resources' capability".to_string(), - }); - } - - let params = serde_json::json!({ "uri": uri }); - self.send_request("resources/read", params).await - } - - async fn list_tools(&self, next_cursor: Option) -> Result { - if !self.completed_initialization() { - return Err(Error::NotInitialized); - } - // If tools is not supported, return an empty list - if self.server_capabilities.as_ref().unwrap().tools.is_none() { - return Ok(ListToolsResult { - tools: vec![], - next_cursor: None, - }); - } - - let payload = next_cursor - .map(|cursor| serde_json::json!({"cursor": cursor})) - .unwrap_or_else(|| serde_json::json!({})); - - self.send_request("tools/list", payload).await - } - - async fn call_tool(&self, name: &str, arguments: Value) -> Result { - if !self.completed_initialization() { - return Err(Error::NotInitialized); - } - // If tools is not supported, return an error - if self.server_capabilities.as_ref().unwrap().tools.is_none() { - return Err(Error::RpcError { - code: METHOD_NOT_FOUND, - message: "Server does not support 'tools' capability".to_string(), - }); - } - - let params = serde_json::json!({ "name": name, "arguments": arguments }); - - // TODO ERROR: check that if there is an error, we send back is_error: true with msg - // https://modelcontextprotocol.io/docs/concepts/tools#error-handling-2 - self.send_request("tools/call", params).await - } - - async fn list_prompts(&self, next_cursor: Option) -> Result { - if !self.completed_initialization() { - return Err(Error::NotInitialized); - } - - // If prompts is not supported, return an error - if self.server_capabilities.as_ref().unwrap().prompts.is_none() { - return Err(Error::RpcError { - code: METHOD_NOT_FOUND, - message: "Server does not support 'prompts' capability".to_string(), - }); - } - - let payload = next_cursor - .map(|cursor| serde_json::json!({"cursor": cursor})) - .unwrap_or_else(|| serde_json::json!({})); - - self.send_request("prompts/list", payload).await - } - - async fn get_prompt(&self, name: &str, arguments: Value) -> Result { - if !self.completed_initialization() { - return Err(Error::NotInitialized); - } - - // If prompts is not supported, return an error - if self.server_capabilities.as_ref().unwrap().prompts.is_none() { - return Err(Error::RpcError { - code: METHOD_NOT_FOUND, - message: "Server does not support 'prompts' capability".to_string(), - }); - } - - let params = serde_json::json!({ "name": name, "arguments": arguments }); - - self.send_request("prompts/get", params).await - } -} diff --git a/crates/mcp-client/src/lib.rs b/crates/mcp-client/src/lib.rs deleted file mode 100644 index 985d89d1..00000000 --- a/crates/mcp-client/src/lib.rs +++ /dev/null @@ -1,7 +0,0 @@ -pub mod client; -pub mod service; -pub mod transport; - -pub use client::{ClientCapabilities, ClientInfo, Error, McpClient, McpClientTrait}; -pub use service::McpService; -pub use transport::{SseTransport, StdioTransport, Transport, TransportHandle}; diff --git a/crates/mcp-client/src/service.rs b/crates/mcp-client/src/service.rs deleted file mode 100644 index 00aa95be..00000000 --- a/crates/mcp-client/src/service.rs +++ /dev/null @@ -1,52 +0,0 @@ -use futures::future::BoxFuture; -use mcp_core::protocol::JsonRpcMessage; -use std::sync::Arc; -use std::task::{Context, Poll}; -use tower::{timeout::Timeout, Service, ServiceBuilder}; - -use crate::transport::{Error, TransportHandle}; - -/// A wrapper service that implements Tower's Service trait for MCP transport -#[derive(Clone)] -pub struct McpService { - inner: Arc, -} - -impl McpService { - pub fn new(transport: T) -> Self { - Self { - inner: Arc::new(transport), - } - } -} - -impl Service for McpService -where - T: TransportHandle + Send + Sync + 'static, -{ - type Response = JsonRpcMessage; - type Error = Error; - type Future = BoxFuture<'static, Result>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - // Most transports are always ready, but this could be customized if needed - Poll::Ready(Ok(())) - } - - fn call(&mut self, request: JsonRpcMessage) -> Self::Future { - let transport = self.inner.clone(); - Box::pin(async move { transport.send(request).await }) - } -} - -// Add a convenience constructor for creating a service with timeout -impl McpService -where - T: TransportHandle, -{ - pub fn with_timeout(transport: T, timeout: std::time::Duration) -> Timeout> { - ServiceBuilder::new() - .timeout(timeout) - .service(McpService::new(transport)) - } -} diff --git a/crates/mcp-client/src/transport/mod.rs b/crates/mcp-client/src/transport/mod.rs deleted file mode 100644 index 25bcef74..00000000 --- a/crates/mcp-client/src/transport/mod.rs +++ /dev/null @@ -1,127 +0,0 @@ -use async_trait::async_trait; -use mcp_core::protocol::JsonRpcMessage; -use std::collections::HashMap; -use thiserror::Error; -use tokio::sync::{mpsc, oneshot, RwLock}; - -pub type BoxError = Box; -/// A generic error type for transport operations. -#[derive(Debug, Error)] -pub enum Error { - #[error("I/O error: {0}")] - Io(#[from] std::io::Error), - - #[error("Transport was not connected or is already closed")] - NotConnected, - - #[error("Channel closed")] - ChannelClosed, - - #[error("Serialization error: {0}")] - Serialization(#[from] serde_json::Error), - - #[error("Unsupported message type. JsonRpcMessage can only be Request or Notification.")] - UnsupportedMessage, - - #[error("Stdio process error: {0}")] - StdioProcessError(String), - - #[error("SSE connection error: {0}")] - SseConnection(String), - - #[error("HTTP error: {status} - {message}")] - HttpError { status: u16, message: String }, -} - -/// A message that can be sent through the transport -#[derive(Debug)] -pub struct TransportMessage { - /// The JSON-RPC message to send - pub message: JsonRpcMessage, - /// Channel to receive the response on (None for notifications) - pub response_tx: Option>>, -} - -/// A generic asynchronous transport trait with channel-based communication -#[async_trait] -pub trait Transport { - type Handle: TransportHandle; - - /// Start the transport and establish the underlying connection. - /// Returns the transport handle for sending messages. - async fn start(&self) -> Result; - - /// Close the transport and free any resources. - async fn close(&self) -> Result<(), Error>; -} - -#[async_trait] -pub trait TransportHandle: Send + Sync + Clone + 'static { - async fn send(&self, message: JsonRpcMessage) -> Result; -} - -// Helper function that contains the common send implementation -pub async fn send_message( - sender: &mpsc::Sender, - message: JsonRpcMessage, -) -> Result { - match message { - JsonRpcMessage::Request(request) => { - let (respond_to, response) = oneshot::channel(); - let msg = TransportMessage { - message: JsonRpcMessage::Request(request), - response_tx: Some(respond_to), - }; - sender.send(msg).await.map_err(|_| Error::ChannelClosed)?; - Ok(response.await.map_err(|_| Error::ChannelClosed)??) - } - JsonRpcMessage::Notification(notification) => { - let msg = TransportMessage { - message: JsonRpcMessage::Notification(notification), - response_tx: None, - }; - sender.send(msg).await.map_err(|_| Error::ChannelClosed)?; - Ok(JsonRpcMessage::Nil) - } - _ => Err(Error::UnsupportedMessage), - } -} - -// A data structure to store pending requests and their response channels -pub struct PendingRequests { - requests: RwLock>>>, -} - -impl Default for PendingRequests { - fn default() -> Self { - Self::new() - } -} - -impl PendingRequests { - pub fn new() -> Self { - Self { - requests: RwLock::new(HashMap::new()), - } - } - - pub async fn insert(&self, id: String, sender: oneshot::Sender>) { - self.requests.write().await.insert(id, sender); - } - - pub async fn respond(&self, id: &str, response: Result) { - if let Some(tx) = self.requests.write().await.remove(id) { - let _ = tx.send(response); - } - } - - pub async fn clear(&self) { - self.requests.write().await.clear(); - } -} - -pub mod stdio; -pub use stdio::StdioTransport; - -pub mod sse; -pub use sse::SseTransport; diff --git a/crates/mcp-client/src/transport/sse.rs b/crates/mcp-client/src/transport/sse.rs deleted file mode 100644 index 90dc5f2f..00000000 --- a/crates/mcp-client/src/transport/sse.rs +++ /dev/null @@ -1,309 +0,0 @@ -use crate::transport::{Error, PendingRequests, TransportMessage}; -use async_trait::async_trait; -use eventsource_client::{Client, SSE}; -use futures::TryStreamExt; -use mcp_core::protocol::{JsonRpcMessage, JsonRpcRequest}; -use reqwest::Client as HttpClient; -use std::collections::HashMap; -use std::sync::Arc; -use tokio::sync::{mpsc, RwLock}; -use tokio::time::{timeout, Duration}; -use tracing::warn; -use url::Url; - -use super::{send_message, Transport, TransportHandle}; - -// Timeout for the endpoint discovery -const ENDPOINT_TIMEOUT_SECS: u64 = 5; - -/// The SSE-based actor that continuously: -/// - Reads incoming events from the SSE stream. -/// - Sends outgoing messages via HTTP POST (once the post endpoint is known). -pub struct SseActor { - /// Receives messages (requests/notifications) from the handle - receiver: mpsc::Receiver, - /// Map of request-id -> oneshot sender - pending_requests: Arc, - /// Base SSE URL - sse_url: String, - /// For sending HTTP POST requests - http_client: HttpClient, - /// The discovered endpoint for POST requests (once "endpoint" SSE event arrives) - post_endpoint: Arc>>, -} - -impl SseActor { - pub fn new( - receiver: mpsc::Receiver, - pending_requests: Arc, - sse_url: String, - post_endpoint: Arc>>, - ) -> Self { - Self { - receiver, - pending_requests, - sse_url, - post_endpoint, - http_client: HttpClient::new(), - } - } - - /// The main entry point for the actor. Spawns two concurrent loops: - /// 1) handle_incoming_messages (SSE events) - /// 2) handle_outgoing_messages (sending messages via POST) - pub async fn run(self) { - tokio::join!( - Self::handle_incoming_messages( - self.sse_url.clone(), - Arc::clone(&self.pending_requests), - Arc::clone(&self.post_endpoint) - ), - Self::handle_outgoing_messages( - self.receiver, - self.http_client.clone(), - Arc::clone(&self.post_endpoint), - Arc::clone(&self.pending_requests), - ) - ); - } - - /// Continuously reads SSE events from `sse_url`. - /// - If an `endpoint` event is received, store it in `post_endpoint`. - /// - If a `message` event is received, parse it as `JsonRpcMessage` - /// and respond to pending requests if it's a `Response`. - async fn handle_incoming_messages( - sse_url: String, - pending_requests: Arc, - post_endpoint: Arc>>, - ) { - let client = match eventsource_client::ClientBuilder::for_url(&sse_url) { - Ok(builder) => builder.build(), - Err(e) => { - pending_requests.clear().await; - warn!("Failed to connect SSE client: {}", e); - return; - } - }; - let mut stream = client.stream(); - - // First, wait for the "endpoint" event - while let Ok(Some(event)) = stream.try_next().await { - match event { - SSE::Event(e) if e.event_type == "endpoint" => { - // SSE server uses the "endpoint" event to tell us the POST URL - let base_url = Url::parse(&sse_url).expect("Invalid base URL"); - let post_url = base_url - .join(&e.data) - .expect("Failed to resolve endpoint URL"); - - tracing::debug!("Discovered SSE POST endpoint: {}", post_url); - *post_endpoint.write().await = Some(post_url.to_string()); - break; - } - _ => continue, - } - } - - // Now handle subsequent events - while let Ok(Some(event)) = stream.try_next().await { - match event { - SSE::Event(e) if e.event_type == "message" => { - // Attempt to parse the SSE data as a JsonRpcMessage - match serde_json::from_str::(&e.data) { - Ok(message) => { - match &message { - JsonRpcMessage::Response(response) => { - if let Some(id) = &response.id { - pending_requests - .respond(&id.to_string(), Ok(message)) - .await; - } - } - JsonRpcMessage::Error(error) => { - if let Some(id) = &error.id { - pending_requests - .respond(&id.to_string(), Ok(message)) - .await; - } - } - _ => {} // TODO: Handle other variants (Request, etc.) - } - } - Err(err) => { - warn!("Failed to parse SSE message: {err}"); - } - } - } - _ => { /* ignore other events */ } - } - } - - // SSE stream ended or errored; signal any pending requests - tracing::error!("SSE stream ended or encountered an error; clearing pending requests."); - pending_requests.clear().await; - } - - /// Continuously receives messages from the `mpsc::Receiver`. - /// - If it's a request, store the oneshot in `pending_requests`. - /// - POST the message to the discovered endpoint (once known). - async fn handle_outgoing_messages( - mut receiver: mpsc::Receiver, - http_client: HttpClient, - post_endpoint: Arc>>, - pending_requests: Arc, - ) { - while let Some(transport_msg) = receiver.recv().await { - let post_url = match post_endpoint.read().await.as_ref() { - Some(url) => url.clone(), - None => { - if let Some(response_tx) = transport_msg.response_tx { - let _ = response_tx.send(Err(Error::NotConnected)); - } - continue; - } - }; - - // Serialize the JSON-RPC message - let message_str = match serde_json::to_string(&transport_msg.message) { - Ok(s) => s, - Err(e) => { - if let Some(tx) = transport_msg.response_tx { - let _ = tx.send(Err(Error::Serialization(e))); - } - continue; - } - }; - - // If it's a request, store the channel so we can respond later - if let Some(response_tx) = transport_msg.response_tx { - if let JsonRpcMessage::Request(JsonRpcRequest { id: Some(id), .. }) = - &transport_msg.message - { - pending_requests.insert(id.to_string(), response_tx).await; - } - } - - // Perform the HTTP POST - match http_client - .post(&post_url) - .header("Content-Type", "application/json") - .body(message_str) - .send() - .await - { - Ok(resp) => { - if !resp.status().is_success() { - let err = Error::HttpError { - status: resp.status().as_u16(), - message: resp.status().to_string(), - }; - warn!("HTTP request returned error: {err}"); - // This doesn't directly fail the request, - // because we rely on SSE to deliver the error response - } - } - Err(e) => { - warn!("HTTP POST failed: {e}"); - // Similarly, SSE might eventually reveal the error - } - } - } - - // mpsc channel closed => no more outgoing messages - tracing::error!("SseActor: outgoing message loop ended. Clearing pending requests."); - pending_requests.clear().await; - } -} - -#[derive(Clone)] -pub struct SseTransportHandle { - sender: mpsc::Sender, -} - -#[async_trait::async_trait] -impl TransportHandle for SseTransportHandle { - async fn send(&self, message: JsonRpcMessage) -> Result { - send_message(&self.sender, message).await - } -} - -#[derive(Clone)] -pub struct SseTransport { - sse_url: String, - env: HashMap, -} - -/// The SSE transport spawns an `SseActor` on `start()`. -impl SseTransport { - pub fn new>(sse_url: S, env: HashMap) -> Self { - Self { - sse_url: sse_url.into(), - env, - } - } - - /// Waits for the endpoint to be set, up to 10 attempts. - async fn wait_for_endpoint( - post_endpoint: Arc>>, - ) -> Result { - // Check every 100ms for the endpoint, for up to 10 attempts - let check_interval = Duration::from_millis(100); - let mut attempts = 0; - let max_attempts = 10; - - while attempts < max_attempts { - if let Some(url) = post_endpoint.read().await.clone() { - return Ok(url); - } - tokio::time::sleep(check_interval).await; - attempts += 1; - } - Err(Error::SseConnection("No endpoint discovered".to_string())) - } -} - -#[async_trait] -impl Transport for SseTransport { - type Handle = SseTransportHandle; - - async fn start(&self) -> Result { - // Set environment variables - for (key, value) in &self.env { - std::env::set_var(key, value); - } - - // Create a channel for outgoing TransportMessages - let (tx, rx) = mpsc::channel(32); - - let post_endpoint: Arc>> = Arc::new(RwLock::new(None)); - let post_endpoint_clone = Arc::clone(&post_endpoint); - - // Build the actor - let actor = SseActor::new( - rx, - Arc::new(PendingRequests::new()), - self.sse_url.clone(), - post_endpoint, - ); - - // Spawn the actor task - tokio::spawn(actor.run()); - - // Wait for the endpoint to be discovered before returning the handle - match timeout( - Duration::from_secs(ENDPOINT_TIMEOUT_SECS), - Self::wait_for_endpoint(post_endpoint_clone), - ) - .await - { - Ok(_) => Ok(SseTransportHandle { sender: tx }), - Err(e) => Err(Error::SseConnection(e.to_string())), - } - } - - async fn close(&self) -> Result<(), Error> { - // For SSE, you might close the stream or send a shutdown signal to the actor. - // Here, we do nothing special. - Ok(()) - } -} diff --git a/crates/mcp-client/src/transport/stdio.rs b/crates/mcp-client/src/transport/stdio.rs deleted file mode 100644 index 7980816b..00000000 --- a/crates/mcp-client/src/transport/stdio.rs +++ /dev/null @@ -1,278 +0,0 @@ -use std::collections::HashMap; -use std::sync::Arc; -use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command}; - -use async_trait::async_trait; -use mcp_core::protocol::JsonRpcMessage; -use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}; -use tokio::sync::{mpsc, Mutex}; - -use super::{send_message, Error, PendingRequests, Transport, TransportHandle, TransportMessage}; - -/// A `StdioTransport` uses a child process's stdin/stdout as a communication channel. -/// -/// It uses channels for message passing and handles responses asynchronously through a background task. -pub struct StdioActor { - receiver: mpsc::Receiver, - pending_requests: Arc, - _process: Child, // we store the process to keep it alive - error_sender: mpsc::Sender, - stdin: ChildStdin, - stdout: ChildStdout, - stderr: ChildStderr, -} - -impl StdioActor { - pub async fn run(mut self) { - use tokio::pin; - - let incoming = Self::handle_incoming_messages(self.stdout, self.pending_requests.clone()); - let outgoing = Self::handle_outgoing_messages( - self.receiver, - self.stdin, - self.pending_requests.clone(), - ); - - // take ownership of futures for tokio::select - pin!(incoming); - pin!(outgoing); - - // Use select! to wait for either I/O completion or process exit - tokio::select! { - result = &mut incoming => { - tracing::debug!("Stdin handler completed: {:?}", result); - } - result = &mut outgoing => { - tracing::debug!("Stdout handler completed: {:?}", result); - } - // capture the status so we don't need to wait for a timeout - status = self._process.wait() => { - tracing::debug!("Process exited with status: {:?}", status); - } - } - - // Then always try to read stderr before cleaning up - let mut stderr_buffer = Vec::new(); - if let Ok(bytes) = self.stderr.read_to_end(&mut stderr_buffer).await { - let err_msg = if bytes > 0 { - String::from_utf8_lossy(&stderr_buffer).to_string() - } else { - "Process ended unexpectedly".to_string() - }; - - tracing::info!("Process stderr: {}", err_msg); - let _ = self - .error_sender - .send(Error::StdioProcessError(err_msg)) - .await; - } - - // Clean up regardless of which path we took - self.pending_requests.clear().await; - } - - async fn handle_incoming_messages(stdout: ChildStdout, pending_requests: Arc) { - let mut reader = BufReader::new(stdout); - let mut line = String::new(); - loop { - match reader.read_line(&mut line).await { - Ok(0) => { - tracing::error!("Child process ended (EOF on stdout)"); - break; - } // EOF - Ok(_) => { - if let Ok(message) = serde_json::from_str::(&line) { - tracing::debug!( - message = ?message, - "Received incoming message" - ); - - match &message { - JsonRpcMessage::Response(response) => { - if let Some(id) = &response.id { - pending_requests.respond(&id.to_string(), Ok(message)).await; - } - } - JsonRpcMessage::Error(error) => { - if let Some(id) = &error.id { - pending_requests.respond(&id.to_string(), Ok(message)).await; - } - } - _ => {} // TODO: Handle other variants (Request, etc.) - } - } - line.clear(); - } - Err(e) => { - tracing::error!(error = ?e, "Error reading line"); - break; - } - } - } - } - - async fn handle_outgoing_messages( - mut receiver: mpsc::Receiver, - mut stdin: ChildStdin, - pending_requests: Arc, - ) { - while let Some(mut transport_msg) = receiver.recv().await { - let message_str = match serde_json::to_string(&transport_msg.message) { - Ok(s) => s, - Err(e) => { - if let Some(tx) = transport_msg.response_tx.take() { - let _ = tx.send(Err(Error::Serialization(e))); - } - continue; - } - }; - - tracing::debug!(message = ?transport_msg.message, "Sending outgoing message"); - - if let Some(response_tx) = transport_msg.response_tx.take() { - if let JsonRpcMessage::Request(request) = &transport_msg.message { - if let Some(id) = &request.id { - pending_requests.insert(id.to_string(), response_tx).await; - } - } - } - - if let Err(e) = stdin - .write_all(format!("{}\n", message_str).as_bytes()) - .await - { - tracing::error!(error = ?e, "Error writing message to child process"); - pending_requests.clear().await; - break; - } - - if let Err(e) = stdin.flush().await { - tracing::error!(error = ?e, "Error flushing message to child process"); - pending_requests.clear().await; - break; - } - } - } -} - -#[derive(Clone)] -pub struct StdioTransportHandle { - sender: mpsc::Sender, - error_receiver: Arc>>, -} - -#[async_trait::async_trait] -impl TransportHandle for StdioTransportHandle { - async fn send(&self, message: JsonRpcMessage) -> Result { - let result = send_message(&self.sender, message).await; - // Check for any pending errors even if send is successful - self.check_for_errors().await?; - result - } -} - -impl StdioTransportHandle { - /// Check if there are any process errors - pub async fn check_for_errors(&self) -> Result<(), Error> { - match self.error_receiver.lock().await.try_recv() { - Ok(error) => { - tracing::debug!("Found error: {:?}", error); - Err(error) - } - Err(_) => Ok(()), - } - } -} - -pub struct StdioTransport { - command: String, - args: Vec, - env: HashMap, -} - -impl StdioTransport { - pub fn new>( - command: S, - args: Vec, - env: HashMap, - ) -> Self { - Self { - command: command.into(), - args, - env, - } - } - - async fn spawn_process(&self) -> Result<(Child, ChildStdin, ChildStdout, ChildStderr), Error> { - let mut command = Command::new(&self.command); - command - .envs(&self.env) - .args(&self.args) - .stdin(std::process::Stdio::piped()) - .stdout(std::process::Stdio::piped()) - .stderr(std::process::Stdio::piped()) - .kill_on_drop(true); - - // Set process group only on Unix systems - #[cfg(unix)] - command.process_group(0); // don't inherit signal handling from parent process - - // Hide console window on Windows - #[cfg(windows)] - command.creation_flags(0x08000000); // CREATE_NO_WINDOW flag - - let mut process = command - .spawn() - .map_err(|e| Error::StdioProcessError(e.to_string()))?; - - let stdin = process - .stdin - .take() - .ok_or_else(|| Error::StdioProcessError("Failed to get stdin".into()))?; - - let stdout = process - .stdout - .take() - .ok_or_else(|| Error::StdioProcessError("Failed to get stdout".into()))?; - - let stderr = process - .stderr - .take() - .ok_or_else(|| Error::StdioProcessError("Failed to get stderr".into()))?; - - Ok((process, stdin, stdout, stderr)) - } -} - -#[async_trait] -impl Transport for StdioTransport { - type Handle = StdioTransportHandle; - - async fn start(&self) -> Result { - let (process, stdin, stdout, stderr) = self.spawn_process().await?; - let (message_tx, message_rx) = mpsc::channel(32); - let (error_tx, error_rx) = mpsc::channel(1); - - let actor = StdioActor { - receiver: message_rx, - pending_requests: Arc::new(PendingRequests::new()), - _process: process, - error_sender: error_tx, - stdin, - stdout, - stderr, - }; - - tokio::spawn(actor.run()); - - let handle = StdioTransportHandle { - sender: message_tx, - error_receiver: Arc::new(Mutex::new(error_rx)), - }; - Ok(handle) - } - - async fn close(&self) -> Result<(), Error> { - Ok(()) - } -} diff --git a/crates/mcp-core/Cargo.toml b/crates/mcp-core/Cargo.toml deleted file mode 100644 index 4550598a..00000000 --- a/crates/mcp-core/Cargo.toml +++ /dev/null @@ -1,22 +0,0 @@ -[package] -name = "mcp-core" -license.workspace = true -version.workspace = true -edition.workspace = true -repository.workspace = true -description = "Core types for Model Context Protocol" - - -[dependencies] -async-trait = "0.1" -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" -thiserror = "1.0" -schemars = "0.8" -anyhow = "1.0" -chrono = { version = "0.4.38", features = ["serde"] } -url = "2.5" -base64 = "0.21" - -[dev-dependencies] -tempfile = "3.8" diff --git a/crates/mcp-core/src/content.rs b/crates/mcp-core/src/content.rs deleted file mode 100644 index 01c44be4..00000000 --- a/crates/mcp-core/src/content.rs +++ /dev/null @@ -1,310 +0,0 @@ -/// Content sent around agents, extensions, and LLMs -/// The various content types can be display to humans but also understood by models -/// They include optional annotations used to help inform agent usage -use super::role::Role; -use crate::resource::ResourceContents; -use chrono::{DateTime, Utc}; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct Annotations { - #[serde(skip_serializing_if = "Option::is_none")] - pub audience: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub priority: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub timestamp: Option>, -} - -impl Annotations { - /// Creates a new Annotations instance specifically for resources - /// optional priority, and a timestamp (defaults to now if None) - pub fn for_resource(priority: f32, timestamp: DateTime) -> Self { - assert!( - (0.0..=1.0).contains(&priority), - "Priority {priority} must be between 0.0 and 1.0" - ); - Annotations { - priority: Some(priority), - timestamp: Some(timestamp), - audience: None, - } - } -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct TextContent { - pub text: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub annotations: Option, -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ImageContent { - pub data: String, - pub mime_type: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub annotations: Option, -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct EmbeddedResource { - pub resource: ResourceContents, - #[serde(skip_serializing_if = "Option::is_none")] - pub annotations: Option, -} - -impl EmbeddedResource { - pub fn get_text(&self) -> String { - match &self.resource { - ResourceContents::TextResourceContents { text, .. } => text.clone(), - _ => String::new(), - } - } -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -#[serde(tag = "type", rename_all = "camelCase")] -pub enum Content { - Text(TextContent), - Image(ImageContent), - Resource(EmbeddedResource), -} - -impl Content { - pub fn text>(text: S) -> Self { - Content::Text(TextContent { - text: text.into(), - annotations: None, - }) - } - - pub fn image, T: Into>(data: S, mime_type: T) -> Self { - Content::Image(ImageContent { - data: data.into(), - mime_type: mime_type.into(), - annotations: None, - }) - } - - pub fn resource(resource: ResourceContents) -> Self { - Content::Resource(EmbeddedResource { - resource, - annotations: None, - }) - } - - pub fn embedded_text, T: Into>(uri: S, content: T) -> Self { - Content::Resource(EmbeddedResource { - resource: ResourceContents::TextResourceContents { - uri: uri.into(), - mime_type: Some("text".to_string()), - text: content.into(), - }, - annotations: None, - }) - } - - /// Get the text content if this is a TextContent variant - pub fn as_text(&self) -> Option<&str> { - match self { - Content::Text(text) => Some(&text.text), - _ => None, - } - } - - /// Get the image content if this is an ImageContent variant - pub fn as_image(&self) -> Option<(&str, &str)> { - match self { - Content::Image(image) => Some((&image.data, &image.mime_type)), - _ => None, - } - } - - /// Set the audience for the content - pub fn with_audience(mut self, audience: Vec) -> Self { - let annotations = match &mut self { - Content::Text(text) => &mut text.annotations, - Content::Image(image) => &mut image.annotations, - Content::Resource(resource) => &mut resource.annotations, - }; - *annotations = Some(match annotations.take() { - Some(mut a) => { - a.audience = Some(audience); - a - } - None => Annotations { - audience: Some(audience), - priority: None, - timestamp: None, - }, - }); - self - } - - /// Set the priority for the content - /// # Panics - /// Panics if priority is not between 0.0 and 1.0 inclusive - pub fn with_priority(mut self, priority: f32) -> Self { - if !(0.0..=1.0).contains(&priority) { - panic!("Priority must be between 0.0 and 1.0"); - } - let annotations = match &mut self { - Content::Text(text) => &mut text.annotations, - Content::Image(image) => &mut image.annotations, - Content::Resource(resource) => &mut resource.annotations, - }; - *annotations = Some(match annotations.take() { - Some(mut a) => { - a.priority = Some(priority); - a - } - None => Annotations { - audience: None, - priority: Some(priority), - timestamp: None, - }, - }); - self - } - - /// Get the audience if set - pub fn audience(&self) -> Option<&Vec> { - match self { - Content::Text(text) => text.annotations.as_ref().and_then(|a| a.audience.as_ref()), - Content::Image(image) => image.annotations.as_ref().and_then(|a| a.audience.as_ref()), - Content::Resource(resource) => resource - .annotations - .as_ref() - .and_then(|a| a.audience.as_ref()), - } - } - - /// Get the priority if set - pub fn priority(&self) -> Option { - match self { - Content::Text(text) => text.annotations.as_ref().and_then(|a| a.priority), - Content::Image(image) => image.annotations.as_ref().and_then(|a| a.priority), - Content::Resource(resource) => resource.annotations.as_ref().and_then(|a| a.priority), - } - } - - pub fn unannotated(&self) -> Self { - match self { - Content::Text(text) => Content::text(text.text.clone()), - Content::Image(image) => Content::image(image.data.clone(), image.mime_type.clone()), - Content::Resource(resource) => Content::resource(resource.resource.clone()), - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_content_text() { - let content = Content::text("hello"); - assert_eq!(content.as_text(), Some("hello")); - assert_eq!(content.as_image(), None); - } - - #[test] - fn test_content_image() { - let content = Content::image("data", "image/png"); - assert_eq!(content.as_text(), None); - assert_eq!(content.as_image(), Some(("data", "image/png"))); - } - - #[test] - fn test_content_annotations_basic() { - let content = Content::text("hello") - .with_audience(vec![Role::User]) - .with_priority(0.5); - assert_eq!(content.audience(), Some(&vec![Role::User])); - assert_eq!(content.priority(), Some(0.5)); - } - - #[test] - fn test_content_annotations_order_independence() { - let content1 = Content::text("hello") - .with_audience(vec![Role::User]) - .with_priority(0.5); - let content2 = Content::text("hello") - .with_priority(0.5) - .with_audience(vec![Role::User]); - - assert_eq!(content1.audience(), content2.audience()); - assert_eq!(content1.priority(), content2.priority()); - } - - #[test] - fn test_content_annotations_overwrite() { - let content = Content::text("hello") - .with_audience(vec![Role::User]) - .with_priority(0.5) - .with_audience(vec![Role::Assistant]) - .with_priority(0.8); - - assert_eq!(content.audience(), Some(&vec![Role::Assistant])); - assert_eq!(content.priority(), Some(0.8)); - } - - #[test] - fn test_content_annotations_image() { - let content = Content::image("data", "image/png") - .with_audience(vec![Role::User]) - .with_priority(0.5); - - assert_eq!(content.audience(), Some(&vec![Role::User])); - assert_eq!(content.priority(), Some(0.5)); - } - - #[test] - fn test_content_annotations_preservation() { - let text_content = Content::text("hello") - .with_audience(vec![Role::User]) - .with_priority(0.5); - - match &text_content { - Content::Text(TextContent { annotations, .. }) => { - assert!(annotations.is_some()); - let ann = annotations.as_ref().unwrap(); - assert_eq!(ann.audience, Some(vec![Role::User])); - assert_eq!(ann.priority, Some(0.5)); - } - _ => panic!("Expected Text content"), - } - } - - #[test] - #[should_panic(expected = "Priority must be between 0.0 and 1.0")] - fn test_invalid_priority() { - Content::text("hello").with_priority(1.5); - } - - #[test] - fn test_unannotated() { - let content = Content::text("hello") - .with_audience(vec![Role::User]) - .with_priority(0.5); - let unannotated = content.unannotated(); - assert_eq!(unannotated.audience(), None); - assert_eq!(unannotated.priority(), None); - } - - #[test] - fn test_partial_annotations() { - let content = Content::text("hello").with_priority(0.5); - assert_eq!(content.audience(), None); - assert_eq!(content.priority(), Some(0.5)); - - let content = Content::text("hello").with_audience(vec![Role::User]); - assert_eq!(content.audience(), Some(&vec![Role::User])); - assert_eq!(content.priority(), None); - } -} diff --git a/crates/mcp-core/src/handler.rs b/crates/mcp-core/src/handler.rs deleted file mode 100644 index 338fe94e..00000000 --- a/crates/mcp-core/src/handler.rs +++ /dev/null @@ -1,73 +0,0 @@ -use async_trait::async_trait; -use schemars::JsonSchema; -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use thiserror::Error; - -#[non_exhaustive] -#[derive(Error, Debug, Clone, Deserialize, Serialize, PartialEq)] -pub enum ToolError { - #[error("Invalid parameters: {0}")] - InvalidParameters(String), - #[error("Execution failed: {0}")] - ExecutionError(String), - #[error("Schema error: {0}")] - SchemaError(String), - #[error("Tool not found: {0}")] - NotFound(String), -} - -pub type ToolResult = std::result::Result; - -#[derive(Error, Debug)] -pub enum ResourceError { - #[error("Execution failed: {0}")] - ExecutionError(String), - #[error("Resource not found: {0}")] - NotFound(String), -} - -#[derive(Error, Debug)] -pub enum PromptError { - #[error("Invalid parameters: {0}")] - InvalidParameters(String), - #[error("Internal error: {0}")] - InternalError(String), - #[error("Prompt not found: {0}")] - NotFound(String), -} - -/// Trait for implementing MCP tools -#[async_trait] -pub trait ToolHandler: Send + Sync + 'static { - /// The name of the tool - fn name(&self) -> &'static str; - - /// A description of what the tool does - fn description(&self) -> &'static str; - - /// JSON schema describing the tool's parameters - fn schema(&self) -> Value; - - /// Execute the tool with the given parameters - async fn call(&self, params: Value) -> ToolResult; -} - -/// Trait for implementing MCP resources -#[async_trait] -pub trait ResourceTemplateHandler: Send + Sync + 'static { - /// The URL template for this resource - fn template() -> &'static str; - - /// JSON schema describing the resource parameters - fn schema() -> Value; - - /// Get the resource value - async fn get(&self, params: Value) -> ToolResult; -} - -/// Helper function to generate JSON schema for a type -pub fn generate_schema() -> ToolResult { - let schema = schemars::schema_for!(T); - serde_json::to_value(schema).map_err(|e| ToolError::SchemaError(e.to_string())) -} diff --git a/crates/mcp-core/src/lib.rs b/crates/mcp-core/src/lib.rs deleted file mode 100644 index 5a37ceea..00000000 --- a/crates/mcp-core/src/lib.rs +++ /dev/null @@ -1,12 +0,0 @@ -pub mod content; -pub use content::{Annotations, Content, ImageContent, TextContent}; -pub mod handler; -pub mod role; -pub use role::Role; -pub mod tool; -pub use tool::{Tool, ToolCall}; -pub mod resource; -pub use resource::{Resource, ResourceContents}; -pub mod protocol; -pub use handler::{ToolError, ToolResult}; -pub mod prompt; diff --git a/crates/mcp-core/src/protocol.rs b/crates/mcp-core/src/protocol.rs deleted file mode 100644 index 202d514d..00000000 --- a/crates/mcp-core/src/protocol.rs +++ /dev/null @@ -1,289 +0,0 @@ -/// The protocol messages exchanged between client and server -use crate::{ - content::Content, - prompt::{Prompt, PromptMessage}, - resource::Resource, - resource::ResourceContents, - tool::Tool, -}; -use serde::{Deserialize, Serialize}; -use serde_json::Value; - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct JsonRpcRequest { - pub jsonrpc: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub id: Option, - pub method: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub params: Option, -} - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct JsonRpcResponse { - pub jsonrpc: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub result: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub error: Option, -} - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct JsonRpcNotification { - pub jsonrpc: String, - pub method: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub params: Option, -} - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct JsonRpcError { - pub jsonrpc: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub id: Option, - pub error: ErrorData, -} - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -#[serde(untagged, try_from = "JsonRpcRaw")] -pub enum JsonRpcMessage { - Request(JsonRpcRequest), - Response(JsonRpcResponse), - Notification(JsonRpcNotification), - Error(JsonRpcError), - Nil, // used to respond to notifications -} - -#[derive(Debug, Serialize, Deserialize)] -struct JsonRpcRaw { - jsonrpc: String, - #[serde(skip_serializing_if = "Option::is_none")] - id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - method: Option, - #[serde(skip_serializing_if = "Option::is_none")] - params: Option, - #[serde(skip_serializing_if = "Option::is_none")] - result: Option, - #[serde(skip_serializing_if = "Option::is_none")] - error: Option, -} - -impl TryFrom for JsonRpcMessage { - type Error = String; - - fn try_from(raw: JsonRpcRaw) -> Result>::Error> { - // If it has an error field, it's an error response - if raw.error.is_some() { - return Ok(JsonRpcMessage::Error(JsonRpcError { - jsonrpc: raw.jsonrpc, - id: raw.id, - error: raw.error.unwrap(), - })); - } - - // If it has a result field, it's a response - if raw.result.is_some() { - return Ok(JsonRpcMessage::Response(JsonRpcResponse { - jsonrpc: raw.jsonrpc, - id: raw.id, - result: raw.result, - error: None, - })); - } - - // If we have a method, it's either a notification or request - if let Some(method) = raw.method { - if raw.id.is_none() { - return Ok(JsonRpcMessage::Notification(JsonRpcNotification { - jsonrpc: raw.jsonrpc, - method, - params: raw.params, - })); - } - - return Ok(JsonRpcMessage::Request(JsonRpcRequest { - jsonrpc: raw.jsonrpc, - id: raw.id, - method, - params: raw.params, - })); - } - - // If we have no method and no result/error, it's a nil response - if raw.id.is_none() && raw.result.is_none() && raw.error.is_none() { - return Ok(JsonRpcMessage::Nil); - } - - // If we get here, something is wrong with the message - Err(format!( - "Invalid JSON-RPC message format: id={:?}, method={:?}, result={:?}, error={:?}", - raw.id, raw.method, raw.result, raw.error - )) - } -} - -// Standard JSON-RPC error codes -pub const PARSE_ERROR: i32 = -32700; -pub const INVALID_REQUEST: i32 = -32600; -pub const METHOD_NOT_FOUND: i32 = -32601; -pub const INVALID_PARAMS: i32 = -32602; -pub const INTERNAL_ERROR: i32 = -32603; - -/// Error information for JSON-RPC error responses. -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct ErrorData { - /// The error type that occurred. - pub code: i32, - - /// A short description of the error. The message SHOULD be limited to a concise single sentence. - pub message: String, - - /// Additional information about the error. The value of this member is defined by the - /// sender (e.g. detailed error information, nested errors etc.). - #[serde(skip_serializing_if = "Option::is_none")] - pub data: Option, -} - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -#[serde(rename_all = "camelCase")] -pub struct InitializeResult { - pub protocol_version: String, - pub capabilities: ServerCapabilities, - pub server_info: Implementation, - #[serde(skip_serializing_if = "Option::is_none")] - pub instructions: Option, -} - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct Implementation { - pub name: String, - pub version: String, -} - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct ServerCapabilities { - #[serde(skip_serializing_if = "Option::is_none")] - pub prompts: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub resources: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option, - // Add other capabilities as needed -} - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -#[serde(rename_all = "camelCase")] -pub struct PromptsCapability { - pub list_changed: Option, -} - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -#[serde(rename_all = "camelCase")] -pub struct ResourcesCapability { - pub subscribe: Option, - pub list_changed: Option, -} - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -#[serde(rename_all = "camelCase")] -pub struct ToolsCapability { - pub list_changed: Option, -} - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -#[serde(rename_all = "camelCase")] -pub struct ListResourcesResult { - pub resources: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub next_cursor: Option, -} - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct ReadResourceResult { - pub contents: Vec, -} - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -#[serde(rename_all = "camelCase")] -pub struct ListToolsResult { - pub tools: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub next_cursor: Option, -} - -#[derive(Debug, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct CallToolResult { - pub content: Vec, - #[serde(skip_serializing_if = "Option::is_none")] - pub is_error: Option, -} - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct ListPromptsResult { - pub prompts: Vec, -} - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub struct GetPromptResult { - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - pub messages: Vec, -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct EmptyResult {} - -#[cfg(test)] -mod tests { - use super::*; - use serde_json::json; - - #[test] - fn test_notification_conversion() { - let raw = JsonRpcRaw { - jsonrpc: "2.0".to_string(), - id: None, - method: Some("notify".to_string()), - params: Some(json!({"key": "value"})), - result: None, - error: None, - }; - - let message = JsonRpcMessage::try_from(raw).unwrap(); - match message { - JsonRpcMessage::Notification(n) => { - assert_eq!(n.jsonrpc, "2.0"); - assert_eq!(n.method, "notify"); - assert_eq!(n.params.unwrap(), json!({"key": "value"})); - } - _ => panic!("Expected Notification"), - } - } - - #[test] - fn test_request_conversion() { - let raw = JsonRpcRaw { - jsonrpc: "2.0".to_string(), - id: Some(1), - method: Some("request".to_string()), - params: Some(json!({"key": "value"})), - result: None, - error: None, - }; - - let message = JsonRpcMessage::try_from(raw).unwrap(); - match message { - JsonRpcMessage::Request(r) => { - assert_eq!(r.jsonrpc, "2.0"); - assert_eq!(r.id, Some(1)); - assert_eq!(r.method, "request"); - assert_eq!(r.params.unwrap(), json!({"key": "value"})); - } - _ => panic!("Expected Request"), - } - } -} diff --git a/crates/mcp-core/src/resource.rs b/crates/mcp-core/src/resource.rs deleted file mode 100644 index a81155c8..00000000 --- a/crates/mcp-core/src/resource.rs +++ /dev/null @@ -1,260 +0,0 @@ -/// Resources that servers provide to clients -use anyhow::{anyhow, Result}; -use chrono::{DateTime, Utc}; -use serde::{Deserialize, Serialize}; -use url::Url; - -use crate::content::Annotations; - -const EPSILON: f32 = 1e-6; // Tolerance for floating point comparison - -/// Represents a resource in the extension with metadata -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -#[serde(rename_all = "camelCase")] -pub struct Resource { - /// URI representing the resource location (e.g., "file:///path/to/file" or "str:///content") - pub uri: String, - /// Name of the resource - pub name: String, - /// Optional description of the resource - #[serde(skip_serializing_if = "Option::is_none")] - pub description: Option, - /// MIME type of the resource content ("text" or "blob") - #[serde(default = "default_mime_type")] - pub mime_type: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub annotations: Option, -} - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -#[serde(rename_all = "camelCase", untagged)] -pub enum ResourceContents { - TextResourceContents { - uri: String, - #[serde(skip_serializing_if = "Option::is_none")] - mime_type: Option, - text: String, - }, - BlobResourceContents { - uri: String, - #[serde(skip_serializing_if = "Option::is_none")] - mime_type: Option, - blob: String, - }, -} - -fn default_mime_type() -> String { - "text".to_string() -} - -impl Resource { - /// Creates a new Resource from a URI with explicit mime type - pub fn new>( - uri: S, - mime_type: Option, - name: Option, - ) -> Result { - let uri = uri.as_ref(); - let url = Url::parse(uri).map_err(|e| anyhow!("Invalid URI: {}", e))?; - - // Extract name from the path component of the URI - // Use provided name if available, otherwise extract from URI - let name = match name { - Some(n) => n, - None => url - .path_segments() - .and_then(|segments| segments.last()) - .unwrap_or("unnamed") - .to_string(), - }; - - // Use provided mime_type or default - let mime_type = match mime_type { - Some(t) if t == "text" || t == "blob" => t, - _ => default_mime_type(), - }; - - Ok(Self { - uri: uri.to_string(), - name, - description: None, - mime_type, - annotations: Some(Annotations::for_resource(0.0, Utc::now())), - }) - } - - /// Creates a new Resource with explicit URI, name, and priority - pub fn with_uri>( - uri: S, - name: S, - priority: f32, - mime_type: Option, - ) -> Result { - let uri_string = uri.into(); - Url::parse(&uri_string).map_err(|e| anyhow!("Invalid URI: {}", e))?; - - // Use provided mime_type or default - let mime_type = match mime_type { - Some(t) if t == "text" || t == "blob" => t, - _ => default_mime_type(), - }; - - Ok(Self { - uri: uri_string, - name: name.into(), - description: None, - mime_type, - annotations: Some(Annotations::for_resource(priority, Utc::now())), - }) - } - - /// Updates the resource's timestamp to the current time - pub fn update_timestamp(&mut self) { - self.annotations.as_mut().unwrap().timestamp = Some(Utc::now()); - } - - /// Sets the priority of the resource and returns self for method chaining - pub fn with_priority(mut self, priority: f32) -> Self { - self.annotations.as_mut().unwrap().priority = Some(priority); - self - } - - /// Mark the resource as active, i.e. set its priority to 1.0 - pub fn mark_active(self) -> Self { - self.with_priority(1.0) - } - - // Check if the resource is active - pub fn is_active(&self) -> bool { - if let Some(priority) = self.priority() { - (priority - 1.0).abs() < EPSILON - } else { - false - } - } - - /// Returns the priority of the resource, if set - pub fn priority(&self) -> Option { - self.annotations.as_ref().and_then(|a| a.priority) - } - - /// Returns the timestamp of the resource, if set - pub fn timestamp(&self) -> Option> { - self.annotations.as_ref().and_then(|a| a.timestamp) - } - - /// Returns the scheme of the URI - pub fn scheme(&self) -> Result { - let url = Url::parse(&self.uri)?; - Ok(url.scheme().to_string()) - } - - /// Sets the description of the resource - pub fn with_description>(mut self, description: S) -> Self { - self.description = Some(description.into()); - self - } - - /// Sets the MIME type of the resource - pub fn with_mime_type>(mut self, mime_type: S) -> Self { - let mime_type = mime_type.into(); - match mime_type.as_str() { - "text" | "blob" => self.mime_type = mime_type, - _ => self.mime_type = default_mime_type(), - } - self - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::io::Write; - use tempfile::NamedTempFile; - - #[test] - fn test_new_resource_with_file_uri() -> Result<()> { - let mut temp_file = NamedTempFile::new()?; - writeln!(temp_file, "test content")?; - - let uri = Url::from_file_path(temp_file.path()) - .map_err(|_| anyhow!("Invalid file path"))? - .to_string(); - - let resource = Resource::new(&uri, Some("text".to_string()), None)?; - assert!(resource.uri.starts_with("file:///")); - assert_eq!(resource.priority(), Some(0.0)); - assert_eq!(resource.mime_type, "text"); - assert_eq!(resource.scheme()?, "file"); - - Ok(()) - } - - #[test] - fn test_resource_with_str_uri() -> Result<()> { - let test_content = "Hello, world!"; - let uri = format!("str:///{}", test_content); - let resource = Resource::with_uri( - uri.clone(), - "test.txt".to_string(), - 0.5, - Some("text".to_string()), - )?; - - assert_eq!(resource.uri, uri); - assert_eq!(resource.name, "test.txt"); - assert_eq!(resource.priority(), Some(0.5)); - assert_eq!(resource.mime_type, "text"); - assert_eq!(resource.scheme()?, "str"); - - Ok(()) - } - - #[test] - fn test_mime_type_validation() -> Result<()> { - // Test valid mime types - let resource = Resource::new("file:///test.txt", Some("text".to_string()), None)?; - assert_eq!(resource.mime_type, "text"); - - let resource = Resource::new("file:///test.bin", Some("blob".to_string()), None)?; - assert_eq!(resource.mime_type, "blob"); - - // Test invalid mime type defaults to "text" - let resource = Resource::new("file:///test.txt", Some("invalid".to_string()), None)?; - assert_eq!(resource.mime_type, "text"); - - // Test None defaults to "text" - let resource = Resource::new("file:///test.txt", None, None)?; - assert_eq!(resource.mime_type, "text"); - - Ok(()) - } - - #[test] - fn test_with_description() -> Result<()> { - let resource = Resource::with_uri("file:///test.txt", "test.txt", 0.0, None)? - .with_description("A test resource"); - - assert_eq!(resource.description, Some("A test resource".to_string())); - Ok(()) - } - - #[test] - fn test_with_mime_type() -> Result<()> { - let resource = - Resource::with_uri("file:///test.txt", "test.txt", 0.0, None)?.with_mime_type("blob"); - - assert_eq!(resource.mime_type, "blob"); - - // Test invalid mime type defaults to "text" - let resource = resource.with_mime_type("invalid"); - assert_eq!(resource.mime_type, "text"); - Ok(()) - } - - #[test] - fn test_invalid_uri() { - let result = Resource::new("not-a-uri", None, None); - assert!(result.is_err()); - } -} diff --git a/crates/mcp-core/src/role.rs b/crates/mcp-core/src/role.rs deleted file mode 100644 index 38f3a872..00000000 --- a/crates/mcp-core/src/role.rs +++ /dev/null @@ -1,9 +0,0 @@ -/// Roles to describe the origin/ownership of content -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -#[serde(rename_all = "lowercase")] -pub enum Role { - User, - Assistant, -} diff --git a/crates/mcp-core/src/tool.rs b/crates/mcp-core/src/tool.rs deleted file mode 100644 index 3f6f4246..00000000 --- a/crates/mcp-core/src/tool.rs +++ /dev/null @@ -1,51 +0,0 @@ -/// Tools represent a routine that a server can execute -/// Tool calls represent requests from the client to execute one -use serde::{Deserialize, Serialize}; -use serde_json::Value; - -/// A tool that can be used by a model. -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct Tool { - /// The name of the tool - pub name: String, - /// A description of what the tool does - pub description: String, - /// A JSON Schema object defining the expected parameters for the tool - pub input_schema: Value, -} - -impl Tool { - /// Create a new tool with the given name and description - pub fn new(name: N, description: D, input_schema: Value) -> Self - where - N: Into, - D: Into, - { - Tool { - name: name.into(), - description: description.into(), - input_schema, - } - } -} - -/// A tool call request that an extension can execute -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct ToolCall { - /// The name of the tool to execute - pub name: String, - /// The parameters for the execution - pub arguments: Value, -} - -impl ToolCall { - /// Create a new ToolUse with the given name and parameters - pub fn new>(name: S, arguments: Value) -> Self { - Self { - name: name.into(), - arguments, - } - } -} diff --git a/crates/mcp-macros/Cargo.toml b/crates/mcp-macros/Cargo.toml deleted file mode 100644 index dc5d8fd3..00000000 --- a/crates/mcp-macros/Cargo.toml +++ /dev/null @@ -1,26 +0,0 @@ -[package] -name = "mcp-macros" -version.workspace = true -edition.workspace = true -license.workspace = true -repository.workspace = true - -[lib] -proc-macro = true - -[dependencies] -syn = { version = "2.0", features = ["full", "extra-traits"] } -quote = "1.0" -proc-macro2 = "1.0" -serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" -mcp-core = { path = "../mcp-core" } -async-trait = "0.1" -schemars = "0.8" -convert_case = "0.6.0" - -[dev-dependencies] -tokio = { version = "1.0", features = ["full"] } -async-trait = "0.1" -serde_json = "1.0" -schemars = "0.8" diff --git a/crates/mcp-macros/src/lib.rs b/crates/mcp-macros/src/lib.rs deleted file mode 100644 index d918d07c..00000000 --- a/crates/mcp-macros/src/lib.rs +++ /dev/null @@ -1,152 +0,0 @@ -use convert_case::{Case, Casing}; -use proc_macro::TokenStream; -use quote::{format_ident, quote}; -use std::collections::HashMap; -use syn::{ - parse::Parse, parse::ParseStream, parse_macro_input, punctuated::Punctuated, Expr, ExprLit, - FnArg, ItemFn, Lit, Meta, Pat, PatType, Token, -}; - -struct MacroArgs { - name: Option, - description: Option, - param_descriptions: HashMap, -} - -impl Parse for MacroArgs { - fn parse(input: ParseStream) -> syn::Result { - let mut name = None; - let mut description = None; - let mut param_descriptions = HashMap::new(); - - let meta_list: Punctuated = Punctuated::parse_terminated(input)?; - - for meta in meta_list { - match meta { - Meta::NameValue(nv) => { - let ident = nv.path.get_ident().unwrap().to_string(); - if let Expr::Lit(ExprLit { - lit: Lit::Str(lit_str), - .. - }) = nv.value - { - match ident.as_str() { - "name" => name = Some(lit_str.value()), - "description" => description = Some(lit_str.value()), - _ => {} - } - } - } - Meta::List(list) if list.path.is_ident("params") => { - let nested: Punctuated = - list.parse_args_with(Punctuated::parse_terminated)?; - - for meta in nested { - if let Meta::NameValue(nv) = meta { - if let Expr::Lit(ExprLit { - lit: Lit::Str(lit_str), - .. - }) = nv.value - { - let param_name = nv.path.get_ident().unwrap().to_string(); - param_descriptions.insert(param_name, lit_str.value()); - } - } - } - } - _ => {} - } - } - - Ok(MacroArgs { - name, - description, - param_descriptions, - }) - } -} - -#[proc_macro_attribute] -pub fn tool(args: TokenStream, input: TokenStream) -> TokenStream { - let args = parse_macro_input!(args as MacroArgs); - let input_fn = parse_macro_input!(input as ItemFn); - - // Extract function details - let fn_name = &input_fn.sig.ident; - let fn_name_str = fn_name.to_string(); - - // Generate PascalCase struct name from the function name - let struct_name = format_ident!("{}", { fn_name_str.to_case(Case::Pascal) }); - - // Use provided name or function name as default - let tool_name = args.name.unwrap_or(fn_name_str); - let tool_description = args.description.unwrap_or_default(); - - // Extract parameter names, types, and descriptions - let mut param_defs = Vec::new(); - let mut param_names = Vec::new(); - - for arg in input_fn.sig.inputs.iter() { - if let FnArg::Typed(PatType { pat, ty, .. }) = arg { - if let Pat::Ident(param_ident) = &**pat { - let param_name = ¶m_ident.ident; - let param_name_str = param_name.to_string(); - let description = args - .param_descriptions - .get(¶m_name_str) - .map(|s| s.as_str()) - .unwrap_or(""); - - param_names.push(param_name); - param_defs.push(quote! { - #[schemars(description = #description)] - #param_name: #ty - }); - } - } - } - - // Generate the implementation - let params_struct_name = format_ident!("{}Parameters", struct_name); - let expanded = quote! { - #[derive(serde::Deserialize, schemars::JsonSchema)] - struct #params_struct_name { - #(#param_defs,)* - } - - #input_fn - - #[derive(Default)] - struct #struct_name; - - #[async_trait::async_trait] - impl mcp_core::handler::ToolHandler for #struct_name { - fn name(&self) -> &'static str { - #tool_name - } - - fn description(&self) -> &'static str { - #tool_description - } - - fn schema(&self) -> serde_json::Value { - mcp_core::handler::generate_schema::<#params_struct_name>() - .expect("Failed to generate schema") - } - - async fn call(&self, params: serde_json::Value) -> Result { - let params: #params_struct_name = serde_json::from_value(params) - .map_err(|e| mcp_core::handler::ToolError::InvalidParameters(e.to_string()))?; - - // Extract parameters and call the function - let result = #fn_name(#(params.#param_names,)*).await - .map_err(|e| mcp_core::handler::ToolError::ExecutionError(e.to_string()))?; - - Ok(serde_json::to_value(result).expect("should serialize")) - - } - } - }; - - TokenStream::from(expanded) -} diff --git a/crates/mcp-server/Cargo.toml b/crates/mcp-server/Cargo.toml deleted file mode 100644 index 6eb5703c..00000000 --- a/crates/mcp-server/Cargo.toml +++ /dev/null @@ -1,25 +0,0 @@ -[package] -name = "mcp-server" -license.workspace = true -version.workspace = true -edition.workspace = true -repository.workspace = true -description = "Server SDK for the Model Context Protocol" - -[dependencies] -anyhow = "1.0.94" -thiserror = "1.0" -mcp-core = { workspace = true } -mcp-macros = { workspace = true } -serde = { version = "1.0.216", features = ["derive"] } -serde_json = "1.0.133" -schemars = "0.8" -tokio = { version = "1", features = ["io-util"] } -tower = { version = "0.4", features = ["timeout"] } -tower-service = "0.3" -futures = "0.3" -pin-project = "1.1" -tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter"] } -tracing-appender = "0.2" -async-trait = "0.1" diff --git a/crates/mcp-server/README.md b/crates/mcp-server/README.md deleted file mode 100644 index 1e4f0617..00000000 --- a/crates/mcp-server/README.md +++ /dev/null @@ -1,7 +0,0 @@ -### Test with MCP Inspector - -```bash -npx @modelcontextprotocol/inspector cargo run -p mcp-server -``` - -Then visit the Inspector in the browser window and test the different endpoints. \ No newline at end of file diff --git a/crates/mcp-server/src/errors.rs b/crates/mcp-server/src/errors.rs deleted file mode 100644 index 7ebe2534..00000000 --- a/crates/mcp-server/src/errors.rs +++ /dev/null @@ -1,104 +0,0 @@ -use thiserror::Error; - -pub type BoxError = Box; - -#[derive(Error, Debug)] -pub enum TransportError { - #[error("IO error: {0}")] - Io(#[from] std::io::Error), - - #[error("JSON serialization error: {0}")] - Json(#[from] serde_json::Error), - - #[error("Invalid UTF-8 sequence: {0}")] - Utf8(#[from] std::string::FromUtf8Error), - - #[error("Protocol error: {0}")] - Protocol(String), - - #[error("Invalid message format: {0}")] - InvalidMessage(String), -} - -#[derive(Error, Debug)] -pub enum ServerError { - #[error("Transport error: {0}")] - Transport(#[from] TransportError), - - #[error("Service error: {0}")] - Service(String), - - #[error("Internal error: {0}")] - Internal(String), - - #[error("Request timed out")] - Timeout(#[from] tower::timeout::error::Elapsed), -} - -#[derive(Error, Debug)] -pub enum RouterError { - #[error("Method not found: {0}")] - MethodNotFound(String), - - #[error("Invalid parameters: {0}")] - InvalidParams(String), - - #[error("Internal error: {0}")] - Internal(String), - - #[error("Tool not found: {0}")] - ToolNotFound(String), - - #[error("Resource not found: {0}")] - ResourceNotFound(String), - - #[error("Not found: {0}")] - PromptNotFound(String), -} - -impl From for mcp_core::protocol::ErrorData { - fn from(err: RouterError) -> Self { - use mcp_core::protocol::*; - match err { - RouterError::MethodNotFound(msg) => ErrorData { - code: METHOD_NOT_FOUND, - message: msg, - data: None, - }, - RouterError::InvalidParams(msg) => ErrorData { - code: INVALID_PARAMS, - message: msg, - data: None, - }, - RouterError::Internal(msg) => ErrorData { - code: INTERNAL_ERROR, - message: msg, - data: None, - }, - RouterError::ToolNotFound(msg) => ErrorData { - code: INVALID_REQUEST, - message: msg, - data: None, - }, - RouterError::ResourceNotFound(msg) => ErrorData { - code: INVALID_REQUEST, - message: msg, - data: None, - }, - RouterError::PromptNotFound(msg) => ErrorData { - code: INVALID_REQUEST, - message: msg, - data: None, - }, - } - } -} - -impl From for RouterError { - fn from(err: mcp_core::handler::ResourceError) -> Self { - match err { - mcp_core::handler::ResourceError::NotFound(msg) => RouterError::ResourceNotFound(msg), - _ => RouterError::Internal("Unknown resource error".to_string()), - } - } -} diff --git a/crates/mcp-server/src/lib.rs b/crates/mcp-server/src/lib.rs deleted file mode 100644 index 5e1a1bae..00000000 --- a/crates/mcp-server/src/lib.rs +++ /dev/null @@ -1,275 +0,0 @@ -use std::{ - pin::Pin, - task::{Context, Poll}, -}; - -use futures::{Future, Stream}; -use mcp_core::protocol::{JsonRpcError, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse}; -use pin_project::pin_project; -use tokio::io::{AsyncBufReadExt, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader}; -use tower_service::Service; - -mod errors; -pub use errors::{BoxError, RouterError, ServerError, TransportError}; - -pub mod router; -pub use router::Router; - -/// A transport layer that handles JSON-RPC messages over byte -#[pin_project] -pub struct ByteTransport { - // Reader is a BufReader on the underlying stream (stdin or similar) buffering - // the underlying data across poll calls, we clear one line (\n) during each - // iteration of poll_next from this buffer - #[pin] - reader: BufReader, - #[pin] - writer: W, -} - -impl ByteTransport -where - R: AsyncRead, - W: AsyncWrite, -{ - pub fn new(reader: R, writer: W) -> Self { - Self { - // Default BufReader capacity is 8 * 1024, increase this to 2MB to the file size limit - // allows the buffer to have the capacity to read very large calls - reader: BufReader::with_capacity(2 * 1024 * 1024, reader), - writer, - } - } -} - -impl Stream for ByteTransport -where - R: AsyncRead + Unpin, - W: AsyncWrite + Unpin, -{ - type Item = Result; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut this = self.project(); - let mut buf = Vec::new(); - - let mut reader = this.reader.as_mut(); - let mut read_future = Box::pin(reader.read_until(b'\n', &mut buf)); - match read_future.as_mut().poll(cx) { - Poll::Ready(Ok(0)) => Poll::Ready(None), // EOF - Poll::Ready(Ok(_)) => { - // Convert to UTF-8 string - let line = match String::from_utf8(buf) { - Ok(s) => s, - Err(e) => return Poll::Ready(Some(Err(TransportError::Utf8(e)))), - }; - // Log incoming message here before serde conversion to - // track incomplete chunks which are not valid JSON - tracing::debug!(json = %line, "incoming message"); - - // Parse JSON and validate message format - match serde_json::from_str::(&line) { - Ok(value) => { - // Validate basic JSON-RPC structure - if !value.is_object() { - return Poll::Ready(Some(Err(TransportError::InvalidMessage( - "Message must be a JSON object".into(), - )))); - } - let obj = value.as_object().unwrap(); // Safe due to check above - - // Check jsonrpc version field - if !obj.contains_key("jsonrpc") || obj["jsonrpc"] != "2.0" { - return Poll::Ready(Some(Err(TransportError::InvalidMessage( - "Missing or invalid jsonrpc version".into(), - )))); - } - - // Now try to parse as proper message - match serde_json::from_value::(value) { - Ok(msg) => Poll::Ready(Some(Ok(msg))), - Err(e) => Poll::Ready(Some(Err(TransportError::Json(e)))), - } - } - Err(e) => Poll::Ready(Some(Err(TransportError::Json(e)))), - } - } - Poll::Ready(Err(e)) => Poll::Ready(Some(Err(TransportError::Io(e)))), - Poll::Pending => Poll::Pending, - } - } -} - -impl ByteTransport -where - R: AsyncRead + Unpin, - W: AsyncWrite + Unpin, -{ - pub async fn write_message( - self: &mut Pin<&mut Self>, - msg: JsonRpcMessage, - ) -> Result<(), std::io::Error> { - let json = serde_json::to_string(&msg)?; - - let mut this = self.as_mut().project(); - this.writer.write_all(json.as_bytes()).await?; - this.writer.write_all(b"\n").await?; - this.writer.flush().await?; - - Ok(()) - } -} - -/// The main server type that processes incoming requests -pub struct Server { - service: S, -} - -impl Server -where - S: Service + Send, - S::Error: Into, - S::Future: Send, -{ - pub fn new(service: S) -> Self { - Self { service } - } - - // TODO transport trait instead of byte transport if we implement others - pub async fn run(self, mut transport: ByteTransport) -> Result<(), ServerError> - where - R: AsyncRead + Unpin, - W: AsyncWrite + Unpin, - { - use futures::StreamExt; - let mut service = self.service; - let mut transport = Pin::new(&mut transport); - - tracing::info!("Server started"); - while let Some(msg_result) = transport.next().await { - let _span = tracing::span!(tracing::Level::INFO, "message_processing"); - let _enter = _span.enter(); - match msg_result { - Ok(msg) => { - match msg { - JsonRpcMessage::Request(request) => { - // Serialize request for logging - let id = request.id; - let request_json = serde_json::to_string(&request) - .unwrap_or_else(|_| "Failed to serialize request".to_string()); - - tracing::debug!( - request_id = ?id, - method = ?request.method, - json = %request_json, - "Received request" - ); - - // Process the request using our service - let response = match service.call(request).await { - Ok(resp) => resp, - Err(e) => { - let error_msg = e.into().to_string(); - tracing::error!(error = %error_msg, "Request processing failed"); - JsonRpcResponse { - jsonrpc: "2.0".to_string(), - id, - result: None, - error: Some(mcp_core::protocol::ErrorData { - code: mcp_core::protocol::INTERNAL_ERROR, - message: error_msg, - data: None, - }), - } - } - }; - - // Serialize response for logging - let response_json = serde_json::to_string(&response) - .unwrap_or_else(|_| "Failed to serialize response".to_string()); - - tracing::debug!( - response_id = ?response.id, - json = %response_json, - "Sending response" - ); - // Send the response back - if let Err(e) = transport - .write_message(JsonRpcMessage::Response(response)) - .await - { - return Err(ServerError::Transport(TransportError::Io(e))); - } - } - JsonRpcMessage::Response(_) - | JsonRpcMessage::Notification(_) - | JsonRpcMessage::Nil - | JsonRpcMessage::Error(_) => { - // Ignore responses, notifications and nil messages for now - continue; - } - } - } - Err(e) => { - // Convert transport error to JSON-RPC error response - let error = match e { - TransportError::Json(_) | TransportError::InvalidMessage(_) => { - mcp_core::protocol::ErrorData { - code: mcp_core::protocol::PARSE_ERROR, - message: e.to_string(), - data: None, - } - } - TransportError::Protocol(_) => mcp_core::protocol::ErrorData { - code: mcp_core::protocol::INVALID_REQUEST, - message: e.to_string(), - data: None, - }, - _ => mcp_core::protocol::ErrorData { - code: mcp_core::protocol::INTERNAL_ERROR, - message: e.to_string(), - data: None, - }, - }; - - let error_response = JsonRpcMessage::Error(JsonRpcError { - jsonrpc: "2.0".to_string(), - id: None, - error, - }); - - if let Err(e) = transport.write_message(error_response).await { - return Err(ServerError::Transport(TransportError::Io(e))); - } - } - } - } - - Ok(()) - } -} - -// Define a specific service implementation that we need for any -// Any router implements this -pub trait BoundedService: - Service< - JsonRpcRequest, - Response = JsonRpcResponse, - Error = BoxError, - Future = Pin> + Send>>, - > + Send - + 'static -{ -} - -// Implement it for any type that meets the bounds -impl BoundedService for T where - T: Service< - JsonRpcRequest, - Response = JsonRpcResponse, - Error = BoxError, - Future = Pin> + Send>>, - > + Send - + 'static -{ -} diff --git a/crates/mcp-server/src/router.rs b/crates/mcp-server/src/router.rs deleted file mode 100644 index 2c277d1c..00000000 --- a/crates/mcp-server/src/router.rs +++ /dev/null @@ -1,431 +0,0 @@ -use std::{ - future::Future, - pin::Pin, - task::{Context, Poll}, -}; - -type PromptFuture = Pin> + Send + 'static>>; - -use mcp_core::{ - content::Content, - handler::{PromptError, ResourceError, ToolError}, - prompt::{Prompt, PromptMessage, PromptMessageRole}, - protocol::{ - CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcRequest, - JsonRpcResponse, ListPromptsResult, ListResourcesResult, ListToolsResult, - PromptsCapability, ReadResourceResult, ResourcesCapability, ServerCapabilities, - ToolsCapability, - }, - ResourceContents, -}; -use serde_json::Value; -use tower_service::Service; - -use crate::{BoxError, RouterError}; - -/// Builder for configuring and constructing capabilities -pub struct CapabilitiesBuilder { - tools: Option, - prompts: Option, - resources: Option, -} - -impl Default for CapabilitiesBuilder { - fn default() -> Self { - Self::new() - } -} - -impl CapabilitiesBuilder { - pub fn new() -> Self { - Self { - tools: None, - prompts: None, - resources: None, - } - } - - /// Add multiple tools to the router - pub fn with_tools(mut self, list_changed: bool) -> Self { - self.tools = Some(ToolsCapability { - list_changed: Some(list_changed), - }); - self - } - - /// Enable prompts capability - pub fn with_prompts(mut self, list_changed: bool) -> Self { - self.prompts = Some(PromptsCapability { - list_changed: Some(list_changed), - }); - self - } - - /// Enable resources capability - pub fn with_resources(mut self, subscribe: bool, list_changed: bool) -> Self { - self.resources = Some(ResourcesCapability { - subscribe: Some(subscribe), - list_changed: Some(list_changed), - }); - self - } - - /// Build the router with automatic capability inference - pub fn build(self) -> ServerCapabilities { - // Create capabilities based on what's configured - ServerCapabilities { - tools: self.tools, - prompts: self.prompts, - resources: self.resources, - } - } -} - -pub trait Router: Send + Sync + 'static { - fn name(&self) -> String; - // in the protocol, instructions are optional but we make it required - fn instructions(&self) -> String; - fn capabilities(&self) -> ServerCapabilities; - fn list_tools(&self) -> Vec; - fn call_tool( - &self, - tool_name: &str, - arguments: Value, - ) -> Pin, ToolError>> + Send + 'static>>; - fn list_resources(&self) -> Vec; - fn read_resource( - &self, - uri: &str, - ) -> Pin> + Send + 'static>>; - fn list_prompts(&self) -> Vec; - fn get_prompt(&self, prompt_name: &str) -> PromptFuture; - - // Helper method to create base response - fn create_response(&self, id: Option) -> JsonRpcResponse { - JsonRpcResponse { - jsonrpc: "2.0".to_string(), - id, - result: None, - error: None, - } - } - - fn handle_initialize( - &self, - req: JsonRpcRequest, - ) -> impl Future> + Send { - async move { - let result = InitializeResult { - protocol_version: "2024-11-05".to_string(), - capabilities: self.capabilities().clone(), - server_info: Implementation { - name: self.name(), - version: env!("CARGO_PKG_VERSION").to_string(), - }, - instructions: Some(self.instructions()), - }; - - let mut response = self.create_response(req.id); - response.result = - Some(serde_json::to_value(result).map_err(|e| { - RouterError::Internal(format!("JSON serialization error: {}", e)) - })?); - - Ok(response) - } - } - - fn handle_tools_list( - &self, - req: JsonRpcRequest, - ) -> impl Future> + Send { - async move { - let tools = self.list_tools(); - - let result = ListToolsResult { - tools, - next_cursor: None, - }; - let mut response = self.create_response(req.id); - response.result = - Some(serde_json::to_value(result).map_err(|e| { - RouterError::Internal(format!("JSON serialization error: {}", e)) - })?); - - Ok(response) - } - } - - fn handle_tools_call( - &self, - req: JsonRpcRequest, - ) -> impl Future> + Send { - async move { - let params = req - .params - .ok_or_else(|| RouterError::InvalidParams("Missing parameters".into()))?; - - let name = params - .get("name") - .and_then(Value::as_str) - .ok_or_else(|| RouterError::InvalidParams("Missing tool name".into()))?; - - let arguments = params.get("arguments").cloned().unwrap_or(Value::Null); - - let result = match self.call_tool(name, arguments).await { - Ok(result) => CallToolResult { - content: result, - is_error: None, - }, - Err(err) => CallToolResult { - content: vec![Content::text(err.to_string())], - is_error: Some(true), - }, - }; - - let mut response = self.create_response(req.id); - response.result = - Some(serde_json::to_value(result).map_err(|e| { - RouterError::Internal(format!("JSON serialization error: {}", e)) - })?); - - Ok(response) - } - } - - fn handle_resources_list( - &self, - req: JsonRpcRequest, - ) -> impl Future> + Send { - async move { - let resources = self.list_resources(); - - let result = ListResourcesResult { - resources, - next_cursor: None, - }; - let mut response = self.create_response(req.id); - response.result = - Some(serde_json::to_value(result).map_err(|e| { - RouterError::Internal(format!("JSON serialization error: {}", e)) - })?); - - Ok(response) - } - } - - fn handle_resources_read( - &self, - req: JsonRpcRequest, - ) -> impl Future> + Send { - async move { - let params = req - .params - .ok_or_else(|| RouterError::InvalidParams("Missing parameters".into()))?; - - let uri = params - .get("uri") - .and_then(Value::as_str) - .ok_or_else(|| RouterError::InvalidParams("Missing resource URI".into()))?; - - let contents = self.read_resource(uri).await.map_err(RouterError::from)?; - - let result = ReadResourceResult { - contents: vec![ResourceContents::TextResourceContents { - uri: uri.to_string(), - mime_type: Some("text/plain".to_string()), - text: contents, - }], - }; - - let mut response = self.create_response(req.id); - response.result = - Some(serde_json::to_value(result).map_err(|e| { - RouterError::Internal(format!("JSON serialization error: {}", e)) - })?); - - Ok(response) - } - } - - fn handle_prompts_list( - &self, - req: JsonRpcRequest, - ) -> impl Future> + Send { - async move { - let prompts = self.list_prompts(); - - let result = ListPromptsResult { prompts }; - - let mut response = self.create_response(req.id); - response.result = - Some(serde_json::to_value(result).map_err(|e| { - RouterError::Internal(format!("JSON serialization error: {}", e)) - })?); - - Ok(response) - } - } - - fn handle_prompts_get( - &self, - req: JsonRpcRequest, - ) -> impl Future> + Send { - async move { - // Validate and extract parameters - let params = req - .params - .ok_or_else(|| RouterError::InvalidParams("Missing parameters".into()))?; - - // Extract "name" field - let prompt_name = params - .get("name") - .and_then(Value::as_str) - .ok_or_else(|| RouterError::InvalidParams("Missing prompt name".into()))?; - - // Extract "arguments" field - let arguments = params - .get("arguments") - .and_then(Value::as_object) - .ok_or_else(|| RouterError::InvalidParams("Missing arguments object".into()))?; - - // Fetch the prompt definition first - let prompt = self - .list_prompts() - .into_iter() - .find(|p| p.name == prompt_name) - .ok_or_else(|| { - RouterError::PromptNotFound(format!("Prompt '{}' not found", prompt_name)) - })?; - - // Validate required arguments - if let Some(args) = &prompt.arguments { - for arg in args { - if arg.required.is_some() - && arg.required.unwrap() - && (!arguments.contains_key(&arg.name) - || arguments - .get(&arg.name) - .and_then(Value::as_str) - .is_none_or(str::is_empty)) - { - return Err(RouterError::InvalidParams(format!( - "Missing required argument: '{}'", - arg.name - ))); - } - } - } - - // Now get the prompt content - let description = self - .get_prompt(prompt_name) - .await - .map_err(|e| RouterError::Internal(e.to_string()))?; - - // Validate prompt arguments for potential security issues from user text input - // Checks: - // - Prompt must be less than 10000 total characters - // - Argument keys must be less than 1000 characters - // - Argument values must be less than 1000 characters - // - Dangerous patterns, eg "../", "//", "\\\\", "