Skip to content

Commit 9d5e2f6

Browse files
committed
mcp: propertly validate against JSON, independent of Go values
Our validation logic was avoiding double-unmarshalling as much as possible, by parsing before validation and validating the Go type. This only works if the Go type has the same structure as its JSON representation, which may not be the case in the presence of types with custom MarshalJSON or UnmarshalJSON methods (such as time.Time). But even if the Go type doesn't use any custom marshalling, validation is broken, because we can't differentiate zero values from missing values. Bite the bullet and use double-unmarshalling for both input and output schemas. Coincidentally, this fixes three bugs: - We were accepting case-insensitive JSON keys, since we parsed first, even though they should have been rejected. A number of tests were wrong. - Defaults were overriding present-yet-zero values, as noted in an incorrect test case. - When "arguments" was missing, validation wasn't performed, no defaults were applied, and unmarshalling failed even if all properties were optional. First unmarshalling to map[string]any allows us to fix all these bugs. Unfortunately, it means a 3x increase in the number of reflection operations (we need to unmarshal, apply defaults and validate, re-marshal with the defaults, and then unmarshal into the Go type). However, this is not likely to be a significant overhead, and we can always optimize in the future. Update github.com/google/jsonschema-go to pick up necessary improvements supporting this change. Additionally, fix the error codes for invalid tool parameters, to be consistent with other SDKs (Invalid Params: -32602). Fixes #447 Fixes #449 Updates #450
1 parent eddef06 commit 9d5e2f6

File tree

12 files changed

+254
-66
lines changed

12 files changed

+254
-66
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ go 1.23.0
44

55
require (
66
github.com/google/go-cmp v0.7.0
7-
github.com/google/jsonschema-go v0.2.2
7+
github.com/google/jsonschema-go v0.2.3-0.20250911201137-bbdc431016d2
88
github.com/yosida95/uritemplate/v3 v3.0.2
99
golang.org/x/tools v0.34.0
1010
)

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
22
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
3-
github.com/google/jsonschema-go v0.2.2 h1:qb9KM/pATIqIPuE9gEDwPsco8HHCTlA88IGFYHDl03A=
4-
github.com/google/jsonschema-go v0.2.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
3+
github.com/google/jsonschema-go v0.2.3-0.20250911201137-bbdc431016d2 h1:IIj7X4SH1HKy0WfPR4nNEj4dhIJWGdXM5YoBAbfpdoo=
4+
github.com/google/jsonschema-go v0.2.3-0.20250911201137-bbdc431016d2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
55
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
66
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
77
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=

mcp/conformance_test.go

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ import (
2020
"strings"
2121
"testing"
2222
"testing/synctest"
23+
"time"
2324

2425
"github.com/google/go-cmp/cmp"
2526
"github.com/google/go-cmp/cmp/cmpopts"
27+
"github.com/google/jsonschema-go/jsonschema"
2628
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
2729
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
2830
"golang.org/x/tools/txtar"
@@ -97,16 +99,40 @@ func TestServerConformance(t *testing.T) {
9799
}
98100
}
99101

