Skip to content

Commit 140b939

Browse files
authored
mcp: add CallToolResult.getError (#481)
Provide a way for middleware to get the error from a tool call, demonstrating that we can add this functionality without making a breaking change. Leave getError unexported for now; we can export it (and setError) at any time. For #64.
1 parent b615fa4 commit 140b939

File tree

7 files changed

+91
-11
lines changed

7 files changed

+91
-11
lines changed

mcp/cmd_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ func TestCmdTransport(t *testing.T) {
226226
&mcp.TextContent{Text: "Hi user"},
227227
},
228228
}
229-
if diff := cmp.Diff(want, got); diff != "" {
229+
if diff := cmp.Diff(want, got, ctrCmpOpts...); diff != "" {
230230
t.Errorf("greet returned unexpected content (-want +got):\n%s", diff)
231231
}
232232
if err := session.Close(); err != nil {

mcp/content_nil_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ func TestContentUnmarshalNil(t *testing.T) {
7272
}
7373

7474
// Verify that the Content field was properly populated
75-
if cmp.Diff(tt.want, tt.content) != "" {
75+
if cmp.Diff(tt.want, tt.content, ctrCmpOpts...) != "" {
7676
t.Errorf("Content is not equal: %v", cmp.Diff(tt.content, tt.content))
7777
}
7878
})
@@ -222,3 +222,5 @@ func TestContentUnmarshalNilWithInvalidContent(t *testing.T) {
222222
})
223223
}
224224
}
225+
226+
var ctrCmpOpts = []cmp.Option{cmp.AllowUnexported(mcp.CallToolResult{})}

mcp/mcp_test.go

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ func TestEndToEnd(t *testing.T) {
234234
&TextContent{Text: "hi user"},
235235
},
236236
}
237-
if diff := cmp.Diff(wantHi, gotHi); diff != "" {
237+
if diff := cmp.Diff(wantHi, gotHi, ctrCmpOpts...); diff != "" {
238238
t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff)
239239
}
240240

