diff --git a/go.mod b/go.mod index 0e18643d..d303ef0c 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,7 @@ go 1.23.0 require ( github.com/google/go-cmp v0.7.0 - github.com/google/jsonschema-go v0.2.2 + github.com/google/jsonschema-go v0.2.3-0.20250911201137-bbdc431016d2 github.com/yosida95/uritemplate/v3 v3.0.2 golang.org/x/tools v0.34.0 ) diff --git a/go.sum b/go.sum index 169c67ed..6903b659 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/jsonschema-go v0.2.2 h1:qb9KM/pATIqIPuE9gEDwPsco8HHCTlA88IGFYHDl03A= -github.com/google/jsonschema-go v0.2.2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= +github.com/google/jsonschema-go v0.2.3-0.20250911201137-bbdc431016d2 h1:IIj7X4SH1HKy0WfPR4nNEj4dhIJWGdXM5YoBAbfpdoo= +github.com/google/jsonschema-go v0.2.3-0.20250911201137-bbdc431016d2/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= diff --git a/mcp/conformance_test.go b/mcp/conformance_test.go index a8da4fb7..3393efcb 100644 --- a/mcp/conformance_test.go +++ b/mcp/conformance_test.go @@ -20,9 +20,11 @@ import ( "strings" "testing" "testing/synctest" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" "golang.org/x/tools/txtar" @@ -97,16 +99,40 @@ func TestServerConformance(t *testing.T) { } } -type input struct { +type structuredInput struct { In string `jsonschema:"the input"` } -type output struct { +type structuredOutput struct { Out string `jsonschema:"the output"` } -func structuredTool(ctx context.Context, req *CallToolRequest, args *input) (*CallToolResult, *output, error) { - return nil, &output{"Ack " + args.In}, nil +func structuredTool(ctx context.Context, req *CallToolRequest, args *structuredInput) (*CallToolResult, *structuredOutput, error) { + return nil, &structuredOutput{"Ack " + args.In}, nil +} + +type tomorrowInput struct { + Now time.Time +} + +type tomorrowOutput struct { + Tomorrow time.Time +} + +func tomorrowTool(ctx context.Context, req *CallToolRequest, args tomorrowInput) (*CallToolResult, tomorrowOutput, error) { + return nil, tomorrowOutput{args.Now.Add(24 * time.Hour)}, nil +} + +type incInput struct { + X int `json:"x,omitempty"` +} + +type incOutput struct { + Y int `json:"y"` +} + +func incTool(_ context.Context, _ *CallToolRequest, args incInput) (*CallToolResult, incOutput, error) { + return nil, incOutput{args.X + 1}, nil } // runServerTest runs the server conformance test. @@ -124,6 +150,15 @@ func runServerTest(t *testing.T, test *conformanceTest) { }, sayHi) case "structured": AddTool(s, &Tool{Name: "structured"}, structuredTool) + case "tomorrow": + AddTool(s, &Tool{Name: "tomorrow"}, tomorrowTool) + case "inc": + inSchema, err := jsonschema.For[incInput](nil) + if err != nil { + t.Fatal(err) + } + inSchema.Properties["x"].Default = json.RawMessage(`6`) + AddTool(s, &Tool{Name: "inc", InputSchema: inSchema}, incTool) default: t.Fatalf("unknown tool %q", tn) } diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 4b05ce7d..6191954c 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -224,7 +224,7 @@ func TestEndToEnd(t *testing.T) { // ListTools is tested in client_list_test.go. gotHi, err := cs.CallTool(ctx, &CallToolParams{ Name: "greet", - Arguments: map[string]any{"name": "user"}, + Arguments: map[string]any{"Name": "user"}, }) if err != nil { t.Fatal(err) @@ -648,7 +648,7 @@ func TestServerClosing(t *testing.T) { }() if _, err := cs.CallTool(ctx, &CallToolParams{ Name: "greet", - Arguments: map[string]any{"name": "user"}, + Arguments: map[string]any{"Name": "user"}, }); err != nil { t.Fatalf("after connecting: %v", err) } @@ -1646,7 +1646,7 @@ var testImpl = &Implementation{Name: "test", Version: "v1.0.0"} // If anyone asks, we can add an option that controls how pointers are treated. func TestPointerArgEquivalence(t *testing.T) { type input struct { - In string + In string `json:",omitempty"` } type output struct { Out string diff --git a/mcp/protocol.go b/mcp/protocol.go index a8e4817d..aeb9adbd 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -105,6 +105,13 @@ type CallToolResult struct { IsError bool `json:"isError,omitempty"` } +// TODO(#64): consider exposing setError (and getError), by adding an error +// field on CallToolResult. +func (r *CallToolResult) setError(err error) { + r.Content = []Content{&TextContent{Text: err.Error()}} + r.IsError = true +} + func (*CallToolResult) isResult() {} // UnmarshalJSON handles the unmarshalling of content into the Content diff --git a/mcp/server.go b/mcp/server.go index af9c2ab4..69808ac7 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -189,7 +189,6 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) { func() bool { s.tools.add(st); return true }) } -// TODO(v0.3.0): test func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler, error) { tt := *t @@ -221,11 +220,23 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan } th := func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { + var input json.RawMessage + if req.Params.Arguments != nil { + input = req.Params.Arguments + } + // Validate input and apply defaults. + var err error + input, err = applySchema(input, inputResolved) + if err != nil { + // TODO(#450): should this be considered a tool error? (and similar below) + return nil, fmt.Errorf("%w: validating \"arguments\": %v", jsonrpc2.ErrInvalidParams, err) + } + // Unmarshal and validate args. var in In - if req.Params.Arguments != nil { - if err := unmarshalSchema(req.Params.Arguments, inputResolved, &in); err != nil { - return nil, err + if input != nil { + if err := json.Unmarshal(input, &in); err != nil { + return nil, fmt.Errorf("%w: %v", jsonrpc2.ErrInvalidParams, err) } } @@ -241,22 +252,15 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan return nil, wireErr } // For regular errors, embed them in the tool result as per MCP spec - return &CallToolResult{ - Content: []Content{&TextContent{Text: err.Error()}}, - IsError: true, - }, nil - } - - // Validate output schema, if any. - // Skip if out is nil: we've removed "null" from the output schema, so nil won't validate. - if v := reflect.ValueOf(out); v.Kind() == reflect.Pointer && v.IsNil() { - } else if err := validateSchema(outputResolved, &out); err != nil { - return nil, fmt.Errorf("tool output: %w", err) + var errRes CallToolResult + errRes.setError(err) + return &errRes, nil } if res == nil { res = &CallToolResult{} } + // Marshal the output and put the RawMessage in the StructuredContent field. var outval any = out if elemZero != nil { @@ -272,7 +276,16 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan if err != nil { return nil, fmt.Errorf("marshaling output: %w", err) } - res.StructuredContent = json.RawMessage(outbytes) // avoid a second marshal over the wire + outJSON := json.RawMessage(outbytes) + // Validate the output JSON, and apply defaults. + // + // We validate against the JSON, rather than the output value, as + // some types may have custom JSON marshalling (issue #447). + outJSON, err = applySchema(outJSON, outputResolved) + if err != nil { + return nil, fmt.Errorf("validating tool output: %w", err) + } + res.StructuredContent = outJSON // avoid a second marshal over the wire // If the Content field isn't being used, return the serialized JSON in a // TextContent block, as the spec suggests: diff --git a/mcp/server_test.go b/mcp/server_test.go index e46be379..4456495f 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -9,6 +9,7 @@ import ( "encoding/json" "log" "slices" + "strings" "testing" "time" @@ -491,7 +492,7 @@ func TestAddTool(t *testing.T) { type schema = jsonschema.Schema -func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out, wantIn, wantOut *schema, wantErr bool) { +func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out, wantIn, wantOut *schema, wantErrContaining string) { t.Helper() th := func(context.Context, *CallToolRequest, In) (*CallToolResult, Out, error) { return nil, out, nil @@ -513,34 +514,48 @@ func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out } _, err = goth(context.Background(), ctr) - if gotErr := err != nil; gotErr != wantErr { - t.Errorf("got error: %t, want error: %t", gotErr, wantErr) + if wantErrContaining != "" { + if err == nil { + t.Errorf("got nil error, want error containing %q", wantErrContaining) + } else { + if !strings.Contains(err.Error(), wantErrContaining) { + t.Errorf("got error %q, want containing %q", err, wantErrContaining) + } + } + } else if err != nil { + t.Errorf("got error %v, want no error", err) } } func TestToolForSchemas(t *testing.T) { - // Validate that ToolFor handles schemas properly. + // Validate that toolForErr handles schemas properly. + type in struct { + P int `json:"p,omitempty"` + } + type out struct { + B bool `json:"b,omitempty"` + } + + var ( + falseSchema = &schema{Not: &schema{}} + inSchema = &schema{Type: "object", AdditionalProperties: falseSchema, Properties: map[string]*schema{"p": {Type: "integer"}}} + inSchema2 = &schema{Type: "object", AdditionalProperties: falseSchema, Properties: map[string]*schema{"p": {Type: "string"}}} + outSchema = &schema{Type: "object", AdditionalProperties: falseSchema, Properties: map[string]*schema{"b": {Type: "boolean"}}} + outSchema2 = &schema{Type: "object", AdditionalProperties: falseSchema, Properties: map[string]*schema{"b": {Type: "integer"}}} + ) // Infer both schemas. - testToolForSchema[int](t, &Tool{}, "3", true, - &schema{Type: "integer"}, &schema{Type: "boolean"}, false) + testToolForSchema[in](t, &Tool{}, `{"p":3}`, out{true}, inSchema, outSchema, "") // Validate the input schema: expect an error if it's wrong. // We can't test that the output schema is validated, because it's typed. - testToolForSchema[int](t, &Tool{}, `"x"`, true, - &schema{Type: "integer"}, &schema{Type: "boolean"}, true) - + testToolForSchema[in](t, &Tool{}, `{"p":"x"}`, out{true}, inSchema, outSchema, `want "integer"`) // Ignore type any for output. - testToolForSchema[int, any](t, &Tool{}, "3", 0, - &schema{Type: "integer"}, nil, false) + testToolForSchema[in, any](t, &Tool{}, `{"p":3}`, 0, inSchema, nil, "") // Input is still validated. - testToolForSchema[int, any](t, &Tool{}, `"x"`, 0, - &schema{Type: "integer"}, nil, true) - + testToolForSchema[in, any](t, &Tool{}, `{"p":"x"}`, 0, inSchema, nil, `want "integer"`) // Tool sets input schema: that is what's used. - testToolForSchema[int, any](t, &Tool{InputSchema: &schema{Type: "string"}}, "3", 0, - &schema{Type: "string"}, nil, true) // error: 3 is not a string - + testToolForSchema[in, any](t, &Tool{InputSchema: inSchema2}, `{"p":3}`, 0, inSchema2, nil, `want "string"`) // Tool sets output schema: that is what's used, and validation happens. - testToolForSchema[string, any](t, &Tool{OutputSchema: &schema{Type: "integer"}}, "3", "x", - &schema{Type: "string"}, &schema{Type: "integer"}, true) // error: "x" is not an integer + testToolForSchema[in, any](t, &Tool{OutputSchema: outSchema2}, `{"p":3}`, out{true}, + inSchema, outSchema2, `want "integer"`) } diff --git a/mcp/sse_example_test.go b/mcp/sse_example_test.go index 7d777114..d06ea62b 100644 --- a/mcp/sse_example_test.go +++ b/mcp/sse_example_test.go @@ -15,7 +15,8 @@ import ( ) type AddParams struct { - X, Y int + X int `json:"x"` + Y int `json:"y"` } func Add(ctx context.Context, req *mcp.CallToolRequest, args AddParams) (*mcp.CallToolResult, any, error) { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 8817784d..e077308c 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -144,7 +144,7 @@ func TestStreamableTransports(t *testing.T) { // The "greet" tool should just work. params := &CallToolParams{ Name: "greet", - Arguments: map[string]any{"name": "foo"}, + Arguments: map[string]any{"Name": "foo"}, } got, err := session.CallTool(ctx, params) if err != nil { @@ -239,10 +239,11 @@ func TestStreamableServerShutdown(t *testing.T) { if err != nil { t.Fatal(err) } + defer clientSession.Close() params := &CallToolParams{ Name: "greet", - Arguments: map[string]any{"name": "foo"}, + Arguments: map[string]any{"Name": "foo"}, } // Verify that we can call a tool. if _, err := clientSession.CallTool(ctx, params); err != nil { diff --git a/mcp/testdata/conformance/server/tools.txtar b/mcp/testdata/conformance/server/tools.txtar index c39e3ec9..b582dda8 100644 --- a/mcp/testdata/conformance/server/tools.txtar +++ b/mcp/testdata/conformance/server/tools.txtar @@ -5,10 +5,17 @@ Fixed bugs: - "_meta" should not be nil - empty resource or prompts should not be returned as 'null' - the server should not crash when params are passed to tools/call +- missing required input fields should be rejected (#449) +- output and input should be validated against their actual json, not Go + representation +- When arguments are missing, the request should succeed if all properties are + optional, and observe any default values. -- tools -- greet structured +tomorrow +inc -- client -- { @@ -26,9 +33,14 @@ structured { "jsonrpc": "2.0", "id": 3, "method": "resources/list" } { "jsonrpc": "2.0", "id": 4, "method": "prompts/list" } { "jsonrpc": "2.0", "id": 5, "method": "tools/call" } -{ "jsonrpc": "2.0", "id": 6, "method": "tools/call", "params": {"name": "greet", "arguments": {"name": "you"} } } +{ "jsonrpc": "2.0", "id": 6, "method": "tools/call", "params": {"name": "greet", "arguments": {"Name": "you"} } } { "jsonrpc": "2.0", "id": 1, "result": {} } { "jsonrpc": "2.0", "id": 7, "method": "tools/call", "params": {"name": "structured", "arguments": {"In": "input"} } } +{ "jsonrpc": "2.0", "id": 8, "method": "tools/call", "params": {"name": "structured", "arguments": {} } } +{ "jsonrpc": "2.0", "id": 9, "method": "tools/call", "params": {"name": "tomorrow", "arguments": { "Now": "2025-06-18T15:04:05Z" } } } +{ "jsonrpc": "2.0", "id": 10, "method": "tools/call", "params": {"name": "greet" } } +{ "jsonrpc": "2.0", "id": 11, "method": "tools/call", "params": {"name": "inc", "arguments": { "x": 3 } } } +{ "jsonrpc": "2.0", "id": 11, "method": "tools/call", "params": {"name": "inc" } } -- server -- { @@ -69,6 +81,31 @@ structured }, "name": "greet" }, + { + "inputSchema": { + "type": "object", + "properties": { + "x": { + "type": "integer", + "default": 6 + } + }, + "additionalProperties": false + }, + "name": "inc", + "outputSchema": { + "type": "object", + "required": [ + "y" + ], + "properties": { + "y": { + "type": "integer" + } + }, + "additionalProperties": false + } + }, { "inputSchema": { "type": "object", @@ -97,6 +134,33 @@ structured }, "additionalProperties": false } + }, + { + "inputSchema": { + "type": "object", + "required": [ + "Now" + ], + "properties": { + "Now": { + "type": "string" + } + }, + "additionalProperties": false + }, + "name": "tomorrow", + "outputSchema": { + "type": "object", + "required": [ + "Tomorrow" + ], + "properties": { + "Tomorrow": { + "type": "string" + } + }, + "additionalProperties": false + } } ] } @@ -155,3 +219,64 @@ structured } } } +{ + "jsonrpc": "2.0", + "id": 8, + "error": { + "code": -32602, + "message": "invalid params: validating \"arguments\": validating root: required: missing properties: [\"In\"]" + } +} +{ + "jsonrpc": "2.0", + "id": 9, + "result": { + "content": [ + { + "type": "text", + "text": "{\"Tomorrow\":\"2025-06-19T15:04:05Z\"}" + } + ], + "structuredContent": { + "Tomorrow": "2025-06-19T15:04:05Z" + } + } +} +{ + "jsonrpc": "2.0", + "id": 10, + "error": { + "code": -32602, + "message": "invalid params: validating \"arguments\": validating root: required: missing properties: [\"Name\"]" + } +} +{ + "jsonrpc": "2.0", + "id": 11, + "result": { + "content": [ + { + "type": "text", + "text": "{\"y\":4}" + } + ], + "structuredContent": { + "y": 4 + } + } +} +{ + "jsonrpc": "2.0", + "id": 11, + "result": { + "content": [ + { + "type": "text", + "text": "{\"y\":7}" + } + ], + "structuredContent": { + "y": 7 + } + } +} diff --git a/mcp/tool.go b/mcp/tool.go index ffccbf30..12b02b7b 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -5,7 +5,6 @@ package mcp import ( - "bytes" "context" "encoding/json" "fmt" @@ -61,41 +60,44 @@ type serverTool struct { handler ToolHandler } -// unmarshalSchema unmarshals data into v and validates the result according to -// the given resolved schema. -func unmarshalSchema(data json.RawMessage, resolved *jsonschema.Resolved, v any) error { +// applySchema validates whether data is valid JSON according to the provided +// schema, after applying schema defaults. +// +// Returns the JSON value augmented with defaults. +func applySchema(data json.RawMessage, resolved *jsonschema.Resolved) (json.RawMessage, error) { // TODO: use reflection to create the struct type to unmarshal into. // Separate validation from assignment. - // Disallow unknown fields. - // Otherwise, if the tool was built with a struct, the client could send extra - // fields and json.Unmarshal would ignore them, so the schema would never get - // a chance to declare the extra args invalid. - dec := json.NewDecoder(bytes.NewReader(data)) - dec.DisallowUnknownFields() - if err := dec.Decode(v); err != nil { - return fmt.Errorf("unmarshaling: %w", err) - } - return validateSchema(resolved, v) -} - -func validateSchema(resolved *jsonschema.Resolved, value any) error { + // Use default JSON marshalling for validation. + // + // This avoids inconsistent representation due to custom marshallers, such as + // time.Time (issue #449). + // + // Additionally, unmarshalling into a map ensures that the resulting JSON is + // at least {}, even if data is empty. For example, arguments is technically + // an optional property of callToolParams, and we still want to apply the + // defaults in this case. + // + // TODO(rfindley): in which cases can resolved be nil? if resolved != nil { - if err := resolved.ApplyDefaults(value); err != nil { - return fmt.Errorf("applying defaults from \n\t%s\nto\n\t%v:\n%w", schemaJSON(resolved.Schema()), value, err) + v := make(map[string]any) + if len(data) > 0 { + if err := json.Unmarshal(data, &v); err != nil { + return nil, fmt.Errorf("unmarshaling arguments: %w", err) + } } - if err := resolved.Validate(value); err != nil { - return fmt.Errorf("validating\n\t%v\nagainst\n\t %s:\n %w", value, schemaJSON(resolved.Schema()), err) + if err := resolved.ApplyDefaults(&v); err != nil { + return nil, fmt.Errorf("applying schema defaults:\n%w", err) + } + if err := resolved.Validate(&v); err != nil { + return nil, err + } + // We must re-marshal with the default values applied. + var err error + data, err = json.Marshal(v) + if err != nil { + return nil, fmt.Errorf("marshalling with defaults: %v", err) } } - return nil -} - -// schemaJSON returns the JSON value for s as a string, or a string indicating an error. -func schemaJSON(s *jsonschema.Schema) string { - m, err := json.Marshal(s) - if err != nil { - return fmt.Sprintf("", err) - } - return string(m) + return data, nil } diff --git a/mcp/tool_test.go b/mcp/tool_test.go index 2722a9ac..ef26e9dc 100644 --- a/mcp/tool_test.go +++ b/mcp/tool_test.go @@ -17,7 +17,7 @@ import ( "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" ) -func TestUnmarshalSchema(t *testing.T) { +func TestApplySchema(t *testing.T) { schema := &jsonschema.Schema{ Type: "object", Properties: map[string]*jsonschema.Schema{ @@ -39,20 +39,23 @@ func TestUnmarshalSchema(t *testing.T) { want any }{ {`{"x": 1}`, new(S), &S{X: 1}}, - {`{}`, new(S), &S{X: 3}}, // default applied - {`{"x": 0}`, new(S), &S{X: 3}}, // FAIL: should be 0. (requires double unmarshal) + {`{}`, new(S), &S{X: 3}}, // default applied + {`{"x": 0}`, new(S), &S{X: 0}}, {`{"x": 1}`, new(map[string]any), &map[string]any{"x": 1.0}}, {`{}`, new(map[string]any), &map[string]any{"x": 3.0}}, // default applied {`{"x": 0}`, new(map[string]any), &map[string]any{"x": 0.0}}, } { raw := json.RawMessage(tt.data) - if err := unmarshalSchema(raw, resolved, tt.v); err != nil { + raw, err = applySchema(raw, resolved) + if err != nil { + t.Fatal(err) + } + if err := json.Unmarshal(raw, &tt.v); err != nil { t.Fatal(err) } if !reflect.DeepEqual(tt.v, tt.want) { t.Errorf("got %#v, want %#v", tt.v, tt.want) } - } }