Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions mcp/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,16 @@ mod tests {

#[test]
fn test_mcp_response_constructors() {
let success_response = McpResponse::success(Some(json!(1)), json!({"test": "data"}));
let success_response = McpResponse::success(json!(1), json!({"test": "data"}));
assert_eq!(success_response.jsonrpc, "2.0");
assert_eq!(success_response.id, Some(json!(1)));
assert_eq!(success_response.id, json!(1));
assert!(success_response.result.is_some());
assert!(success_response.error.is_none());

let error_response =
McpResponse::error(Some(json!(2)), -32602, "Invalid params".to_string());
McpResponse::error(json!(2), -32602, "Invalid params".to_string());
assert_eq!(error_response.jsonrpc, "2.0");
assert_eq!(error_response.id, Some(json!(2)));
assert_eq!(error_response.id, json!(2));
assert!(error_response.result.is_none());
assert!(error_response.error.is_some());
assert_eq!(error_response.error.unwrap().code, -32602);
Expand Down
6 changes: 3 additions & 3 deletions mcp/src/mcp_protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pub struct McpRequest {
#[derive(Debug, Serialize)]
pub struct McpResponse {
pub jsonrpc: String,
pub id: Option<serde_json::Value>,
pub id: serde_json::Value, // Always required - never None for valid responses
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
Expand Down Expand Up @@ -270,7 +270,7 @@ pub struct GetPatternsOutput {
// TestStructure and TestStep are now imported from parser crate

impl McpResponse {
pub fn success(id: Option<serde_json::Value>, result: serde_json::Value) -> Self {
pub fn success(id: serde_json::Value, result: serde_json::Value) -> Self {
Self {
jsonrpc: "2.0".to_string(),
id,
Expand All @@ -279,7 +279,7 @@ impl McpResponse {
}
}

pub fn error(id: Option<serde_json::Value>, code: i32, message: String) -> Self {
pub fn error(id: serde_json::Value, code: i32, message: String) -> Self {
Self {
jsonrpc: "2.0".to_string(),
id,
Expand Down
57 changes: 44 additions & 13 deletions mcp/src/server/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,30 @@ impl McpServer {
}

// Parse JSON and handle errors properly
let response = match serde_json::from_str::<McpRequest>(line.trim()) {
let response_opt = match serde_json::from_str::<McpRequest>(line.trim()) {
Ok(request) => self.handle_request(request).await,
Err(_) => {
// Send error response for malformed JSON
McpResponse::error(None, -32700, "Parse error: Invalid JSON".to_string())
// Try to extract the id field from the raw JSON for error response
// If we can't extract id or it's null, it might be a notification, so skip sending response
let extracted_id = serde_json::from_str::<serde_json::Value>(line.trim())
.ok()
.and_then(|v| v.get("id").cloned());

let normalized_id = Self::normalize_id(extracted_id);

// Only send error response if we found a valid non-null id (meaning it's a request, not a notification)
normalized_id.map(|id| {
McpResponse::error(id, -32700, "Parse error: Invalid JSON".to_string())
})
}
};

// Only send response if we have one (skip notifications)
let response = match response_opt {
Some(resp) => resp,
None => continue, // Notification - no response to send
};

// Send response with proper error handling
if let Err(e) = self.send_response(&mut stdout, &response).await {
// Check if it's a broken pipe or connection issue
Expand Down Expand Up @@ -104,20 +120,35 @@ impl McpServer {
Ok(())
}

async fn handle_request(&mut self, request: McpRequest) -> McpResponse {
match request.method.as_str() {
"initialize" => self.handle_initialize(request.id, request.params),
"tools/list" => self.handle_tools_list(request.id),
"tools/call" => self.handle_tools_call(request.id, request.params).await,
async fn handle_request(&mut self, request: McpRequest) -> Option<McpResponse> {
// Normalize null IDs to None (MCP doesn't accept null as a valid response id)
let normalized_id = Self::normalize_id(request.id);

// If request has no valid id (or null id), treat as notification and don't send response
let id = match normalized_id {
Some(id) => id,
None => return None, // Notification - no response
};

Some(match request.method.as_str() {
"initialize" => self.handle_initialize(id, request.params),
"tools/list" => self.handle_tools_list(id),
"tools/call" => self.handle_tools_call(id, request.params).await,
_ => McpResponse::error(
request.id,
id,
-32601,
format!("Method not found: {}", request.method),
),
}
})
}

/// Normalize ID: convert null IDs to None to avoid sending id: null in responses
/// (MCP protocol doesn't accept null as a valid response id value)
fn normalize_id(id: Option<Value>) -> Option<Value> {
id.and_then(|v| if v.is_null() { None } else { Some(v) })
}

fn handle_initialize(&self, id: Option<Value>, _params: Option<Value>) -> McpResponse {
fn handle_initialize(&self, id: Value, _params: Option<Value>) -> McpResponse {
let result = InitializeResult {
protocol_version: "2024-11-05".to_string(),
capabilities: ServerCapabilities {
Expand All @@ -132,7 +163,7 @@ impl McpServer {
McpResponse::success(id, json!(result))
}

fn handle_tools_list(&self, id: Option<Value>) -> McpResponse {
fn handle_tools_list(&self, id: Value) -> McpResponse {
let tools = self.tool_definitions.get_tools();
let result = json!({
"tools": tools
Expand All @@ -141,7 +172,7 @@ impl McpServer {
McpResponse::success(id, result)
}

async fn handle_tools_call(&mut self, id: Option<Value>, params: Option<Value>) -> McpResponse {
async fn handle_tools_call(&mut self, id: Value, params: Option<Value>) -> McpResponse {
let params = match params {
Some(p) => p,
None => return McpResponse::error(id, -32602, "Missing parameters".to_string()),
Expand Down