diff --git a/examples/server/toolschemas/main.go b/examples/server/toolschemas/main.go index a53acd47..abc1e3f5 100644 --- a/examples/server/toolschemas/main.go +++ b/examples/server/toolschemas/main.go @@ -126,15 +126,13 @@ func main() { // Add the 'greeting' tool in a few different ways. // First, we can just use [mcp.AddTool], and get the out-of-the-box handling - // it provides: + // it provides for schema inference, validation, parsing, and packing the + // result. mcp.AddTool(server, &mcp.Tool{Name: "simple greeting"}, simpleGreeting) - // Next, we can create our schemas entirely manually, and add them using - // [mcp.Server.AddTool]. Since we're working manually, we can add some - // constraints on the length of the name. - // - // We don't need to do all this work: below, we use jsonschema.For to start - // from the default schema. + // Alternatively, we can create our schemas entirely manually, and add them + // using [mcp.Server.AddTool]. Since we're using the 'raw' API, we have to do + // the parsing and validation ourselves manual, err := newManualGreeter() if err != nil { log.Fatal(err) @@ -145,6 +143,22 @@ func main() { OutputSchema: outputSchema, }, manual.greet) + // We can even use raw schema values. In this case, note that we're not + // validating the input at all. + server.AddTool(&mcp.Tool{ + Name: "unvalidated greeting", + InputSchema: json.RawMessage(`{"type":"object","properties":{"user":{"type":"string"}}}`), + }, func(_ context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Note: no validation! + var args struct{ User string } + if err := json.Unmarshal(req.Params.Arguments, &args); err != nil { + return nil, err + } + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: "Hi " + args.User}}, + }, nil + }) + // Finally, note that we can also use custom schemas with a ToolHandlerFor. // We can do this in two ways: by using one of the schema values constructed // above, or by using jsonschema.For and adjusting the resulting schema. diff --git a/mcp/client.go b/mcp/client.go index dea3e854..3211977b 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -295,7 +295,8 @@ func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult, } // Validate that the requested schema only contains top-level properties without nesting - if err := validateElicitSchema(req.Params.RequestedSchema); err != nil { + schema, err := validateElicitSchema(req.Params.RequestedSchema) + if err != nil { return nil, jsonrpc2.NewError(CodeInvalidParams, err.Error()) } @@ -305,11 +306,11 @@ func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult, } // Validate elicitation result content against requested schema - if req.Params.RequestedSchema != nil && res.Content != nil { + if schema != nil && res.Content != nil { // TODO: is this the correct behavior if validation fails? // It isn't the *server's* params that are invalid, so why would we return // this code to the server? - resolved, err := req.Params.RequestedSchema.Resolve(nil) + resolved, err := schema.Resolve(nil) if err != nil { return nil, jsonrpc2.NewError(CodeInvalidParams, fmt.Sprintf("failed to resolve requested schema: %v", err)) } @@ -324,14 +325,19 @@ func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult, // validateElicitSchema validates that the schema conforms to MCP elicitation schema requirements. // Per the MCP specification, elicitation schemas are limited to flat objects with primitive properties only. -func validateElicitSchema(schema *jsonschema.Schema) error { - if schema == nil { - return nil // nil schema is allowed +func validateElicitSchema(wireSchema any) (*jsonschema.Schema, error) { + if wireSchema == nil { + return nil, nil // nil schema is allowed + } + + var schema *jsonschema.Schema + if err := remarshal(wireSchema, &schema); err != nil { + return nil, err } // The root schema must be of type "object" if specified if schema.Type != "" && schema.Type != "object" { - return fmt.Errorf("elicit schema must be of type 'object', got %q", schema.Type) + return nil, fmt.Errorf("elicit schema must be of type 'object', got %q", schema.Type) } // Check if the schema has properties @@ -342,12 +348,12 @@ func validateElicitSchema(schema *jsonschema.Schema) error { } if err := validateElicitProperty(propName, propSchema); err != nil { - return err + return nil, err } } } - return nil + return schema, nil } // validateElicitProperty validates a single property in an elicitation schema. @@ -383,7 +389,7 @@ func validateElicitStringProperty(propName string, propSchema *jsonschema.Schema if propSchema.Extra != nil { if enumNamesRaw, exists := propSchema.Extra["enumNames"]; exists { // Type check enumNames - should be a slice - if enumNamesSlice, ok := enumNamesRaw.([]interface{}); ok { + if enumNamesSlice, ok := enumNamesRaw.([]any); ok { if len(enumNamesSlice) != len(propSchema.Enum) { return fmt.Errorf("elicit schema property %q has %d enum values but %d enumNames, they must match", propName, len(propSchema.Enum), len(enumNamesSlice)) } diff --git a/mcp/client_list_test.go b/mcp/client_list_test.go index 0183a733..6974ad5d 100644 --- a/mcp/client_list_test.go +++ b/mcp/client_list_test.go @@ -6,6 +6,7 @@ package mcp_test import ( "context" + "encoding/json" "iter" "log" "testing" @@ -41,7 +42,13 @@ func TestList(t *testing.T) { if err != nil { t.Fatal(err) } - tt.InputSchema = is + data, err := json.Marshal(is) + if err != nil { + t.Fatal(err) + } + if err := json.Unmarshal(data, &tt.InputSchema); err != nil { + t.Fatal(err) + } wantTools = append(wantTools, tt) } t.Run("list", func(t *testing.T) { diff --git a/mcp/protocol.go b/mcp/protocol.go index f3f23f58..e62867e3 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -13,8 +13,6 @@ package mcp import ( "encoding/json" "fmt" - - "github.com/google/jsonschema-go/jsonschema" ) // Optional annotations for the client. The client can use annotations to inform @@ -913,14 +911,38 @@ type Tool struct { // This can be used by clients to improve the LLM's understanding of available // tools. It can be thought of like a "hint" to the model. Description string `json:"description,omitempty"` - // A JSON Schema object defining the expected parameters for the tool. - InputSchema *jsonschema.Schema `json:"inputSchema"` + // InputSchema holds a JSON Schema object defining the expected parameters + // for the tool. + // + // From the server, this field may be set to any value that JSON-marshals to + // valid JSON schema (including json.RawMessage). However, for tools added + // using [AddTool], which automatically validates inputs and outputs, the + // schema must be in a draft the SDK understands. Currently, the SDK uses + // github.com/google/jsonschema-go for inference and validation, which only + // supports the 2020-12 draft of JSON schema. To do your own validation, use + // [Server.AddTool]. + // + // From the client, this field will hold the default JSON marshaling of the + // server's input schema (a map[string]any). + InputSchema any `json:"inputSchema"` // Intended for programmatic or logical use, but used as a display name in past // specs or fallback (if title isn't present). Name string `json:"name"` - // An optional JSON Schema object defining the structure of the tool's output - // returned in the structuredContent field of a CallToolResult. - OutputSchema *jsonschema.Schema `json:"outputSchema,omitempty"` + // OutputSchema holds an optional JSON Schema object defining the structure + // of the tool's output returned in the StructuredContent field of a + // CallToolResult. + // + // From the server, this field may be set to any value that JSON-marshals to + // valid JSON schema (including json.RawMessage). However, for tools added + // using [AddTool], which automatically validates inputs and outputs, the + // schema must be in a draft the SDK understands. Currently, the SDK uses + // github.com/google/jsonschema-go for inference and validation, which only + // supports the 2020-12 draft of JSON schema. To do your own validation, use + // [Server.AddTool]. + // + // From the client, this field will hold the default JSON marshaling of the + // server's output schema (a map[string]any). + OutputSchema any `json:"outputSchema,omitempty"` // Intended for UI and end-user contexts — optimized to be human-readable and // easily understood, even by those unfamiliar with domain-specific terminology. // If not provided, Annotations.Title should be used for display if present, @@ -1022,9 +1044,18 @@ type ElicitParams struct { Meta `json:"_meta,omitempty"` // The message to present to the user. Message string `json:"message"` - // A restricted subset of JSON Schema. + // A JSON schema object defining the requested elicitation schema. + // + // From the server, this field may be set to any value that can JSON-marshal + // to valid JSON schema (including json.RawMessage for raw schema values). + // Internally, the SDK uses github.com/google/jsonschema-go for validation, + // which only supports the 2020-12 draft of the JSON schema spec. + // + // From the client, this field will use the default JSON marshaling (a + // map[string]any). + // // Only top-level properties are allowed, without nesting. - RequestedSchema *jsonschema.Schema `json:"requestedSchema"` + RequestedSchema any `json:"requestedSchema"` } func (x *ElicitParams) isParams() {} diff --git a/mcp/server.go b/mcp/server.go index 27de09a3..4d958df4 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -164,9 +164,9 @@ func (s *Server) RemovePrompts(names ...string) { // // The tool's input schema must be non-nil and have the type "object". For a tool // that takes no input, or one where any input is valid, set [Tool.InputSchema] to -// &jsonschema.Schema{Type: "object"}. +// `{"type": "object"}`, using your preferred library or `json.RawMessage`. // -// If present, the output schema must also have type "object". +// If present, [Tool.OutputSchema] must also have type "object". // // When the handler is invoked as part of a CallTool request, req.Params.Arguments // will be a json.RawMessage. @@ -189,11 +189,29 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) { // discovered until runtime, when the LLM sent bad data. panic(fmt.Errorf("AddTool %q: missing input schema", t.Name)) } - if t.InputSchema.Type != "object" { + if s, ok := t.InputSchema.(*jsonschema.Schema); ok && s.Type != "object" { panic(fmt.Errorf(`AddTool %q: input schema must have type "object"`, t.Name)) + } else { + var m map[string]any + if err := remarshal(t.InputSchema, &m); err != nil { + panic(fmt.Errorf("AddTool %q: can't marshal input schema to a JSON object: %v", t.Name, err)) + } + if typ := m["type"]; typ != "object" { + panic(fmt.Errorf(`AddTool %q: input schema must have type "object" (got %v)`, t.Name, typ)) + } } - if t.OutputSchema != nil && t.OutputSchema.Type != "object" { - panic(fmt.Errorf(`AddTool %q: output schema must have type "object"`, t.Name)) + if t.OutputSchema != nil { + if s, ok := t.OutputSchema.(*jsonschema.Schema); ok && s.Type != "object" { + panic(fmt.Errorf(`AddTool %q: output schema must have type "object"`, t.Name)) + } else { + var m map[string]any + if err := remarshal(t.OutputSchema, &m); err != nil { + panic(fmt.Errorf("AddTool %q: can't marshal output schema to a JSON object: %v", t.Name, err)) + } + if typ := m["type"]; typ != "object" { + panic(fmt.Errorf(`AddTool %q: output schema must have type "object" (got %v)`, t.Name, typ)) + } + } } st := &serverTool{tool: t, handler: h} // Assume there was a change, since add replaces existing tools. @@ -331,7 +349,8 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan // // TODO(rfindley): we really shouldn't ever return 'null' results. Maybe we // should have a jsonschema.Zero(schema) helper? -func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) (zero any, err error) { +func setSchema[T any](sfield *any, rfield **jsonschema.Resolved) (zero any, err error) { + var internalSchema *jsonschema.Schema if *sfield == nil { rt := reflect.TypeFor[T]() if rt.Kind() == reflect.Pointer { @@ -339,28 +358,41 @@ func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) zero = reflect.Zero(rt).Interface() } // TODO: we should be able to pass nil opts here. - *sfield, err = jsonschema.ForType(rt, &jsonschema.ForOptions{}) + internalSchema, err = jsonschema.ForType(rt, &jsonschema.ForOptions{}) + if err == nil { + *sfield = internalSchema + } + } else { + if err := remarshal(*sfield, &internalSchema); err != nil { + return zero, err + } } if err != nil { return zero, err } - *rfield, err = (*sfield).Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) + *rfield, err = internalSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) return zero, err } // AddTool adds a tool and typed tool handler to the server. // // If the tool's input schema is nil, it is set to the schema inferred from the -// In type parameter, using [jsonschema.For]. The In type argument must be a -// map or a struct, so that its inferred JSON Schema has type "object". +// In type parameter. Types are inferred from Go types, and property +// descriptions are read from the 'jsonschema' struct tag. Internally, the SDK +// uses the github.com/google/jsonschema-go package for ineference and +// validation. The In type argument must be a map or a struct, so that its +// inferred JSON Schema has type "object", as required by the spec. As a +// special case, if the In type is 'any', the tool's input schema is set to an +// empty object schema value. // // If the tool's output schema is nil, and the Out type is not 'any', the // output schema is set to the schema inferred from the Out type argument, -// which also must be a map or struct. +// which must also be a map or struct. If the Out type is 'any', the output +// schema is omitted. // -// Unlike [Server.AddTool], AddTool does a lot automatically, and forces tools -// to conform to the MCP spec. See [ToolHandlerFor] for a detailed description -// of this automatic behavior. +// Unlike [Server.AddTool], AddTool does a lot automatically, and forces +// tools to conform to the MCP spec. See [ToolHandlerFor] for a detailed +// description of this automatic behavior. func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) { tt, hh, err := toolForErr(t, h) if err != nil { diff --git a/mcp/server_test.go b/mcp/server_test.go index 249ef90b..ec6114d0 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -492,7 +492,7 @@ func TestAddTool(t *testing.T) { type schema = jsonschema.Schema -func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out, wantIn, wantOut *schema, wantErrContaining string) { +func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out, wantIn, wantOut any, wantErrContaining string) { t.Helper() th := func(context.Context, *CallToolRequest, In) (*CallToolResult, Out, error) { return nil, out, nil diff --git a/mcp/tool_example_test.go b/mcp/tool_example_test.go index e41250d4..8efc3ee0 100644 --- a/mcp/tool_example_test.go +++ b/mcp/tool_example_test.go @@ -16,6 +16,49 @@ import ( "github.com/modelcontextprotocol/go-sdk/mcp" ) +func ExampleServer_AddTool_rawSchema() { + // In some scenarios, you may want your server to be a pass-through, with + // JSON schema coming from another source. Or perhaps you want to implement + // tool validation using a different JSON schema library. + // + // For these cases, you can use [mcp.Server.AddTool], which is the "raw" form + // of the API. Note that it is the caller's responsibility to validate inputs + // and outputs. + server := mcp.NewServer(&mcp.Implementation{Name: "server", Version: "v0.0.1"}, nil) + server.AddTool(&mcp.Tool{ + Name: "greet", + InputSchema: json.RawMessage(`{"type":"object","properties":{"user":{"type":"string"}}}`), + }, func(_ context.Context, req *mcp.CallToolRequest) (*mcp.CallToolResult, error) { + // Note: no validation! + var args struct{ User string } + if err := json.Unmarshal(req.Params.Arguments, &args); err != nil { + // TODO: we should use a jsonrpc error here, to be consistent with other + // SDKs. + return nil, err + } + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: "Hi " + args.User}}, + }, nil + }) + + ctx := context.Background() + session, err := connect(ctx, server) + if err != nil { + log.Fatal(err) + } + defer session.Close() + + res, err := session.CallTool(ctx, &mcp.CallToolParams{ + Name: "greet", + Arguments: map[string]any{"user": "you"}, + }) + if err != nil { + log.Fatal(err) + } + fmt.Println(res.Content[0].(*mcp.TextContent).Text) + // Output: Hi you +} + func ExampleAddTool_customMarshalling() { // Sometimes when you want to customize the input or output schema for a // tool, you need to customize the schema of a single helper type that's used @@ -68,7 +111,7 @@ func ExampleAddTool_customMarshalling() { } // Output: // my_tool { - // "type": "object", + // "additionalProperties": false, // "properties": { // "end": { // "type": "string" @@ -80,7 +123,7 @@ func ExampleAddTool_customMarshalling() { // "type": "string" // } // }, - // "additionalProperties": false + // "type": "object" // } } @@ -200,9 +243,9 @@ func ExampleAddTool_complexSchema() { } // Formatting the entire schemas would be too much output. // Just check that our customizations were effective. - fmt.Println("max days:", *t.InputSchema.Properties["days"].Maximum) - fmt.Println("max confidence:", *t.OutputSchema.Properties["confidence"].Maximum) - fmt.Println("weather types:", t.OutputSchema.Properties["dailyForecast"].Items.Properties["type"].Enum) + fmt.Println("max days:", jsonPath(t.InputSchema, "properties", "days", "maximum")) + fmt.Println("max confidence:", jsonPath(t.OutputSchema, "properties", "confidence", "maximum")) + fmt.Println("weather types:", jsonPath(t.OutputSchema, "properties", "dailyForecast", "items", "properties", "type", "enum")) } // Output: // max days: 10 @@ -218,3 +261,10 @@ func connect(ctx context.Context, server *mcp.Server) (*mcp.ClientSession, error client := mcp.NewClient(&mcp.Implementation{Name: "client", Version: "v0.0.1"}, nil) return client.Connect(ctx, t2, nil) } + +func jsonPath(s any, path ...string) any { + if len(path) == 0 { + return s + } + return jsonPath(s.(map[string]any)[path[0]], path[1:]...) +} diff --git a/mcp/util.go b/mcp/util.go index 102c0885..5ada466e 100644 --- a/mcp/util.go +++ b/mcp/util.go @@ -6,6 +6,7 @@ package mcp import ( "crypto/rand" + "encoding/json" ) func assert(cond bool, msg string) { @@ -27,3 +28,16 @@ func randText() string { } return string(src) } + +// remarshal marshals from to JSON, and then unmarshals into to, which must be +// a pointer type. +func remarshal(from, to any) error { + data, err := json.Marshal(from) + if err != nil { + return err + } + if err := json.Unmarshal(data, to); err != nil { + return err + } + return nil +}