Skip to content

Commit 39d231e

Browse files
committed
mcp: validate tool output
1 parent c1c2292 commit 39d231e

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

mcp/server.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,8 +260,12 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan
260260
}, nil
261261
}
262262

263-
// TODO(v0.3.0): Validate out.
264-
_ = outputResolved
263+
// Validate output schema, if any.
264+
// Skip if out is nil: we've removed "null" from the output schema, so nil won't validate.
265+
if v := reflect.ValueOf(out); v.Kind() == reflect.Pointer && v.IsNil() {
266+
} else if err := validateSchema(outputResolved, &out); err != nil {
267+
return nil, err
268+
}
265269

266270
// TODO: return the serialized JSON in a TextContent block, as per spec?
267271
// https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content

mcp/tool.go

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"context"
1010
"encoding/json"
1111
"fmt"
12+
// "log"
1213

1314
"github.com/google/jsonschema-go/jsonschema"
1415
)
@@ -42,13 +43,17 @@ func unmarshalSchema(data json.RawMessage, resolved *jsonschema.Resolved, v any)
4243
if err := dec.Decode(v); err != nil {
4344
return fmt.Errorf("unmarshaling: %w", err)
4445
}
45-
// TODO: test with nil args.
46+
return validateSchema(resolved, v)
47+
}
48+
49+
// TODO: test with nil args.
50+
func validateSchema(resolved *jsonschema.Resolved, value any) error {
4651
if resolved != nil {
47-
if err := resolved.ApplyDefaults(v); err != nil {
48-
return fmt.Errorf("applying defaults from \n\t%s\nto\n\t%s:\n%w", schemaJSON(resolved.Schema()), data, err)
52+
if err := resolved.ApplyDefaults(value); err != nil {
53+
return fmt.Errorf("applying defaults from \n\t%s\nto\n\t%v:\n%w", schemaJSON(resolved.Schema()), value, err)
4954
}
50-
if err := resolved.Validate(v); err != nil {
51-
return fmt.Errorf("validating\n\t%s\nagainst\n\t %s:\n %w", data, schemaJSON(resolved.Schema()), err)
55+
if err := resolved.Validate(value); err != nil {
56+
return fmt.Errorf("validating\n\t%v\nagainst\n\t %s:\n %w", value, schemaJSON(resolved.Schema()), err)
5257
}
5358
}
5459
return nil

0 commit comments

Comments
 (0)