Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,11 @@ func (s *Server) getPrompt(ctx context.Context, req *ServerRequest[*GetPromptPar
prompt, ok := s.prompts.get(req.Params.Name)
s.mu.Unlock()
if !ok {
// TODO: surface the error code over the wire, instead of flattening it into the string.
return nil, fmt.Errorf("%s: unknown prompt %q", jsonrpc2.ErrInvalidParams, req.Params.Name)
// Return a proper JSON-RPC error with the correct error code
return nil, &jsonrpc2.WireError{
Code: -32602, // ErrInvalidParams code
Message: fmt.Sprintf("unknown prompt %q", req.Params.Name),
}
}
return prompt.handler(ctx, req.Session, req.Params)
}
Expand All @@ -340,7 +343,10 @@ func (s *Server) callTool(ctx context.Context, req *ServerRequest[*CallToolParam
st, ok := s.tools.get(req.Params.Name)
s.mu.Unlock()
if !ok {
return nil, fmt.Errorf("%s: unknown tool %q", jsonrpc2.ErrInvalidParams, req.Params.Name)
return nil, &jsonrpc2.WireError{
Code: -32602, // ErrInvalidParams code
Message: fmt.Sprintf("unknown tool %q", req.Params.Name),
}
}
return st.handler(ctx, req)
}
Expand Down
12 changes: 10 additions & 2 deletions mcp/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"reflect"

"github.com/google/jsonschema-go/jsonschema"
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
)

// A ToolHandler handles a call to tools/call.
Expand Down Expand Up @@ -69,9 +70,16 @@ func newServerTool[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*serverTool
Session: req.Session,
Params: params,
})
// TODO(rfindley): investigate why server errors are embedded in this strange way,
// rather than returned as jsonrpc2 server errors.
// Handle server errors appropriately:
// - If the handler returns a structured error (like jsonrpc2.WireError), return it directly
// - If the handler returns a regular error, wrap it in a CallToolResult with IsError=true
// - This allows tools to distinguish between protocol errors and tool execution errors
if err != nil {
// Check if this is already a structured JSON-RPC error
if wireErr, ok := err.(*jsonrpc2.WireError); ok {
return nil, wireErr
}
// For regular errors, embed them in the tool result as per MCP spec
return &CallToolResult{
Content: []Content{&TextContent{Text: err.Error()}},
IsError: true,
Expand Down
108 changes: 108 additions & 0 deletions mcp/tool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@ package mcp
import (
"context"
"encoding/json"
"errors"
"fmt"
"reflect"
"strings"
"testing"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/jsonschema-go/jsonschema"
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
)

// testToolHandler is used for type inference in TestNewServerTool.
Expand Down Expand Up @@ -132,3 +136,107 @@ func TestUnmarshalSchema(t *testing.T) {

}
}

func TestToolErrorHandling(t *testing.T) {
// Test that structured JSON-RPC errors are returned directly
t.Run("structured_error", func(t *testing.T) {
server := NewServer(testImpl, nil)

// Create a tool that returns a structured error
structuredErrorHandler := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResultFor[any], error) {
return nil, &jsonrpc2.WireError{
Code: -32603, // ErrInternal
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Define this one too, same place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Message: "internal server error",
}
}

AddTool(server, &Tool{Name: "error_tool", Description: "returns structured error"}, structuredErrorHandler)

// Connect and test
ct, st := NewInMemoryTransports()
_, err := server.Connect(context.Background(), st, nil)
if err != nil {
t.Fatal(err)
}

client := NewClient(testImpl, nil)
cs, err := client.Connect(context.Background(), ct, nil)
if err != nil {
t.Fatal(err)
}
defer cs.Close()

// Call the tool
_, err = cs.CallTool(context.Background(), &CallToolParams{
Name: "error_tool",
Arguments: map[string]any{},
})

// Should get the structured error directly
if err == nil {
t.Fatal("expected error, got nil")
}

var wireErr *jsonrpc2.WireError
if !errors.As(err, &wireErr) {
t.Fatalf("expected WireError, got %T: %v", err, err)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

%T, %[1]%v
then you only need one arg

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}

if wireErr.Code != -32603 {
t.Errorf("expected error code -32603, got %d", wireErr.Code)
}
})

// Test that regular errors are embedded in tool results
t.Run("regular_error", func(t *testing.T) {
server := NewServer(testImpl, nil)

// Create a tool that returns a regular error
regularErrorHandler := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResultFor[any], error) {
return nil, fmt.Errorf("tool execution failed")
}

AddTool(server, &Tool{Name: "regular_error_tool", Description: "returns regular error"}, regularErrorHandler)

// Connect and test
ct, st := NewInMemoryTransports()
_, err := server.Connect(context.Background(), st, nil)
if err != nil {
t.Fatal(err)
}

client := NewClient(testImpl, nil)
cs, err := client.Connect(context.Background(), ct, nil)
if err != nil {
t.Fatal(err)
}
defer cs.Close()

// Call the tool
result, err := cs.CallTool(context.Background(), &CallToolParams{
Name: "regular_error_tool",
Arguments: map[string]any{},
})

// Should not get an error at the protocol level
if err != nil {
t.Fatalf("unexpected protocol error: %v", err)
}

// Should get a result with IsError=true
if !result.IsError {
t.Error("expected IsError=true, got false")
}

// Should have error message in content
if len(result.Content) == 0 {
t.Error("expected error content, got empty")
}

if textContent, ok := result.Content[0].(*TextContent); !ok {
t.Error("expected TextContent")
} else if !strings.Contains(textContent.Text, "tool execution failed") {
t.Errorf("expected error message in content, got: %s", textContent.Text)
}
})
}