diff --git a/mcp/cmd_test.go b/mcp/cmd_test.go index 0df45708..cbaadcb0 100644 --- a/mcp/cmd_test.go +++ b/mcp/cmd_test.go @@ -226,7 +226,7 @@ func TestCmdTransport(t *testing.T) { &mcp.TextContent{Text: "Hi user"}, }, } - if diff := cmp.Diff(want, got); diff != "" { + if diff := cmp.Diff(want, got, ctrCmpOpts...); diff != "" { t.Errorf("greet returned unexpected content (-want +got):\n%s", diff) } if err := session.Close(); err != nil { diff --git a/mcp/content_nil_test.go b/mcp/content_nil_test.go index 70cabfd7..8cc7bdbd 100644 --- a/mcp/content_nil_test.go +++ b/mcp/content_nil_test.go @@ -72,7 +72,7 @@ func TestContentUnmarshalNil(t *testing.T) { } // Verify that the Content field was properly populated - if cmp.Diff(tt.want, tt.content) != "" { + if cmp.Diff(tt.want, tt.content, ctrCmpOpts...) != "" { t.Errorf("Content is not equal: %v", cmp.Diff(tt.content, tt.content)) } }) @@ -222,3 +222,5 @@ func TestContentUnmarshalNilWithInvalidContent(t *testing.T) { }) } } + +var ctrCmpOpts = []cmp.Option{cmp.AllowUnexported(mcp.CallToolResult{})} diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index dd542d3d..fa941bf0 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -234,7 +234,7 @@ func TestEndToEnd(t *testing.T) { &TextContent{Text: "hi user"}, }, } - if diff := cmp.Diff(wantHi, gotHi); diff != "" { + if diff := cmp.Diff(wantHi, gotHi, ctrCmpOpts...); diff != "" { t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff) } @@ -253,7 +253,7 @@ func TestEndToEnd(t *testing.T) { &TextContent{Text: errTestFailure.Error()}, }, } - if diff := cmp.Diff(wantFail, gotFail); diff != "" { + if diff := cmp.Diff(wantFail, gotFail, ctrCmpOpts...); diff != "" { t.Errorf("tools/call 'fail' mismatch (-want +got):\n%s", diff) } @@ -1717,7 +1717,7 @@ func TestPointerArgEquivalence(t *testing.T) { if err != nil { t.Fatal(err) } - if diff := cmp.Diff(r0, r1); diff != "" { + if diff := cmp.Diff(r0, r1, ctrCmpOpts...); diff != "" { t.Errorf("CallTool(%v) with no arguments mismatch (-%s +%s):\n%s", args, t0.Name, t1.Name, diff) } } @@ -1733,7 +1733,7 @@ func TestPointerArgEquivalence(t *testing.T) { if err != nil { t.Fatal(err) } - if diff := cmp.Diff(r0, r1); diff != "" { + if diff := cmp.Diff(r0, r1, ctrCmpOpts...); diff != "" { t.Errorf("CallTool({\"In\": %q}) mismatch (-%s +%s):\n%s", in, t0.Name, t1.Name, diff) } }) @@ -1837,3 +1837,70 @@ func TestEmbeddedStructResponse(t *testing.T) { t.Errorf("CallTool() failed: %v", err) } } + +func TestToolErrorMiddleware(t *testing.T) { + ctx := context.Background() + ct, st := NewInMemoryTransports() + + s := NewServer(testImpl, nil) + AddTool(s, &Tool{ + Name: "greet", + Description: "say hi", + }, sayHi) + AddTool(s, &Tool{Name: "fail", InputSchema: &jsonschema.Schema{Type: "object"}}, + func(context.Context, *CallToolRequest, map[string]any) (*CallToolResult, any, error) { + return nil, nil, errTestFailure + }) + + var middleErr error + s.AddReceivingMiddleware(func(h MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + res, err := h(ctx, method, req) + if err == nil { + if ctr, ok := res.(*CallToolResult); ok { + middleErr = ctr.getError() + } + } + return res, err + } + }) + _, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + client := NewClient(&Implementation{Name: "test-client"}, nil) + clientSession, err := client.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer clientSession.Close() + + _, err = clientSession.CallTool(ctx, &CallToolParams{ + Name: "greet", + Arguments: map[string]any{"Name": "al"}, + }) + if err != nil { + t.Errorf("CallTool() failed: %v", err) + } + if middleErr != nil { + t.Errorf("middleware got error %v, want nil", middleErr) + } + res, err := clientSession.CallTool(ctx, &CallToolParams{ + Name: "fail", + }) + if err != nil { + t.Errorf("CallTool() failed: %v", err) + } + if !res.IsError { + t.Fatal("want error, got none") + } + // Clients can't see the error, because it isn't marshaled. + if err := res.getError(); err != nil { + t.Fatalf("got %v, want nil", err) + } + if middleErr != errTestFailure { + t.Errorf("middleware got err %v, want errTestFailure", middleErr) + } +} + +var ctrCmpOpts = []cmp.Option{cmp.AllowUnexported(CallToolResult{})} diff --git a/mcp/protocol.go b/mcp/protocol.go index 7be8ea17..3e3c544e 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -103,6 +103,12 @@ type CallToolResult struct { // tool handler returns an error, and the error string is included as text in // the Content field. IsError bool `json:"isError,omitempty"` + + // The error passed to setError, if any. + // It is not marshaled, and therefore it is only visible on the server. + // Its only use is in server sending middleware, where it can be accessed + // with getError. + err error } // TODO(#64): consider exposing setError (and getError), by adding an error @@ -110,6 +116,13 @@ type CallToolResult struct { func (r *CallToolResult) setError(err error) { r.Content = []Content{&TextContent{Text: err.Error()}} r.IsError = true + r.err = err +} + +// getError returns the error set with setError, or nil if none. +// This function always returns nil on clients. +func (r *CallToolResult) getError() error { + return r.err } func (*CallToolResult) isResult() {} diff --git a/mcp/protocol_test.go b/mcp/protocol_test.go index 28e97518..67d021d1 100644 --- a/mcp/protocol_test.go +++ b/mcp/protocol_test.go @@ -499,7 +499,7 @@ func TestContentUnmarshal(t *testing.T) { if err := json.Unmarshal(data, out); err != nil { t.Fatal(err) } - if diff := cmp.Diff(in, out); diff != "" { + if diff := cmp.Diff(in, out, ctrCmpOpts...); diff != "" { t.Errorf("mismatch (-want, +got):\n%s", diff) } } diff --git a/mcp/sse_test.go b/mcp/sse_test.go index 408e92ec..32a20bf3 100644 --- a/mcp/sse_test.go +++ b/mcp/sse_test.go @@ -70,7 +70,7 @@ func TestSSEServer(t *testing.T) { &TextContent{Text: "hi user"}, }, } - if diff := cmp.Diff(wantHi, gotHi); diff != "" { + if diff := cmp.Diff(wantHi, gotHi, ctrCmpOpts...); diff != "" { t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff) } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index e077308c..79e9645f 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -159,7 +159,7 @@ func TestStreamableTransports(t *testing.T) { want := &CallToolResult{ Content: []Content{&TextContent{Text: "hi foo"}}, } - if diff := cmp.Diff(want, got); diff != "" { + if diff := cmp.Diff(want, got, ctrCmpOpts...); diff != "" { t.Errorf("CallTool() returned unexpected content (-want +got):\n%s", diff) } @@ -550,8 +550,6 @@ func resp(id int64, result any, err error) *jsonrpc.Response { } } -var () - func TestStreamableServerTransport(t *testing.T) { // This test checks detailed behavior of the streamable server transport, by // faking the behavior of a streamable client using a sequence of HTTP