Skip to content

Commit 20fc70c

Browse files
committed
mcp: propertly validate against JSON, independent of Go values
Our validation logic was avoiding double-unmarshalling as much as possible, by parsing before validation and validating the Go type. This only works if the Go type has the same structure as its JSON representation, which may not be the case in the presence of types with custom MarshalJSON or UnmarshalJSON methods (such as time.Time). But even if the Go type doesn't use any custom marshalling, validation is broken, because we can't differentiate zero values from missing values. Bite the bullet and use double-unmarshalling for both input and output schemas. Coincidentally, this fixes three bugs: - We were accepting case-insensitive JSON keys, since we parsed first, even though they should have been rejected. A number of tests were wrong. - Defaults were overriding present-yet-zero values, as noted in an incorrect test case. - When "arguments" was missing, validation wasn't performed, no defaults were applied, and unmarshalling failed even if all properties were optional. First unmarshalling to map[string]any allows us to fix all these bugs. Unfortunately, it means a 3x increase in the number of reflection operations (we need to unmarshal, apply defaults and validate, re-marshal with the defaults, and then unmarshal into the Go type). However, this is not likely to be a significant overhead, and we can always optimize in the future. Update github.com/google/jsonschema-go to pick up necessary improvements supporting this change. Additionally, fix the error codes for invalid tool parameters, to be consistent with other SDKs (Invalid Params: -32602). Fixes #447 Fixes #449 Updates #450
1 parent eddef06 commit 20fc70c

File tree

12 files changed

+287
-85
lines changed

12 files changed

+287
-85
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ go 1.23.0
44

55
require (
66
github.com/google/go-cmp v0.7.0
7-
github.com/google/jsonschema-go v0.2.2
7+
github.com/google/jsonschema-go v0.2.3-0.20250911201137-bbdc431016d2
88
github.com/yosida95/uritemplate/v3 v3.0.2
99
golang.org/x/tools v0.34.0
1010
)

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
22
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
3-
github.com/google/jsonschema-go v0.2.2 h1:qb9KM/pATIqIPuE9gEDwPsco8HHCTlA88IGFYHDl03A=
4-
github.com/google/jsonschema-go v0.2.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
3+
github.com/google/jsonschema-go v0.2.3-0.20250911201137-bbdc431016d2 h1:IIj7X4SH1HKy0WfPR4nNEj4dhIJWGdXM5YoBAbfpdoo=
4+
github.com/google/jsonschema-go v0.2.3-0.20250911201137-bbdc431016d2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
55
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
66
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
77
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=

mcp/conformance_test.go

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ import (
2020
"strings"
2121
"testing"
2222
"testing/synctest"
23+
"time"
2324

2425
"github.com/google/go-cmp/cmp"
2526
"github.com/google/go-cmp/cmp/cmpopts"
27+
"github.com/google/jsonschema-go/jsonschema"
2628
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
2729
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
2830
"golang.org/x/tools/txtar"
@@ -97,16 +99,40 @@ func TestServerConformance(t *testing.T) {
9799
}
98100
}
99101

