Skip to content

Commit 8a38a3a

Browse files
committed
mcp: add CallToolResult.getError
Provide a way for middleware to get the error from a tool call, demonstrating that we can add this functionality without making a breaking change. Leave getError unexported for now; we can export it (and setError) at any time. For #64.
1 parent 22f86c4 commit 8a38a3a

File tree

2 files changed

+78
-0
lines changed

2 files changed

+78
-0
lines changed

mcp/mcp_test.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1837,3 +1837,68 @@ func TestEmbeddedStructResponse(t *testing.T) {
18371837
t.Errorf("CallTool() failed: %v", err)
18381838
}
18391839
}
1840+
1841+
func TestToolErrorMiddleware(t *testing.T) {
1842+
ctx := context.Background()
1843+
ct, st := NewInMemoryTransports()
1844+
1845+
s := NewServer(testImpl, nil)
1846+
AddTool(s, &Tool{
1847+
Name: "greet",
1848+
Description: "say hi",
1849+
}, sayHi)
1850+
AddTool(s, &Tool{Name: "fail", InputSchema: &jsonschema.Schema{Type: "object"}},
1851+
func(context.Context, *CallToolRequest, map[string]any) (*CallToolResult, any, error) {
1852+
return nil, nil, errTestFailure
1853+
})
1854+
1855+
var middleErr error
1856+
s.AddReceivingMiddleware(func(h MethodHandler) MethodHandler {
1857+
return func(ctx context.Context, method string, req Request) (Result, error) {
1858+
res, err := h(ctx, method, req)
1859+
if err == nil {
1860+
if ctr, ok := res.(*CallToolResult); ok {
1861+
middleErr = ctr.getError()
1862+
}
1863+
}
1864+
return res, err
1865+
}
1866+
})
1867+
_, err := s.Connect(ctx, st, nil)
1868+
if err != nil {
1869+
t.Fatal(err)
1870+
}
1871+
client := NewClient(&Implementation{Name: "test-client"}, nil)
1872+
clientSession, err := client.Connect(ctx, ct, nil)
1873+
if err != nil {
1874+
t.Fatal(err)
1875+
}
1876+
defer clientSession.Close()
1877+
1878+
res, err := clientSession.CallTool(ctx, &CallToolParams{
1879+
Name: "greet",
1880+
Arguments: map[string]any{"Name": "al"},
1881+
})
1882+
if err != nil {
1883+
t.Errorf("CallTool() failed: %v", err)
1884+
}
1885+
if middleErr != nil {
1886+
t.Errorf("middleware got error %v, want nil", middleErr)
1887+
}
1888+
res, err = clientSession.CallTool(ctx, &CallToolParams{
1889+
Name: "fail",
1890+
})
1891+
if err != nil {
1892+
t.Errorf("CallTool() failed: %v", err)
1893+
}
1894+
if !res.IsError {
1895+
t.Fatal("want error, got none")
1896+
}
1897+
// Clients can't see the error, because it isn't marshaled.
1898+
if err := res.getError(); err != nil {
1899+
t.Fatalf("got %v, want nil", err)
1900+
}
1901+
if middleErr != errTestFailure {
1902+
t.Errorf("middleware got err %v, want errTestFailure", middleErr)
1903+
}
1904+
}

mcp/protocol.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,26 @@ type CallToolResult struct {
103103
// tool handler returns an error, and the error string is included as text in
104104
// the Content field.
105105
IsError bool `json:"isError,omitempty"`
106+
107+
// The error passed to setError, if any.
108+
// It is not marshaled, and therefore it is only visible on the server.
109+
// Its only use is in server sending middleware, where it can be accessed
110+
// with getError.
111+
err error
106112
}
107113

108114
// TODO(#64): consider exposing setError (and getError), by adding an error
109115
// field on CallToolResult.
110116
func (r *CallToolResult) setError(err error) {
111117
r.Content = []Content{&TextContent{Text: err.Error()}}
112118
r.IsError = true
119+
r.err = err
120+
}
121+
122+
// getError returns the error set with setError, or nil if none.
123+
// This function always returns nil on clients.
124+
func (r *CallToolResult) getError() error {
125+
return r.err
113126
}
114127

115128
func (*CallToolResult) isResult() {}

0 commit comments

Comments
 (0)