Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1837,3 +1837,68 @@
t.Errorf("CallTool() failed: %v", err)
}
}

func TestToolErrorMiddleware(t *testing.T) {
ctx := context.Background()
ct, st := NewInMemoryTransports()

s := NewServer(testImpl, nil)
AddTool(s, &Tool{
Name: "greet",
Description: "say hi",
}, sayHi)
AddTool(s, &Tool{Name: "fail", InputSchema: &jsonschema.Schema{Type: "object"}},
func(context.Context, *CallToolRequest, map[string]any) (*CallToolResult, any, error) {
return nil, nil, errTestFailure
})

var middleErr error
s.AddReceivingMiddleware(func(h MethodHandler) MethodHandler {
return func(ctx context.Context, method string, req Request) (Result, error) {
res, err := h(ctx, method, req)
if err == nil {
if ctr, ok := res.(*CallToolResult); ok {
middleErr = ctr.getError()
}
}
return res, err
}
})
_, err := s.Connect(ctx, st, nil)
if err != nil {
t.Fatal(err)
}
client := NewClient(&Implementation{Name: "test-client"}, nil)
clientSession, err := client.Connect(ctx, ct, nil)
if err != nil {
t.Fatal(err)
}
defer clientSession.Close()

res, err := clientSession.CallTool(ctx, &CallToolParams{

Check failure on line 1878 in mcp/mcp_test.go

View workflow job for this annotation

GitHub Actions / lint

this value of res is never used (SA4006)
Name: "greet",
Arguments: map[string]any{"Name": "al"},
})
if err != nil {
t.Errorf("CallTool() failed: %v", err)
}
if middleErr != nil {
t.Errorf("middleware got error %v, want nil", middleErr)
}
res, err = clientSession.CallTool(ctx, &CallToolParams{
Name: "fail",
})
if err != nil {
t.Errorf("CallTool() failed: %v", err)
}
if !res.IsError {
t.Fatal("want error, got none")
}
// Clients can't see the error, because it isn't marshaled.
if err := res.getError(); err != nil {
t.Fatalf("got %v, want nil", err)
}
if middleErr != errTestFailure {
t.Errorf("middleware got err %v, want errTestFailure", middleErr)
}
}
13 changes: 13 additions & 0 deletions mcp/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,26 @@ type CallToolResult struct {
// tool handler returns an error, and the error string is included as text in
// the Content field.
IsError bool `json:"isError,omitempty"`

// The error passed to setError, if any.
// It is not marshaled, and therefore it is only visible on the server.
// Its only use is in server sending middleware, where it can be accessed
// with getError.
err error
}

// TODO(#64): consider exposing setError (and getError), by adding an error
// field on CallToolResult.
func (r *CallToolResult) setError(err error) {
r.Content = []Content{&TextContent{Text: err.Error()}}
r.IsError = true
r.err = err
}

// getError returns the error set with setError, or nil if none.
// This function always returns nil on clients.
func (r *CallToolResult) getError() error {
return r.err
}

func (*CallToolResult) isResult() {}
Expand Down
Loading