Skip to content

Commit 8314ec0

Browse files
authored
mcp: validate tool output (#352)
Validate the handler's returned output against the output schema of a tool. Fixes #301.
1 parent d0c5943 commit 8314ec0

File tree

3 files changed

+35
-7
lines changed

3 files changed

+35
-7
lines changed

mcp/mcp_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,26 @@ func TestEndToEnd(t *testing.T) {
257257
t.Errorf("tools/call 'fail' mismatch (-want +got):\n%s", diff)
258258
}
259259

260+
// Check output schema validation.
261+
badout := &Tool{
262+
Name: "badout",
263+
OutputSchema: &jsonschema.Schema{
264+
Type: "object",
265+
Properties: map[string]*jsonschema.Schema{
266+
"x": {Type: "string"},
267+
},
268+
},
269+
}
270+
AddTool(s, badout, func(_ context.Context, _ *CallToolRequest, arg map[string]any) (*CallToolResult, map[string]any, error) {
271+
return nil, map[string]any{"x": 1}, nil
272+
})
273+
_, err = cs.CallTool(ctx, &CallToolParams{Name: "badout"})
274+
wantMsg := `has type "integer", want "string"`
275+
if err == nil || !strings.Contains(err.Error(), wantMsg) {
276+
t.Errorf("\ngot %q\nwant error message containing %q", err, wantMsg)
277+
}
278+
279+
// Check tools-changed notifications.
260280
s.AddTool(&Tool{Name: "T", InputSchema: &jsonschema.Schema{Type: "object"}}, nopHandler)
261281
waitForNotification(t, "tools")
262282
s.RemoveTools("T")

mcp/server.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,12 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan
247247
}, nil
248248
}
249249

250-
// TODO(v0.3.0): Validate out.
251-
_ = outputResolved
250+
// Validate output schema, if any.
251+
// Skip if out is nil: we've removed "null" from the output schema, so nil won't validate.
252+
if v := reflect.ValueOf(out); v.Kind() == reflect.Pointer && v.IsNil() {
253+
} else if err := validateSchema(outputResolved, &out); err != nil {
254+
return nil, fmt.Errorf("tool output: %w", err)
255+
}
252256

253257
if res == nil {
254258
res = &CallToolResult{}

mcp/tool.go

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"context"
1010
"encoding/json"
1111
"fmt"
12+
// "log"
1213

1314
"github.com/google/jsonschema-go/jsonschema"
1415
)
@@ -42,13 +43,16 @@ func unmarshalSchema(data json.RawMessage, resolved *jsonschema.Resolved, v any)
4243
if err := dec.Decode(v); err != nil {
4344
return fmt.Errorf("unmarshaling: %w", err)
4445
}
45-
// TODO: test with nil args.
46+
return validateSchema(resolved, v)
47+
}
48+
49+
func validateSchema(resolved *jsonschema.Resolved, value any) error {
4650
if resolved != nil {
47-
if err := resolved.ApplyDefaults(v); err != nil {
48-
return fmt.Errorf("applying defaults from \n\t%s\nto\n\t%s:\n%w", schemaJSON(resolved.Schema()), data, err)
51+
if err := resolved.ApplyDefaults(value); err != nil {
52+
return fmt.Errorf("applying defaults from \n\t%s\nto\n\t%v:\n%w", schemaJSON(resolved.Schema()), value, err)
4953
}
50-
if err := resolved.Validate(v); err != nil {
51-
return fmt.Errorf("validating\n\t%s\nagainst\n\t %s:\n %w", data, schemaJSON(resolved.Schema()), err)
54+
if err := resolved.Validate(value); err != nil {
55+
return fmt.Errorf("validating\n\t%v\nagainst\n\t %s:\n %w", value, schemaJSON(resolved.Schema()), err)
5256
}
5357
}
5458
return nil

0 commit comments

Comments
 (0)