Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mcp/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ func TestCmdTransport(t *testing.T) {
&mcp.TextContent{Text: "Hi user"},
},
}
if diff := cmp.Diff(want, got); diff != "" {
if diff := cmp.Diff(want, got, ctrCmpOpts...); diff != "" {
t.Errorf("greet returned unexpected content (-want +got):\n%s", diff)
}
if err := session.Close(); err != nil {
Expand Down
4 changes: 3 additions & 1 deletion mcp/content_nil_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func TestContentUnmarshalNil(t *testing.T) {
}

// Verify that the Content field was properly populated
if cmp.Diff(tt.want, tt.content) != "" {
if cmp.Diff(tt.want, tt.content, ctrCmpOpts...) != "" {
t.Errorf("Content is not equal: %v", cmp.Diff(tt.content, tt.content))
}
})
Expand Down Expand Up @@ -222,3 +222,5 @@ func TestContentUnmarshalNilWithInvalidContent(t *testing.T) {
})
}
}

var ctrCmpOpts = []cmp.Option{cmp.AllowUnexported(mcp.CallToolResult{})}
75 changes: 71 additions & 4 deletions mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ func TestEndToEnd(t *testing.T) {
&TextContent{Text: "hi user"},
},
}
if diff := cmp.Diff(wantHi, gotHi); diff != "" {
if diff := cmp.Diff(wantHi, gotHi, ctrCmpOpts...); diff != "" {
t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff)
}

Expand All @@ -253,7 +253,7 @@ func TestEndToEnd(t *testing.T) {
&TextContent{Text: errTestFailure.Error()},
},
}
if diff := cmp.Diff(wantFail, gotFail); diff != "" {
if diff := cmp.Diff(wantFail, gotFail, ctrCmpOpts...); diff != "" {
t.Errorf("tools/call 'fail' mismatch (-want +got):\n%s", diff)
}

Expand Down Expand Up @@ -1717,7 +1717,7 @@ func TestPointerArgEquivalence(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(r0, r1); diff != "" {
if diff := cmp.Diff(r0, r1, ctrCmpOpts...); diff != "" {
t.Errorf("CallTool(%v) with no arguments mismatch (-%s +%s):\n%s", args, t0.Name, t1.Name, diff)
}
}
Expand All @@ -1733,7 +1733,7 @@ func TestPointerArgEquivalence(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(r0, r1); diff != "" {
if diff := cmp.Diff(r0, r1, ctrCmpOpts...); diff != "" {
t.Errorf("CallTool({\"In\": %q}) mismatch (-%s +%s):\n%s", in, t0.Name, t1.Name, diff)
}
})
Expand Down Expand Up @@ -1837,3 +1837,70 @@ func TestEmbeddedStructResponse(t *testing.T) {
t.Errorf("CallTool() failed: %v", err)
}
}

func TestToolErrorMiddleware(t *testing.T) {
ctx := context.Background()
ct, st := NewInMemoryTransports()

s := NewServer(testImpl, nil)
AddTool(s, &Tool{
Name: "greet",
Description: "say hi",
}, sayHi)
AddTool(s, &Tool{Name: "fail", InputSchema: &jsonschema.Schema{Type: "object"}},
func(context.Context, *CallToolRequest, map[string]any) (*CallToolResult, any, error) {
return nil, nil, errTestFailure
})

var middleErr error
s.AddReceivingMiddleware(func(h MethodHandler) MethodHandler {
return func(ctx context.Context, method string, req Request) (Result, error) {
res, err := h(ctx, method, req)
if err == nil {
if ctr, ok := res.(*CallToolResult); ok {
middleErr = ctr.getError()
}
}
return res, err
}
})
_, err := s.Connect(ctx, st, nil)
if err != nil {
t.Fatal(err)
}
client := NewClient(&Implementation{Name: "test-client"}, nil)
clientSession, err := client.Connect(ctx, ct, nil)
if err != nil {
t.Fatal(err)
}
defer clientSession.Close()

_, err = clientSession.CallTool(ctx, &CallToolParams{
Name: "greet",
Arguments: map[string]any{"Name": "al"},
})
if err != nil {
t.Errorf("CallTool() failed: %v", err)
}
if middleErr != nil {
t.Errorf("middleware got error %v, want nil", middleErr)
}
res, err := clientSession.CallTool(ctx, &CallToolParams{
Name: "fail",
})
if err != nil {
t.Errorf("CallTool() failed: %v", err)
}
if !res.IsError {
t.Fatal("want error, got none")
}
// Clients can't see the error, because it isn't marshaled.
if err := res.getError(); err != nil {
t.Fatalf("got %v, want nil", err)
}
if middleErr != errTestFailure {
t.Errorf("middleware got err %v, want errTestFailure", middleErr)
}
}

var ctrCmpOpts = []cmp.Option{cmp.AllowUnexported(CallToolResult{})}
13 changes: 13 additions & 0 deletions mcp/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,26 @@ type CallToolResult struct {
// tool handler returns an error, and the error string is included as text in
// the Content field.
IsError bool `json:"isError,omitempty"`

// The error passed to setError, if any.
// It is not marshaled, and therefore it is only visible on the server.
// Its only use is in server sending middleware, where it can be accessed
// with getError.
err error
}

// TODO(#64): consider exposing setError (and getError), by adding an error
// field on CallToolResult.
func (r *CallToolResult) setError(err error) {
r.Content = []Content{&TextContent{Text: err.Error()}}
r.IsError = true
r.err = err
}

// getError returns the error set with setError, or nil if none.
// This function always returns nil on clients.
func (r *CallToolResult) getError() error {
return r.err
}

func (*CallToolResult) isResult() {}
Expand Down
2 changes: 1 addition & 1 deletion mcp/protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ func TestContentUnmarshal(t *testing.T) {
if err := json.Unmarshal(data, out); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(in, out); diff != "" {
if diff := cmp.Diff(in, out, ctrCmpOpts...); diff != "" {
t.Errorf("mismatch (-want, +got):\n%s", diff)
}
}
Expand Down
2 changes: 1 addition & 1 deletion mcp/sse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func TestSSEServer(t *testing.T) {
&TextContent{Text: "hi user"},
},
}
if diff := cmp.Diff(wantHi, gotHi); diff != "" {
if diff := cmp.Diff(wantHi, gotHi, ctrCmpOpts...); diff != "" {
t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff)
}

Expand Down
4 changes: 1 addition & 3 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ func TestStreamableTransports(t *testing.T) {
want := &CallToolResult{
Content: []Content{&TextContent{Text: "hi foo"}},
}
if diff := cmp.Diff(want, got); diff != "" {
if diff := cmp.Diff(want, got, ctrCmpOpts...); diff != "" {
t.Errorf("CallTool() returned unexpected content (-want +got):\n%s", diff)
}

Expand Down Expand Up @@ -550,8 +550,6 @@ func resp(id int64, result any, err error) *jsonrpc.Response {
}
}

var ()

func TestStreamableServerTransport(t *testing.T) {
// This test checks detailed behavior of the streamable server transport, by
// faking the behavior of a streamable client using a sequence of HTTP
Expand Down