100-
type input struct {
102+
type structuredInput struct {
101103
In string `jsonschema:"the input"`
102104
}
103105

104-
type output struct {
106+
type structuredOutput struct {
105107
Out string `jsonschema:"the output"`
106108
}
107109

108-
func structuredTool(ctx context.Context, req *CallToolRequest, args *input) (*CallToolResult, *output, error) {
109-
return nil, &output{"Ack " + args.In}, nil
110+
func structuredTool(ctx context.Context, req *CallToolRequest, args *structuredInput) (*CallToolResult, *structuredOutput, error) {
111+
return nil, &structuredOutput{"Ack " + args.In}, nil
112+
}
113+
114+
type tomorrowInput struct {
115+
Now time.Time
116+
}
117+
118+
type tomorrowOutput struct {
119+
Tomorrow time.Time
120+
}
121+
122+
func tomorrowTool(ctx context.Context, req *CallToolRequest, args tomorrowInput) (*CallToolResult, tomorrowOutput, error) {
123+
return nil, tomorrowOutput{args.Now.Add(24 * time.Hour)}, nil
124+
}
125+
126+
type incInput struct {
127+
X int `json:"x,omitempty"`
128+
}
129+
130+
type incOutput struct {
131+
Y int `json:"y"`
132+
}
133+
134+
func incTool(_ context.Context, _ *CallToolRequest, args incInput) (*CallToolResult, incOutput, error) {
135+
return nil, incOutput{args.X + 1}, nil
110136
}
111137

112138
// runServerTest runs the server conformance test.
@@ -124,6 +150,15 @@ func runServerTest(t *testing.T, test *conformanceTest) {
124150
}, sayHi)
125151
case "structured":
126152
AddTool(s, &Tool{Name: "structured"}, structuredTool)
153+
case "tomorrow":
154+
AddTool(s, &Tool{Name: "tomorrow"}, tomorrowTool)
155+
case "inc":
156+
inSchema, err := jsonschema.For[incInput](nil)
157+
if err != nil {
158+
t.Fatal(err)
159+
}
160+
inSchema.Properties["x"].Default = json.RawMessage(`6`)
161+
AddTool(s, &Tool{Name: "inc", InputSchema: inSchema}, incTool)
127162
default:
128163
t.Fatalf("unknown tool %q", tn)
129164
}

mcp/mcp_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ func TestEndToEnd(t *testing.T) {
224224
// ListTools is tested in client_list_test.go.
225225
gotHi, err := cs.CallTool(ctx, &CallToolParams{
226226
Name: "greet",
227-
Arguments: map[string]any{"name": "user"},
227+
Arguments: map[string]any{"Name": "user"},
228228
})
229229
if err != nil {
230230
t.Fatal(err)
@@ -648,7 +648,7 @@ func TestServerClosing(t *testing.T) {
648648
}()
649649
if _, err := cs.CallTool(ctx, &CallToolParams{
650650
Name: "greet",
651-
Arguments: map[string]any{"name": "user"},
651+
Arguments: map[string]any{"Name": "user"},
652652
}); err != nil {
653653
t.Fatalf("after connecting: %v", err)
654654
}
@@ -1646,7 +1646,7 @@ var testImpl = &Implementation{Name: "test", Version: "v1.0.0"}
16461646
// If anyone asks, we can add an option that controls how pointers are treated.
16471647
func TestPointerArgEquivalence(t *testing.T) {
16481648
type input struct {
1649-
In string
1649+
In string `json:",omitempty"`
16501650
}
16511651
type output struct {
16521652
Out string

mcp/protocol.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,13 @@ type CallToolResult struct {
105105
IsError bool `json:"isError,omitempty"`
106106
}
107107

108+
// TODO(#64): consider exposing setError (and getError), by adding an error
109+
// field on CallToolResult.
110+
func (r *CallToolResult) setError(err error) {
111+
r.Content = []Content{&TextContent{Text: err.Error()}}
112+
r.IsError = true
113+
}
114+
108115
func (*CallToolResult) isResult() {}
109116

110117
// UnmarshalJSON handles the unmarshalling of content into the Content

mcp/server.go

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -221,11 +221,23 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan
221221
}
222222

223223
th := func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) {
224+
var input json.RawMessage
225+
if req.Params.Arguments != nil {
226+
input = req.Params.Arguments
227+
}
228+
// Validate input and apply defaults.
229+
var err error
230+
input, err = applySchema(input, inputResolved)
231+
if err != nil {
232+
// TODO(#450): should this be considered a tool error? (and similar below)
233+
return nil, fmt.Errorf("%w: validating \"arguments\": %v", jsonrpc2.ErrInvalidParams, err)
234+
}
235+
224236
// Unmarshal and validate args.
225237
var in In
226-
if req.Params.Arguments != nil {
227-
if err := unmarshalSchema(req.Params.Arguments, inputResolved, &in); err != nil {
228-
return nil, err
238+
if input != nil {
239+
if err := json.Unmarshal(input, &in); err != nil {
240+
return nil, fmt.Errorf("%w: %v", jsonrpc2.ErrInvalidParams, err)
229241
}
230242
}
231243

@@ -241,22 +253,15 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan
241253
return nil, wireErr
242254
}
243255
// For regular errors, embed them in the tool result as per MCP spec
244-
return &CallToolResult{
245-
Content: []Content{&TextContent{Text: err.Error()}},
246-
IsError: true,
247-
}, nil
248-
}
249-
250-
// Validate output schema, if any.
251-
// Skip if out is nil: we've removed "null" from the output schema, so nil won't validate.
252-
if v := reflect.ValueOf(out); v.Kind() == reflect.Pointer && v.IsNil() {
253-
} else if err := validateSchema(outputResolved, &out); err != nil {
254-
return nil, fmt.Errorf("tool output: %w", err)
256+
var errRes CallToolResult
257+
errRes.setError(err)
258+
return &errRes, nil
255259
}
256260

257261
if res == nil {
258262
res = &CallToolResult{}
259263
}
264+
260265
// Marshal the output and put the RawMessage in the StructuredContent field.
261266
var outval any = out
262267
if elemZero != nil {
@@ -272,7 +277,16 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan
272277
if err != nil {
273278
return nil, fmt.Errorf("marshaling output: %w", err)
274279
}
275-
res.StructuredContent = json.RawMessage(outbytes) // avoid a second marshal over the wire
280+
outJSON := json.RawMessage(outbytes)
281+
// Validate the output JSON, and apply defaults.
282+
//
283+
// We validate against the JSON, rather than the output value, as
284+
// some types may have custom JSON marshalling (issue #447).
285+
outJSON, err = applySchema(outJSON, outputResolved)
286+
if err != nil {
287+
return nil, fmt.Errorf("validating tool output: %w", err)
288+
}
289+
res.StructuredContent = outJSON // avoid a second marshal over the wire
276290

277291
// If the Content field isn't being used, return the serialized JSON in a
278292
// TextContent block, as the spec suggests:

mcp/server_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out
514514
_, err = goth(context.Background(), ctr)
515515

516516
if gotErr := err != nil; gotErr != wantErr {
517-
t.Errorf("got error: %t, want error: %t", gotErr, wantErr)
517+
t.Errorf("got error %v, want error: %t", err, wantErr)
518518
}
519519
}
520520

mcp/sse_example_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ import (
1515
)
1616

1717
type AddParams struct {
18-
X, Y int
18+
X int `json:"x"`
19+
Y int `json:"y"`
1920
}
2021

2122
func Add(ctx context.Context, req *mcp.CallToolRequest, args AddParams) (*mcp.CallToolResult, any, error) {

mcp/streamable_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ func TestStreamableTransports(t *testing.T) {
144144
// The "greet" tool should just work.
145145
params := &CallToolParams{
146146
Name: "greet",
147-
Arguments: map[string]any{"name": "foo"},
147+
Arguments: map[string]any{"Name": "foo"},
148148
}
149149
got, err := session.CallTool(ctx, params)
150150
if err != nil {
@@ -239,10 +239,11 @@ func TestStreamableServerShutdown(t *testing.T) {
239239
if err != nil {
240240
t.Fatal(err)
241241
}
242+
defer clientSession.Close()
242243

243244
params := &CallToolParams{
244245
Name: "greet",
245-
Arguments: map[string]any{"name": "foo"},
246+
Arguments: map[string]any{"Name": "foo"},
246247
}
247248
// Verify that we can call a tool.
248249
if _, err := clientSession.CallTool(ctx, params); err != nil {

0 commit comments

Comments
 (0)