100-
type input struct {
102+
type structuredInput struct {
101103
In string `jsonschema:"the input"`
102104
}
103105

104-
type output struct {
106+
type structuredOutput struct {
105107
Out string `jsonschema:"the output"`
106108
}
107109

108-
func structuredTool(ctx context.Context, req *CallToolRequest, args *input) (*CallToolResult, *output, error) {
109-
return nil, &output{"Ack " + args.In}, nil
110+
func structuredTool(ctx context.Context, req *CallToolRequest, args *structuredInput) (*CallToolResult, *structuredOutput, error) {
111+
return nil, &structuredOutput{"Ack " + args.In}, nil
112+
}
113+
114+
type tomorrowInput struct {
115+
Now time.Time
116+
}
117+
118+
type tomorrowOutput struct {
119+
Tomorrow time.Time
120+
}
121+
122+
func tomorrowTool(ctx context.Context, req *CallToolRequest, args tomorrowInput) (*CallToolResult, tomorrowOutput, error) {
123+
return nil, tomorrowOutput{args.Now.Add(24 * time.Hour)}, nil
124+
}
125+
126+
type incInput struct {
127+
X int `json:"x,omitempty"`
128+
}
129+
130+
type incOutput struct {
131+
Y int `json:"y"`
132+
}
133+
134+
func incTool(_ context.Context, _ *CallToolRequest, args incInput) (*CallToolResult, incOutput, error) {
135+
return nil, incOutput{args.X + 1}, nil
110136
}
111137

112138
// runServerTest runs the server conformance test.
@@ -124,6 +150,15 @@ func runServerTest(t *testing.T, test *conformanceTest) {
124150
}, sayHi)
125151
case "structured":
126152
AddTool(s, &Tool{Name: "structured"}, structuredTool)
153+
case "tomorrow":
154+
AddTool(s, &Tool{Name: "tomorrow"}, tomorrowTool)
155+
case "inc":
156+
inSchema, err := jsonschema.For[incInput](nil)
157+
if err != nil {
158+
t.Fatal(err)
159+
}
160+
inSchema.Properties["x"].Default = json.RawMessage(`6`)
161+
AddTool(s, &Tool{Name: "inc", InputSchema: inSchema}, incTool)
127162
default:
128163
t.Fatalf("unknown tool %q", tn)
129164
}

mcp/mcp_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ func TestEndToEnd(t *testing.T) {
224224
// ListTools is tested in client_list_test.go.
225225
gotHi, err := cs.CallTool(ctx, &CallToolParams{
226226
Name: "greet",
227-
Arguments: map[string]any{"name": "user"},
227+
Arguments: map[string]any{"Name": "user"},
228228
})
229229
if err != nil {
230230
t.Fatal(err)
@@ -648,7 +648,7 @@ func TestServerClosing(t *testing.T) {
648648
}()
649649
if _, err := cs.CallTool(ctx, &CallToolParams{
650650
Name: "greet",
651-
Arguments: map[string]any{"name": "user"},
651+
Arguments: map[string]any{"Name": "user"},
652652
}); err != nil {
653653
t.Fatalf("after connecting: %v", err)
654654
}
@@ -1646,7 +1646,7 @@ var testImpl = &Implementation{Name: "test", Version: "v1.0.0"}
16461646
// If anyone asks, we can add an option that controls how pointers are treated.
16471647
func TestPointerArgEquivalence(t *testing.T) {
16481648
type input struct {
1649-
In string
1649+
In string `json:",omitempty"`
16501650
}
16511651
type output struct {
16521652
Out string

mcp/protocol.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,13 @@ type CallToolResult struct {
105105
IsError bool `json:"isError,omitempty"`
106106
}
107107

108+
// TODO(#64): consider exposing setError (and getError), by adding an error
109+
// field on CallToolResult.
110+
func (r *CallToolResult) setError(err error) {
111+
r.Content = []Content{&TextContent{Text: err.Error()}}
112+
r.IsError = true
113+
}
114+
108115
func (*CallToolResult) isResult() {}
109116

110117
// UnmarshalJSON handles the unmarshalling of content into the Content

mcp/server.go

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,6 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) {
189189
func() bool { s.tools.add(st); return true })
190190
}
191191

192-
// TODO(v0.3.0): test
193192
func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler, error) {
194193
tt := *t
195194

@@ -221,11 +220,23 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan
221220
}
222221

223222
th := func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) {
223+
var input json.RawMessage
224+
if req.Params.Arguments != nil {
225+
input = req.Params.Arguments
226+
}
227+
// Validate input and apply defaults.
228+
var err error
229+
input, err = applySchema(input, inputResolved)
230+
if err != nil {
231+
// TODO(#450): should this be considered a tool error? (and similar below)
232+
return nil, fmt.Errorf("%w: validating \"arguments\": %v", jsonrpc2.ErrInvalidParams, err)
233+
}
234+
224235
// Unmarshal and validate args.
225236
var in In
226-
if req.Params.Arguments != nil {
227-
if err := unmarshalSchema(req.Params.Arguments, inputResolved, &in); err != nil {
228-
return nil, err
237+
if input != nil {
238+
if err := json.Unmarshal(input, &in); err != nil {
239+
return nil, fmt.Errorf("%w: %v", jsonrpc2.ErrInvalidParams, err)
229240
}
230241
}
231242

@@ -241,22 +252,15 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan
241252
return nil, wireErr
242253
}
243254
// For regular errors, embed them in the tool result as per MCP spec
244-
return &CallToolResult{
245-
Content: []Content{&TextContent{Text: err.Error()}},
246-
IsError: true,
247-
}, nil
248-
}
249-
250-
// Validate output schema, if any.
251-
// Skip if out is nil: we've removed "null" from the output schema, so nil won't validate.
252-
if v := reflect.ValueOf(out); v.Kind() == reflect.Pointer && v.IsNil() {
253-
} else if err := validateSchema(outputResolved, &out); err != nil {
254-
return nil, fmt.Errorf("tool output: %w", err)
255+
var errRes CallToolResult
256+
errRes.setError(err)
257+
return &errRes, nil
255258
}
256259

257260
if res == nil {
258261
res = &CallToolResult{}
259262
}
263+
260264
// Marshal the output and put the RawMessage in the StructuredContent field.
261265
var outval any = out
262266
if elemZero != nil {
@@ -272,7 +276,16 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan
272276
if err != nil {
273277
return nil, fmt.Errorf("marshaling output: %w", err)
274278
}
275-
res.StructuredContent = json.RawMessage(outbytes) // avoid a second marshal over the wire
279+
outJSON := json.RawMessage(outbytes)
280+
// Validate the output JSON, and apply defaults.
281+
//
282+
// We validate against the JSON, rather than the output value, as
283+
// some types may have custom JSON marshalling (issue #447).
284+
outJSON, err = applySchema(outJSON, outputResolved)
285+
if err != nil {
286+
return nil, fmt.Errorf("validating tool output: %w", err)
287+
}
288+
res.StructuredContent = outJSON // avoid a second marshal over the wire
276289

277290
// If the Content field isn't being used, return the serialized JSON in a
278291
// TextContent block, as the spec suggests:

mcp/server_test.go

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"encoding/json"
1010
"log"
1111
"slices"
12+
"strings"
1213
"testing"
1314
"time"
1415

@@ -491,7 +492,7 @@ func TestAddTool(t *testing.T) {
491492

492493
type schema = jsonschema.Schema
493494

494-
func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out, wantIn, wantOut *schema, wantErr bool) {
495+
func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out, wantIn, wantOut *schema, wantErrContaining string) {
495496
t.Helper()
496497
th := func(context.Context, *CallToolRequest, In) (*CallToolResult, Out, error) {
497498
return nil, out, nil
@@ -513,34 +514,48 @@ func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out
513514
}
514515
_, err = goth(context.Background(), ctr)
515516

516-
if gotErr := err != nil; gotErr != wantErr {
517-
t.Errorf("got error: %t, want error: %t", gotErr, wantErr)
517+
if wantErrContaining != "" {
518+
if err == nil {
519+
t.Errorf("got nil error, want error containing %q", wantErrContaining)
520+
} else {
521+
if !strings.Contains(err.Error(), wantErrContaining) {
522+
t.Errorf("got error %q, want containing %q", err, wantErrContaining)
523+
}
524+
}
525+
} else if err != nil {
526+
t.Errorf("got error %v, want no error", err)
518527
}
519528
}
520529

521530
func TestToolForSchemas(t *testing.T) {
522-
// Validate that ToolFor handles schemas properly.
531+
// Validate that toolForErr handles schemas properly.
532+
type in struct {
533+
P int `json:"p,omitempty"`
534+
}
535+
type out struct {
536+
B bool `json:"b,omitempty"`
537+
}
538+
539+
var (
540+
falseSchema = &schema{Not: &schema{}}
541+
inSchema = &schema{Type: "object", AdditionalProperties: falseSchema, Properties: map[string]*schema{"p": {Type: "integer"}}}
542+
inSchema2 = &schema{Type: "object", AdditionalProperties: falseSchema, Properties: map[string]*schema{"p": {Type: "string"}}}
543+
outSchema = &schema{Type: "object", AdditionalProperties: falseSchema, Properties: map[string]*schema{"b": {Type: "boolean"}}}
544+
outSchema2 = &schema{Type: "object", AdditionalProperties: falseSchema, Properties: map[string]*schema{"b": {Type: "integer"}}}
545+
)
523546

524547
// Infer both schemas.
525-
testToolForSchema[int](t, &Tool{}, "3", true,
526-
&schema{Type: "integer"}, &schema{Type: "boolean"}, false)
548+
testToolForSchema[in](t, &Tool{}, `{"p":3}`, out{true}, inSchema, outSchema, "")
527549
// Validate the input schema: expect an error if it's wrong.
528550
// We can't test that the output schema is validated, because it's typed.
529-
testToolForSchema[int](t, &Tool{}, `"x"`, true,
530-
&schema{Type: "integer"}, &schema{Type: "boolean"}, true)
531-
551+
testToolForSchema[in](t, &Tool{}, `{"p":"x"}`, out{true}, inSchema, outSchema, `want "integer"`)
532552
// Ignore type any for output.
533-
testToolForSchema[int, any](t, &Tool{}, "3", 0,
534-
&schema{Type: "integer"}, nil, false)
553+
testToolForSchema[in, any](t, &Tool{}, `{"p":3}`, 0, inSchema, nil, "")
535554
// Input is still validated.
536-
testToolForSchema[int, any](t, &Tool{}, `"x"`, 0,
537-
&schema{Type: "integer"}, nil, true)
538-
555+
testToolForSchema[in, any](t, &Tool{}, `{"p":"x"}`, 0, inSchema, nil, `want "integer"`)
539556
// Tool sets input schema: that is what's used.
540-
testToolForSchema[int, any](t, &Tool{InputSchema: &schema{Type: "string"}}, "3", 0,
541-
&schema{Type: "string"}, nil, true) // error: 3 is not a string
542-
557+
testToolForSchema[in, any](t, &Tool{InputSchema: inSchema2}, `{"p":3}`, 0, inSchema2, nil, `want "string"`)
543558
// Tool sets output schema: that is what's used, and validation happens.
544-
testToolForSchema[string, any](t, &Tool{OutputSchema: &schema{Type: "integer"}}, "3", "x",
545-
&schema{Type: "string"}, &schema{Type: "integer"}, true) // error: "x" is not an integer
559+
testToolForSchema[in, any](t, &Tool{OutputSchema: outSchema2}, `{"p":3}`, out{true},
560+
inSchema, outSchema2, `want "integer"`)
546561
}

mcp/sse_example_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ import (
1515
)
1616

1717
type AddParams struct {
18-
X, Y int
18+
X int `json:"x"`
19+
Y int `json:"y"`
1920
}
2021

2122
func Add(ctx context.Context, req *mcp.CallToolRequest, args AddParams) (*mcp.CallToolResult, any, error) {

mcp/streamable_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ func TestStreamableTransports(t *testing.T) {
144144
// The "greet" tool should just work.
145145
params := &CallToolParams{
146146
Name: "greet",
147-
Arguments: map[string]any{"name": "foo"},
147+
Arguments: map[string]any{"Name": "foo"},
148148
}
149149
got, err := session.CallTool(ctx, params)
150150
if err != nil {
@@ -239,10 +239,11 @@ func TestStreamableServerShutdown(t *testing.T) {
239239
if err != nil {
240240
t.Fatal(err)
241241
}
242+
defer clientSession.Close()
242243

243244
params := &CallToolParams{
244245
Name: "greet",
245-
Arguments: map[string]any{"name": "foo"},
246+
Arguments: map[string]any{"Name": "foo"},
246247
}
247248
// Verify that we can call a tool.
248249
if _, err := clientSession.CallTool(ctx, params); err != nil {

0 commit comments

Comments
 (0)