@@ -253,7 +253,7 @@ func TestEndToEnd(t *testing.T) {
253253
&TextContent{Text: errTestFailure.Error()},
254254
},
255255
}
256-
if diff := cmp.Diff(wantFail, gotFail); diff != "" {
256+
if diff := cmp.Diff(wantFail, gotFail, ctrCmpOpts...); diff != "" {
257257
t.Errorf("tools/call 'fail' mismatch (-want +got):\n%s", diff)
258258
}
259259

@@ -1717,7 +1717,7 @@ func TestPointerArgEquivalence(t *testing.T) {
17171717
if err != nil {
17181718
t.Fatal(err)
17191719
}
1720-
if diff := cmp.Diff(r0, r1); diff != "" {
1720+
if diff := cmp.Diff(r0, r1, ctrCmpOpts...); diff != "" {
17211721
t.Errorf("CallTool(%v) with no arguments mismatch (-%s +%s):\n%s", args, t0.Name, t1.Name, diff)
17221722
}
17231723
}
@@ -1733,7 +1733,7 @@ func TestPointerArgEquivalence(t *testing.T) {
17331733
if err != nil {
17341734
t.Fatal(err)
17351735
}
1736-
if diff := cmp.Diff(r0, r1); diff != "" {
1736+
if diff := cmp.Diff(r0, r1, ctrCmpOpts...); diff != "" {
17371737
t.Errorf("CallTool({\"In\": %q}) mismatch (-%s +%s):\n%s", in, t0.Name, t1.Name, diff)
17381738
}
17391739
})
@@ -1837,3 +1837,70 @@ func TestEmbeddedStructResponse(t *testing.T) {
18371837
t.Errorf("CallTool() failed: %v", err)
18381838
}
18391839
}
1840+
1841+
func TestToolErrorMiddleware(t *testing.T) {
1842+
ctx := context.Background()
1843+
ct, st := NewInMemoryTransports()
1844+
1845+
s := NewServer(testImpl, nil)
1846+
AddTool(s, &Tool{
1847+
Name: "greet",
1848+
Description: "say hi",
1849+
}, sayHi)
1850+
AddTool(s, &Tool{Name: "fail", InputSchema: &jsonschema.Schema{Type: "object"}},
1851+
func(context.Context, *CallToolRequest, map[string]any) (*CallToolResult, any, error) {
1852+
return nil, nil, errTestFailure
1853+
})
1854+
1855+
var middleErr error
1856+
s.AddReceivingMiddleware(func(h MethodHandler) MethodHandler {
1857+
return func(ctx context.Context, method string, req Request) (Result, error) {
1858+
res, err := h(ctx, method, req)
1859+
if err == nil {
1860+
if ctr, ok := res.(*CallToolResult); ok {
1861+
middleErr = ctr.getError()
1862+
}
1863+
}
1864+
return res, err
1865+
}
1866+
})
1867+
_, err := s.Connect(ctx, st, nil)
1868+
if err != nil {
1869+
t.Fatal(err)
1870+
}
1871+
client := NewClient(&Implementation{Name: "test-client"}, nil)
1872+
clientSession, err := client.Connect(ctx, ct, nil)
1873+
if err != nil {
1874+
t.Fatal(err)
1875+
}
1876+
defer clientSession.Close()
1877+
1878+
_, err = clientSession.CallTool(ctx, &CallToolParams{
1879+
Name: "greet",
1880+
Arguments: map[string]any{"Name": "al"},
1881+
})
1882+
if err != nil {
1883+
t.Errorf("CallTool() failed: %v", err)
1884+
}
1885+
if middleErr != nil {
1886+
t.Errorf("middleware got error %v, want nil", middleErr)
1887+
}
1888+
res, err := clientSession.CallTool(ctx, &CallToolParams{
1889+
Name: "fail",
1890+
})
1891+
if err != nil {
1892+
t.Errorf("CallTool() failed: %v", err)
1893+
}
1894+
if !res.IsError {
1895+
t.Fatal("want error, got none")
1896+
}
1897+
// Clients can't see the error, because it isn't marshaled.
1898+
if err := res.getError(); err != nil {
1899+
t.Fatalf("got %v, want nil", err)
1900+
}
1901+
if middleErr != errTestFailure {
1902+
t.Errorf("middleware got err %v, want errTestFailure", middleErr)
1903+
}
1904+
}
1905+
1906+
var ctrCmpOpts = []cmp.Option{cmp.AllowUnexported(CallToolResult{})}

mcp/protocol.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,26 @@ type CallToolResult struct {
103103
// tool handler returns an error, and the error string is included as text in
104104
// the Content field.
105105
IsError bool `json:"isError,omitempty"`
106+
107+
// The error passed to setError, if any.
108+
// It is not marshaled, and therefore it is only visible on the server.
109+
// Its only use is in server sending middleware, where it can be accessed
110+
// with getError.
111+
err error
106112
}
107113

108114
// TODO(#64): consider exposing setError (and getError), by adding an error
109115
// field on CallToolResult.
110116
func (r *CallToolResult) setError(err error) {
111117
r.Content = []Content{&TextContent{Text: err.Error()}}
112118
r.IsError = true
119+
r.err = err
120+
}
121+
122+
// getError returns the error set with setError, or nil if none.
123+
// This function always returns nil on clients.
124+
func (r *CallToolResult) getError() error {
125+
return r.err
113126
}
114127

115128
func (*CallToolResult) isResult() {}

mcp/protocol_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ func TestContentUnmarshal(t *testing.T) {
499499
if err := json.Unmarshal(data, out); err != nil {
500500
t.Fatal(err)
501501
}
502-
if diff := cmp.Diff(in, out); diff != "" {
502+
if diff := cmp.Diff(in, out, ctrCmpOpts...); diff != "" {
503503
t.Errorf("mismatch (-want, +got):\n%s", diff)
504504
}
505505
}

mcp/sse_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ func TestSSEServer(t *testing.T) {
7070
&TextContent{Text: "hi user"},
7171
},
7272
}
73-
if diff := cmp.Diff(wantHi, gotHi); diff != "" {
73+
if diff := cmp.Diff(wantHi, gotHi, ctrCmpOpts...); diff != "" {
7474
t.Errorf("tools/call 'greet' mismatch (-want +got):\n%s", diff)
7575
}
7676

mcp/streamable_test.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ func TestStreamableTransports(t *testing.T) {
159159
want := &CallToolResult{
160160
Content: []Content{&TextContent{Text: "hi foo"}},
161161
}
162-
if diff := cmp.Diff(want, got); diff != "" {
162+
if diff := cmp.Diff(want, got, ctrCmpOpts...); diff != "" {
163163
t.Errorf("CallTool() returned unexpected content (-want +got):\n%s", diff)
164164
}
165165

@@ -550,8 +550,6 @@ func resp(id int64, result any, err error) *jsonrpc.Response {
550550
}
551551
}
552552

553-
var ()
554-
555553
func TestStreamableServerTransport(t *testing.T) {
556554
// This test checks detailed behavior of the streamable server transport, by
557555
// faking the behavior of a streamable client using a sequence of HTTP

0 commit comments

Comments
 (0)