diff --git a/README.md b/README.md index d4900674..ce4d9f11 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,21 @@ # MCP Go SDK +***BREAKING CHANGES*** + +The latest version contains breaking changes: + +- Server.AddTools is replaced by Server.AddTool. + +- NewServerTool is replaced by AddTool. AddTool takes a Tool rather than a name and description, so you can + set any field on the Tool that you want before associating it with a handler. + +- Tool options have been removed. If you don't want AddTool to infer a JSON Schema for you, you can construct one + as a struct literal, or using any other code that suits you. + +- AddPrompts, AddResources and AddResourceTemplates are similarly replaced by singular methods which pair the + feature with a handler. The ServerXXX types have been removed. + [![PkgGoDev](https://pkg.go.dev/badge/github.com/modelcontextprotocol/go-sdk)](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk) This repository contains an unreleased implementation of the official Go @@ -99,7 +114,7 @@ import ( ) type HiParams struct { - Name string `json:"name"` + Name string `json:"name", mcp:"the name of the person to greet"` } func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParamsFor[HiParams]) (*mcp.CallToolResultFor[any], error) { @@ -111,11 +126,8 @@ func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParam func main() { // Create a server with a single tool. server := mcp.NewServer("greeter", "v1.0.0", nil) - server.AddTools( - mcp.NewServerTool("greet", "say hi", SayHi, mcp.Input( - mcp.Property("name", mcp.Description("the name of the person to greet")), - )), - ) + + mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) // Run the server over stdin/stdout, until the client disconnects if err := server.Run(context.Background(), mcp.NewStdioTransport()); err != nil { log.Fatal(err) diff --git a/design/design.md b/design/design.md index b52e9c10..610de399 100644 --- a/design/design.md +++ b/design/design.md @@ -372,12 +372,11 @@ A server that can handle that client call would look like this: ```go // Create a server with a single tool. server := mcp.NewServer("greeter", "v1.0.0", nil) -server.AddTools(mcp.NewServerTool("greet", "say hi", SayHi)) +mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) // Run the server over stdin/stdout, until the client disconnects. -transport := mcp.NewStdioTransport() -session, err := server.Connect(ctx, transport) -... -return session.Wait() +if err := server.Run(context.Background(), mcp.NewStdioTransport()); err != nil { + log.Fatal(err) +} ``` For convenience, we provide `Server.Run` to handle the common case of running a session until the client disconnects: @@ -603,14 +602,14 @@ type ClientOptions struct { ### Tools -A `Tool` is a logical MCP tool, generated from the MCP spec, and a `ServerTool` is a tool bound to a tool handler. +A `Tool` is a logical MCP tool, generated from the MCP spec. -A tool handler accepts `CallToolParams` and returns a `CallToolResult`. However, since we want to bind tools to Go input types, it is convenient in associated APIs to make `CallToolParams` generic, with a type parameter `TArgs` for the tool argument type. This allows tool APIs to manage the marshalling and unmarshalling of tool inputs for their caller. The bound `ServerTool` type expects a `json.RawMessage` for its tool arguments, but the `NewServerTool` constructor described below provides a mechanism to bind a typed handler. +A tool handler accepts `CallToolParams` and returns a `CallToolResult`. However, since we want to bind tools to Go input types, it is convenient in associated APIs to have a generic version of `CallToolParams`, with a type parameter `In` for the tool argument type, as well as a generic version of for `CallToolResult`. This allows tool APIs to manage the marshalling and unmarshalling of tool inputs for their caller. ```go -type CallToolParams[TArgs any] struct { +type CallToolParamsFor[In any] struct { Meta Meta `json:"_meta,omitempty"` - Arguments TArgs `json:"arguments,omitempty"` + Arguments In `json:"arguments,omitempty"` Name string `json:"name"` } @@ -621,23 +620,31 @@ type Tool struct { Name string `json:"name"` } -type ToolHandler[TArgs] func(context.Context, *ServerSession, *CallToolParams[TArgs]) (*CallToolResult, error) +type ToolHandlerFor[In, Out any] func(context.Context, *ServerSession, *CallToolParamsFor[In]) (*CallToolResultFor[Out], error) +type ToolHandler = ToolHandlerFor[map[string]any, any] +``` -type ServerTool struct { - Tool Tool - Handler ToolHandler[json.RawMessage] -} +Add tools to a server with the `AddTool` method or function. The function is generic and infers schemas from the handler +arguments: + +```go +func (s *Server) AddTool(t *Tool, h ToolHandler) +func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) ``` -Add tools to a server with `AddTools`: +```go +mcp.AddTool(server, &mcp.Tool{Name: "add", Description: "add numbers"}, addHandler) +mcp.AddTool(server, &mcp.Tool{Name: "subtract", Description: "subtract numbers"}, subHandler) +``` +The `AddTool` method requires an input schema, and optionally an output one. It will not modify them. +The handler should accept a `CallToolParams` and return a `CallToolResult`. ```go -server.AddTools( - mcp.NewServerTool("add", "add numbers", addHandler), - mcp.NewServerTool("subtract, subtract numbers", subHandler)) +t := &Tool{Name: ..., Description: ..., InputSchema: &jsonschema.Schema{...}} +server.AddTool(t, myHandler) ``` -Remove them by name with `RemoveTools`: +Tools can be removed by name with `RemoveTools`: ```go server.RemoveTools("add", "subtract") @@ -650,53 +657,30 @@ A tool's input schema, expressed as a [JSON Schema](https://json-schema.org), pr Both of these have their advantages and disadvantages. Reflection is nice, because it allows you to bind directly to a Go API, and means that the JSON schema of your API is compatible with your Go types by construction. It also means that concerns like parsing and validation can be handled automatically. However, it can become cumbersome to express the full breadth of JSON schema using Go types or struct tags, and sometimes you want to express things that aren’t naturally modeled by Go types, like unions. Explicit schemas are simple and readable, and give the caller full control over their tool definition, but involve significant boilerplate. -We have found that a hybrid model works well, where the _initial_ schema is derived using reflection, but any customization on top of that schema is applied using variadic options. We achieve this using a `NewServerTool` helper, which generates the schema from the input type, and wraps the handler to provide parsing and validation. The schema (and potentially other features) can be customized using ToolOptions. - -```go -// NewServerTool creates a Tool using reflection on the given handler. -func NewServerTool[TArgs any](name, description string, handler ToolHandler[TArgs], opts …ToolOption) *ServerTool +We provide both ways. The `jsonschema.For[T]` function will infer a schema, and it is called by the `AddTool` generic function. +Users can also call it themselves, or build a schema directly as a struct literal. They can still use the `AddTool` function to +create a typed handler, since `AddTool` doesn't touch schemas that are already present. -type ToolOption interface { /* ... */ } -``` -`NewServerTool` determines the input schema for a Tool from the `TArgs` type. Each struct field that would be marshaled by `encoding/json.Marshal` becomes a property of the schema. The property is required unless the field's `json` tag specifies "omitempty" or "omitzero" (new in Go 1.24). For example, given this struct: +If the tool's `InputSchema` is nil, it is inferred from the `In` type parameter. If the `OutputSchema` is nil, it is inferred from the `Out` type parameter (unless `Out` is `any`). +For example, given this handler: ```go -struct { - Name string `json:"name"` - Count int `json:"count,omitempty"` - Choices []string - Password []byte `json:"-"` +type AddParams struct { + X int `json:"x"` + Y int `json:"y"` } -``` - -"name" and "Choices" are required, while "count" is optional. - -As of this writing, the only `ToolOption` is `Input`, which allows customizing the input schema of the tool using schema options. These schema options are recursive, in the sense that they may also be applied to properties. - -```go -func Input(...SchemaOption) ToolOption - -type Property(name string, opts ...SchemaOption) SchemaOption -type Description(desc string) SchemaOption -// etc. -``` - -For example: -```go -NewServerTool(name, description, handler, - Input(Property("count", Description("size of the inventory")))) +func addHandler(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[AddParams]) (*mcp.CallToolResultFor[int], error) { + return &mcp.CallToolResultFor[int]{StructuredContent: params.Arguments.X + params.Arguments.Y}, nil +} ``` -The most recent JSON Schema spec defines over 40 keywords. Providing them all as options would bloat the API despite the fact that most would be very rarely used. For less common keywords, use the `Schema` option to set the schema explicitly: - +You can add it to a server like this: ```go -NewServerTool(name, description, handler, - Input(Property("Choices", Schema(&jsonschema.Schema{UniqueItems: true})))) +mcp.AddTool(server, &mcp.Tool{Name: "add", Description: "add numbers"}, addHandler) ``` - -Schemas are validated on the server before the tool handler is called. +The input schema will be inferred from `AddParams`, and the output schema from `int`. Since all the fields of the Tool struct are exported, a Tool can also be created directly with assignment or a struct literal. @@ -718,15 +702,7 @@ For registering tools, we provide only `AddTools`; mcp-go's `SetTools`, `AddTool ### Prompts -Use `NewServerPrompt` to create a prompt. As with tools, prompt argument schemas can be inferred from a struct, or obtained from options. - -```go -func NewServerPrompt[TReq any](name, description string, - handler func(context.Context, *ServerSession, TReq) (*GetPromptResult, error), - opts ...PromptOption) *ServerPrompt -``` - -Use `AddPrompts` to add prompts to the server, and `RemovePrompts` +Use `AddPrompt` to add a prompt to the server, and `RemovePrompts` to remove them by name. ```go @@ -734,11 +710,12 @@ type codeReviewArgs struct { Code string `json:"code"` } -func codeReviewHandler(context.Context, *ServerSession, codeReviewArgs) {...} +func codeReviewHandler(context.Context, *ServerSession, *mcp.GetPromptParams) (*mcp.GetPromptResult, error) {...} -server.AddPrompts( - NewServerPrompt("code_review", "review code", codeReviewHandler, - Argument("code", Description("the code to review")))) +server.AddPrompt( + &mcp.Prompt{Name: "code_review", Description: "review code"}, + codeReviewHandler, +) server.RemovePrompts("code_review") ``` @@ -757,25 +734,11 @@ type ResourceHandler func(context.Context, *ServerSession, *ReadResourceParams) The arguments include the `ServerSession` so the handler can observe the client's roots. The handler should return the resource contents in a `ReadResourceResult`, calling either `NewTextResourceContents` or `NewBlobResourceContents`. If the handler omits the URI or MIME type, the server will populate them from the resource. -The `ServerResource` and `ServerResourceTemplate` types hold the association between the resource and its handler: - -```go -type ServerResource struct { - Resource Resource - Handler ResourceHandler -} - -type ServerResourceTemplate struct { - Template ResourceTemplate - Handler ResourceHandler -} -``` - -To add a resource or resource template to a server, users call the `AddResources` and `AddResourceTemplates` methods with one or more `ServerResource`s or `ServerResourceTemplate`s. We also provide methods to remove them. +To add a resource or resource template to a server, users call the `AddResource` and `AddResourceTemplate` methods. We also provide methods to remove them. ```go -func (*Server) AddResources(...*ServerResource) -func (*Server) AddResourceTemplates(...*ServerResourceTemplate) +func (*Server) AddResource(*Resource, ResourceHandler) +func (*Server) AddResourceTemplate(*ResourceTemplate, ResourceHandler) func (s *Server) RemoveResources(uris ...string) func (s *Server) RemoveResourceTemplates(uriTemplates ...string) @@ -796,9 +759,7 @@ Here is an example: ```go // Safely read "/public/puppies.txt". -s.AddResources(&mcp.ServerResource{ - Resource: mcp.Resource{URI: "file:///puppies.txt"}, - Handler: s.FileReadResourceHandler("/public")}) +s.AddResource(&mcp.Resource{URI: "file:///puppies.txt"}, s.FileReadResourceHandler("/public")) ``` Server sessions also support the spec methods `ListResources` and `ListResourceTemplates`, and the corresponding iterator methods `Resources` and `ResourceTemplates`. diff --git a/examples/hello/main.go b/examples/hello/main.go index 9af34cc3..4db20cc8 100644 --- a/examples/hello/main.go +++ b/examples/hello/main.go @@ -19,7 +19,7 @@ import ( var httpAddr = flag.String("http", "", "if set, use streamable HTTP at this address, instead of stdin/stdout") type HiArgs struct { - Name string `json:"name"` + Name string `json:"name" mcp:"the name to say hi to"` } func SayHi(ctx context.Context, ss *mcp.ServerSession, params *mcp.CallToolParamsFor[HiArgs]) (*mcp.CallToolResultFor[struct{}], error) { @@ -43,21 +43,13 @@ func main() { flag.Parse() server := mcp.NewServer("greeter", "v0.0.1", nil) - server.AddTools(mcp.NewServerTool("greet", "say hi", SayHi, mcp.Input( - mcp.Property("name", mcp.Description("the name to say hi to")), - ))) - server.AddPrompts(&mcp.ServerPrompt{ - Prompt: &mcp.Prompt{Name: "greet"}, - Handler: PromptHi, - }) - server.AddResources(&mcp.ServerResource{ - Resource: &mcp.Resource{ - Name: "info", - MIMEType: "text/plain", - URI: "embedded:info", - }, - Handler: handleEmbeddedResource, - }) + mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) + server.AddPrompt(&mcp.Prompt{Name: "greet"}, PromptHi) + server.AddResource(&mcp.Resource{ + Name: "info", + MIMEType: "text/plain", + URI: "embedded:info", + }, handleEmbeddedResource) if *httpAddr != "" { handler := mcp.NewStreamableHTTPHandler(func(*http.Request) *mcp.Server { diff --git a/examples/sse/main.go b/examples/sse/main.go index 97ea1bd0..c93320ab 100644 --- a/examples/sse/main.go +++ b/examples/sse/main.go @@ -35,12 +35,12 @@ func main() { } server1 := mcp.NewServer("greeter1", "v0.0.1", nil) - server1.AddTools(mcp.NewServerTool("greet1", "say hi", SayHi)) + mcp.AddTool(server1, &mcp.Tool{Name: "greet1", Description: "say hi"}, SayHi) server2 := mcp.NewServer("greeter2", "v0.0.1", nil) - server2.AddTools(mcp.NewServerTool("greet2", "say hello", SayHi)) + mcp.AddTool(server2, &mcp.Tool{Name: "greet2", Description: "say hello"}, SayHi) - log.Printf("MCP servers serving at %s\n", *httpAddr) + log.Printf("MCP servers serving at %s", *httpAddr) handler := mcp.NewSSEHandler(func(request *http.Request) *mcp.Server { url := request.URL.Path log.Printf("Handling request for URL %s\n", url) diff --git a/internal/readme/README.src.md b/internal/readme/README.src.md index 629629a4..11d63110 100644 --- a/internal/readme/README.src.md +++ b/internal/readme/README.src.md @@ -1,5 +1,20 @@ # MCP Go SDK +***BREAKING CHANGES*** + +The latest version contains breaking changes: + +- Server.AddTools is replaced by Server.AddTool. + +- NewServerTool is replaced by AddTool. AddTool takes a Tool rather than a name and description, so you can + set any field on the Tool that you want before associating it with a handler. + +- Tool options have been removed. If you don't want AddTool to infer a JSON Schema for you, you can construct one + as a struct literal, or using any other code that suits you. + +- AddPrompts, AddResources and AddResourceTemplates are similarly replaced by singular methods which pair the + feature with a handler. The ServerXXX types have been removed. + [![PkgGoDev](https://pkg.go.dev/badge/github.com/modelcontextprotocol/go-sdk)](https://pkg.go.dev/github.com/modelcontextprotocol/go-sdk) This repository contains an unreleased implementation of the official Go diff --git a/internal/readme/server/server.go b/internal/readme/server/server.go index 534e0798..1fe211ea 100644 --- a/internal/readme/server/server.go +++ b/internal/readme/server/server.go @@ -13,7 +13,7 @@ import ( ) type HiParams struct { - Name string `json:"name"` + Name string `json:"name", mcp:"the name of the person to greet"` } func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParamsFor[HiParams]) (*mcp.CallToolResultFor[any], error) { @@ -25,11 +25,8 @@ func SayHi(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParam func main() { // Create a server with a single tool. server := mcp.NewServer("greeter", "v1.0.0", nil) - server.AddTools( - mcp.NewServerTool("greet", "say hi", SayHi, mcp.Input( - mcp.Property("name", mcp.Description("the name of the person to greet")), - )), - ) + + mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) // Run the server over stdin/stdout, until the client disconnects if err := server.Run(context.Background(), mcp.NewStdioTransport()); err != nil { log.Fatal(err) diff --git a/mcp/client_list_test.go b/mcp/client_list_test.go index 7e6da95a..497a9cd0 100644 --- a/mcp/client_list_test.go +++ b/mcp/client_list_test.go @@ -22,12 +22,12 @@ func TestList(t *testing.T) { defer serverSession.Close() t.Run("tools", func(t *testing.T) { - toolA := mcp.NewServerTool("apple", "apple tool", SayHi) - toolB := mcp.NewServerTool("banana", "banana tool", SayHi) - toolC := mcp.NewServerTool("cherry", "cherry tool", SayHi) - tools := []*mcp.ServerTool{toolA, toolB, toolC} - wantTools := []*mcp.Tool{toolA.Tool, toolB.Tool, toolC.Tool} - server.AddTools(tools...) + var wantTools []*mcp.Tool + for _, name := range []string{"apple", "banana", "cherry"} { + t := &mcp.Tool{Name: name, Description: name + " tool"} + wantTools = append(wantTools, t) + mcp.AddTool(server, t, SayHi) + } t.Run("list", func(t *testing.T) { res, err := clientSession.ListTools(ctx, nil) if err != nil { @@ -43,12 +43,13 @@ func TestList(t *testing.T) { }) t.Run("resources", func(t *testing.T) { - resourceA := &mcp.ServerResource{Resource: &mcp.Resource{URI: "http://apple"}} - resourceB := &mcp.ServerResource{Resource: &mcp.Resource{URI: "http://banana"}} - resourceC := &mcp.ServerResource{Resource: &mcp.Resource{URI: "http://cherry"}} - wantResources := []*mcp.Resource{resourceA.Resource, resourceB.Resource, resourceC.Resource} - resources := []*mcp.ServerResource{resourceA, resourceB, resourceC} - server.AddResources(resources...) + var wantResources []*mcp.Resource + for _, name := range []string{"apple", "banana", "cherry"} { + r := &mcp.Resource{URI: "http://" + name} + wantResources = append(wantResources, r) + server.AddResource(r, nil) + } + t.Run("list", func(t *testing.T) { res, err := clientSession.ListResources(ctx, nil) if err != nil { @@ -64,15 +65,12 @@ func TestList(t *testing.T) { }) t.Run("templates", func(t *testing.T) { - resourceTmplA := &mcp.ServerResourceTemplate{ResourceTemplate: &mcp.ResourceTemplate{URITemplate: "http://apple/{x}"}} - resourceTmplB := &mcp.ServerResourceTemplate{ResourceTemplate: &mcp.ResourceTemplate{URITemplate: "http://banana/{x}"}} - resourceTmplC := &mcp.ServerResourceTemplate{ResourceTemplate: &mcp.ResourceTemplate{URITemplate: "http://cherry/{x}"}} - wantResourceTemplates := []*mcp.ResourceTemplate{ - resourceTmplA.ResourceTemplate, resourceTmplB.ResourceTemplate, - resourceTmplC.ResourceTemplate, + var wantResourceTemplates []*mcp.ResourceTemplate + for _, name := range []string{"apple", "banana", "cherry"} { + rt := &mcp.ResourceTemplate{URITemplate: "http://" + name + "/{x}"} + wantResourceTemplates = append(wantResourceTemplates, rt) + server.AddResourceTemplate(rt, nil) } - resourceTemplates := []*mcp.ServerResourceTemplate{resourceTmplA, resourceTmplB, resourceTmplC} - server.AddResourceTemplates(resourceTemplates...) t.Run("list", func(t *testing.T) { res, err := clientSession.ListResourceTemplates(ctx, nil) if err != nil { @@ -88,12 +86,12 @@ func TestList(t *testing.T) { }) t.Run("prompts", func(t *testing.T) { - promptA := newServerPrompt("apple", "apple prompt") - promptB := newServerPrompt("banana", "banana prompt") - promptC := newServerPrompt("cherry", "cherry prompt") - wantPrompts := []*mcp.Prompt{promptA.Prompt, promptB.Prompt, promptC.Prompt} - prompts := []*mcp.ServerPrompt{promptA, promptB, promptC} - server.AddPrompts(prompts...) + var wantPrompts []*mcp.Prompt + for _, name := range []string{"apple", "banana", "cherry"} { + p := &mcp.Prompt{Name: name, Description: name + " prompt"} + wantPrompts = append(wantPrompts, p) + server.AddPrompt(p, testPromptHandler) + } t.Run("list", func(t *testing.T) { res, err := clientSession.ListPrompts(ctx, nil) if err != nil { @@ -123,14 +121,6 @@ func testIterator[T any](ctx context.Context, t *testing.T, seq iter.Seq2[*T, er } } -// testPromptHandler is used for type inference newServerPrompt. func testPromptHandler(context.Context, *mcp.ServerSession, *mcp.GetPromptParams) (*mcp.GetPromptResult, error) { panic("not implemented") } - -func newServerPrompt(name, desc string) *mcp.ServerPrompt { - return &mcp.ServerPrompt{ - Prompt: &mcp.Prompt{Name: name, Description: desc}, - Handler: testPromptHandler, - } -} diff --git a/mcp/cmd_test.go b/mcp/cmd_test.go index f66423d6..496694a5 100644 --- a/mcp/cmd_test.go +++ b/mcp/cmd_test.go @@ -31,8 +31,7 @@ func runServer() { ctx := context.Background() server := mcp.NewServer("greeter", "v0.0.1", nil) - server.AddTools(mcp.NewServerTool("greet", "say hi", SayHi)) - + mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) if err := server.Run(ctx, mcp.NewStdioTransport()); err != nil { log.Fatal(err) } diff --git a/mcp/features_test.go b/mcp/features_test.go index 5ffbce8c..e0165ecb 100644 --- a/mcp/features_test.go +++ b/mcp/features_test.go @@ -27,9 +27,9 @@ func SayHi(ctx context.Context, cc *ServerSession, params *CallToolParamsFor[Say } func TestFeatureSetOrder(t *testing.T) { - toolA := NewServerTool("apple", "apple tool", SayHi).Tool - toolB := NewServerTool("banana", "banana tool", SayHi).Tool - toolC := NewServerTool("cherry", "cherry tool", SayHi).Tool + toolA := &Tool{Name: "apple", Description: "apple tool"} + toolB := &Tool{Name: "banana", Description: "banana tool"} + toolC := &Tool{Name: "cherry", Description: "cherry tool"} testCases := []struct { tools []*Tool @@ -52,9 +52,9 @@ func TestFeatureSetOrder(t *testing.T) { } func TestFeatureSetAbove(t *testing.T) { - toolA := NewServerTool("apple", "apple tool", SayHi).Tool - toolB := NewServerTool("banana", "banana tool", SayHi).Tool - toolC := NewServerTool("cherry", "cherry tool", SayHi).Tool + toolA := &Tool{Name: "apple", Description: "apple tool"} + toolB := &Tool{Name: "banana", Description: "banana tool"} + toolC := &Tool{Name: "cherry", Description: "cherry tool"} testCases := []struct { tools []*Tool diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 5f42b1b9..70c79b58 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -21,7 +21,6 @@ import ( "time" "github.com/google/go-cmp/cmp" - "github.com/google/go-cmp/cmp/cmpopts" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonschema" ) @@ -30,6 +29,9 @@ type hiParams struct { Name string } +// TODO(jba): after schemas are stateless (WIP), this can be a variable. +func greetTool() *Tool { return &Tool{Name: "greet", Description: "say hi"} } + func sayHi(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[hiParams]) (*CallToolResultFor[any], error) { if err := ss.Ping(ctx, nil); err != nil { return nil, fmt.Errorf("ping failed: %v", err) @@ -63,9 +65,31 @@ func TestEndToEnd(t *testing.T) { }, } s := NewServer("testServer", "v1.0.0", sopts) - add(tools, s.AddTools, "greet", "fail") - add(prompts, s.AddPrompts, "code_review", "fail") - add(resources, s.AddResources, "info.txt", "fail.txt") + AddTool(s, &Tool{ + Name: "greet", + Description: "say hi", + }, sayHi) + s.AddTool(&Tool{Name: "fail", InputSchema: &jsonschema.Schema{}}, + func(context.Context, *ServerSession, *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { + return nil, errTestFailure + }) + s.AddPrompt(&Prompt{ + Name: "code_review", + Description: "do a code review", + Arguments: []*PromptArgument{{Name: "Code", Required: true}}, + }, func(_ context.Context, _ *ServerSession, params *GetPromptParams) (*GetPromptResult, error) { + return &GetPromptResult{ + Description: "Code review prompt", + Messages: []*PromptMessage{ + {Role: "user", Content: &TextContent{Text: "Please review the following code: " + params.Arguments["Code"]}}, + }, + }, nil + }) + s.AddPrompt(&Prompt{Name: "fail"}, func(_ context.Context, _ *ServerSession, _ *GetPromptParams) (*GetPromptResult, error) { + return nil, errTestFailure + }) + s.AddResource(resource1, readHandler) + s.AddResource(resource2, readHandler) // Connect the server. ss, err := s.Connect(ctx, st) @@ -154,39 +178,14 @@ func TestEndToEnd(t *testing.T) { t.Errorf("fail returned unexpected error: got %v, want containing %v", err, errTestFailure) } - s.AddPrompts(&ServerPrompt{Prompt: &Prompt{Name: "T"}}) + s.AddPrompt(&Prompt{Name: "T"}, nil) waitForNotification(t, "prompts") s.RemovePrompts("T") waitForNotification(t, "prompts") }) t.Run("tools", func(t *testing.T) { - res, err := cs.ListTools(ctx, nil) - if err != nil { - t.Errorf("tools/list failed: %v", err) - } - wantTools := []*Tool{ - { - Name: "fail", - InputSchema: nil, - }, - { - Name: "greet", - Description: "say hi", - InputSchema: &jsonschema.Schema{ - Type: "object", - Required: []string{"Name"}, - Properties: map[string]*jsonschema.Schema{ - "Name": {Type: "string"}, - }, - AdditionalProperties: falseSchema(), - }, - }, - } - if diff := cmp.Diff(wantTools, res.Tools, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { - t.Fatalf("tools/list mismatch (-want +got):\n%s", diff) - } - + // ListTools is tested in client_list_test.go. gotHi, err := cs.CallTool(ctx, &CallToolParams{ Name: "greet", Arguments: map[string]any{"name": "user"}, @@ -222,7 +221,7 @@ func TestEndToEnd(t *testing.T) { t.Errorf("tools/call 'fail' mismatch (-want +got):\n%s", diff) } - s.AddTools(&ServerTool{Tool: &Tool{Name: "T"}, Handler: nopHandler}) + s.AddTool(&Tool{Name: "T", InputSchema: &jsonschema.Schema{}}, nopHandler) waitForNotification(t, "tools") s.RemoveTools("T") waitForNotification(t, "tools") @@ -246,8 +245,7 @@ func TestEndToEnd(t *testing.T) { MIMEType: "text/template", URITemplate: "file:///{+filename}", // the '+' means that filename can contain '/' } - st := &ServerResourceTemplate{ResourceTemplate: template, Handler: readHandler} - s.AddResourceTemplates(st) + s.AddResourceTemplate(template, readHandler) tres, err := cs.ListResourceTemplates(ctx, nil) if err != nil { t.Fatal(err) @@ -292,7 +290,7 @@ func TestEndToEnd(t *testing.T) { } } - s.AddResources(&ServerResource{Resource: &Resource{URI: "http://U"}}) + s.AddResource(&Resource{URI: "http://U"}, nil) waitForNotification(t, "resources") s.RemoveResources("http://U") waitForNotification(t, "resources") @@ -434,40 +432,6 @@ func TestEndToEnd(t *testing.T) { var ( errTestFailure = errors.New("mcp failure") - tools = map[string]*ServerTool{ - "greet": NewServerTool("greet", "say hi", sayHi), - "fail": { - Tool: &Tool{Name: "fail"}, - Handler: func(context.Context, *ServerSession, *CallToolParamsFor[map[string]any]) (*CallToolResult, error) { - return nil, errTestFailure - }, - }, - } - - prompts = map[string]*ServerPrompt{ - "code_review": { - Prompt: &Prompt{ - Name: "code_review", - Description: "do a code review", - Arguments: []*PromptArgument{{Name: "Code", Required: true}}, - }, - Handler: func(_ context.Context, _ *ServerSession, params *GetPromptParams) (*GetPromptResult, error) { - return &GetPromptResult{ - Description: "Code review prompt", - Messages: []*PromptMessage{ - {Role: "user", Content: &TextContent{Text: "Please review the following code: " + params.Arguments["Code"]}}, - }, - }, nil - }, - }, - "fail": { - Prompt: &Prompt{Name: "fail"}, - Handler: func(_ context.Context, _ *ServerSession, _ *GetPromptParams) (*GetPromptResult, error) { - return nil, errTestFailure - }, - }, - } - resource1 = &Resource{ Name: "public", MIMEType: "text/plain", @@ -484,11 +448,6 @@ var ( URI: "embedded:info", } readHandler = fileResourceHandler("testdata/files") - resources = map[string]*ServerResource{ - "info.txt": {resource1, readHandler}, - "fail.txt": {resource2, readHandler}, - "info": {resource3, handleEmbeddedResource}, - } ) var embeddedResources = map[string]string{ @@ -540,21 +499,21 @@ func errorCode(err error) int64 { return -1 } -// basicConnection returns a new basic client-server connection configured with -// the provided tools. +// basicConnection returns a new basic client-server connection, with the server +// configured via the provided function. // // The caller should cancel either the client connection or server connection // when the connections are no longer needed. -func basicConnection(t *testing.T, tools ...*ServerTool) (*ServerSession, *ClientSession) { +func basicConnection(t *testing.T, config func(*Server)) (*ServerSession, *ClientSession) { t.Helper() ctx := context.Background() ct, st := NewInMemoryTransports() s := NewServer("testServer", "v1.0.0", nil) - - // The 'greet' tool says hi. - s.AddTools(tools...) + if config != nil { + config(s) + } ss, err := s.Connect(ctx, st) if err != nil { t.Fatal(err) @@ -569,7 +528,9 @@ func basicConnection(t *testing.T, tools ...*ServerTool) (*ServerSession, *Clien } func TestServerClosing(t *testing.T) { - cc, cs := basicConnection(t, NewServerTool("greet", "say hi", sayHi)) + cc, cs := basicConnection(t, func(s *Server) { + AddTool(s, greetTool(), sayHi) + }) defer cs.Close() ctx := context.Background() @@ -651,11 +612,9 @@ func TestCancellation(t *testing.T) { } return nil, nil } - st := &ServerTool{ - Tool: &Tool{Name: "slow"}, - Handler: slowRequest, - } - _, cs := basicConnection(t, st) + _, cs := basicConnection(t, func(s *Server) { + s.AddTool(&Tool{Name: "slow"}, slowRequest) + }) defer cs.Close() ctx, cancel := context.WithCancel(context.Background()) @@ -852,7 +811,7 @@ func TestKeepAlive(t *testing.T) { KeepAlive: 100 * time.Millisecond, } s := NewServer("testServer", "v1.0.0", serverOpts) - s.AddTools(NewServerTool("greet", "say hi", sayHi)) + AddTool(s, greetTool(), sayHi) ss, err := s.Connect(ctx, st) if err != nil { @@ -897,7 +856,7 @@ func TestKeepAliveFailure(t *testing.T) { // Server without keepalive (to test one-sided keepalive) s := NewServer("testServer", "v1.0.0", nil) - s.AddTools(NewServerTool("greet", "say hi", sayHi)) + AddTool(s, greetTool(), sayHi) ss, err := s.Connect(ctx, st) if err != nil { t.Fatal(err) diff --git a/mcp/prompt.go b/mcp/prompt.go index e2db7b27..0ecf5528 100644 --- a/mcp/prompt.go +++ b/mcp/prompt.go @@ -11,8 +11,7 @@ import ( // A PromptHandler handles a call to prompts/get. type PromptHandler func(context.Context, *ServerSession, *GetPromptParams) (*GetPromptResult, error) -// A Prompt is a prompt definition bound to a prompt handler. -type ServerPrompt struct { - Prompt *Prompt - Handler PromptHandler +type serverPrompt struct { + prompt *Prompt + handler PromptHandler } diff --git a/mcp/resource.go b/mcp/resource.go index 18e0bec4..4202fdac 100644 --- a/mcp/resource.go +++ b/mcp/resource.go @@ -20,16 +20,16 @@ import ( "github.com/modelcontextprotocol/go-sdk/internal/util" ) -// A ServerResource associates a Resource with its handler. -type ServerResource struct { - Resource *Resource - Handler ResourceHandler +// A serverResource associates a Resource with its handler. +type serverResource struct { + resource *Resource + handler ResourceHandler } -// A ServerResourceTemplate associates a ResourceTemplate with its handler. -type ServerResourceTemplate struct { - ResourceTemplate *ResourceTemplate - Handler ResourceHandler +// A serverResourceTemplate associates a ResourceTemplate with its handler. +type serverResourceTemplate struct { + resourceTemplate *ResourceTemplate + handler ResourceHandler } // A ResourceHandler is a function that reads a resource. @@ -156,8 +156,8 @@ func fileRoot(root *Root) (_ string, err error) { // Matches reports whether the receiver's uri template matches the uri. // TODO: use "github.com/yosida95/uritemplate/v3" -func (sr *ServerResourceTemplate) Matches(uri string) bool { - re, err := uriTemplateToRegexp(sr.ResourceTemplate.URITemplate) +func (sr *serverResourceTemplate) Matches(uri string) bool { + re, err := uriTemplateToRegexp(sr.resourceTemplate.URITemplate) if err != nil { return false } diff --git a/mcp/server.go b/mcp/server.go index 69666a6a..cd8f808b 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -12,6 +12,7 @@ import ( "encoding/json" "fmt" "iter" + "log" "net/url" "path/filepath" "slices" @@ -20,7 +21,6 @@ import ( "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/internal/util" - "github.com/modelcontextprotocol/go-sdk/jsonschema" ) const DefaultPageSize = 1000 @@ -36,10 +36,10 @@ type Server struct { opts ServerOptions mu sync.Mutex - prompts *featureSet[*ServerPrompt] - tools *featureSet[*ServerTool] - resources *featureSet[*ServerResource] - resourceTemplates *featureSet[*ServerResourceTemplate] + prompts *featureSet[*serverPrompt] + tools *featureSet[*serverTool] + resources *featureSet[*serverResource] + resourceTemplates *featureSet[*serverResourceTemplate] sessions []*ServerSession sendingMethodHandler_ MethodHandler[*ServerSession] receivingMethodHandler_ MethodHandler[*ServerSession] @@ -87,28 +87,23 @@ func NewServer(name, version string, opts *ServerOptions) *Server { name: name, version: version, opts: *opts, - prompts: newFeatureSet(func(p *ServerPrompt) string { return p.Prompt.Name }), - tools: newFeatureSet(func(t *ServerTool) string { return t.Tool.Name }), - resources: newFeatureSet(func(r *ServerResource) string { return r.Resource.URI }), - resourceTemplates: newFeatureSet(func(t *ServerResourceTemplate) string { return t.ResourceTemplate.URITemplate }), + prompts: newFeatureSet(func(p *serverPrompt) string { return p.prompt.Name }), + tools: newFeatureSet(func(t *serverTool) string { return t.tool.Name }), + resources: newFeatureSet(func(r *serverResource) string { return r.resource.URI }), + resourceTemplates: newFeatureSet(func(t *serverResourceTemplate) string { return t.resourceTemplate.URITemplate }), sendingMethodHandler_: defaultSendingMethodHandler[*ServerSession], receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], } } -// AddPrompts adds the given prompts to the server, -// replacing any with the same names. -func (s *Server) AddPrompts(prompts ...*ServerPrompt) { - // Only notify if something could change. - if len(prompts) == 0 { - return - } - // Assume there was a change, since add replaces existing roots. - // (It's possible a root was replaced with an identical one, but not worth checking.) +// AddPrompt adds a [Prompt] to the server, or replaces one with the same name. +func (s *Server) AddPrompt(p *Prompt, h PromptHandler) { + // Assume there was a change, since add replaces existing items. + // (It's possible an item was replaced with an identical one, but not worth checking.) s.changeAndNotify( notificationPromptListChanged, &PromptListChangedParams{}, - func() bool { s.prompts.add(prompts...); return true }) + func() bool { s.prompts.add(&serverPrompt{p, h}); return true }) } // RemovePrompts removes the prompts with the given names. @@ -118,55 +113,44 @@ func (s *Server) RemovePrompts(names ...string) { func() bool { return s.prompts.remove(names...) }) } -// AddTools adds the given tools to the server, -// replacing any with the same names. -// The arguments must not be modified after this call. -// -// AddTools panics if errors are detected. -func (s *Server) AddTools(tools ...*ServerTool) { - if err := s.addToolsErr(tools...); err != nil { +// AddTool adds a [Tool] to the server, or replaces one with the same name. +// The tool's input schema must be non-nil. +// The Tool argument must not be modified after this call. +func (s *Server) AddTool(t *Tool, h ToolHandler) { + // TODO(jba): This is a breaking behavior change. Add before v0.2.0? + if t.InputSchema == nil { + log.Printf("mcp: tool %q has a nil input schema. This will panic in a future release.", t.Name) + // panic(fmt.Sprintf("adding tool %q: nil input schema", t.Name)) + } + if err := addToolErr(s, t, h); err != nil { panic(err) } } -// addToolsErr is like [AddTools], but returns an error instead of panicking. -func (s *Server) addToolsErr(tools ...*ServerTool) error { - // Only notify if something could change. - if len(tools) == 0 { - return nil +// AddTool adds a [Tool] to the server, or replaces one with the same name. +// If the tool's input schema is nil, it is set to the schema inferred from the In +// type parameter, using [jsonschema.For]. +// If the tool's output schema is nil and the Out type parameter is not the empty +// interface, then the output schema is set to the schema inferred from Out. +// The Tool argument must not be modified after this call. +func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) { + if err := addToolErr(s, t, h); err != nil { + panic(err) } - // Wrap the user's Handlers with rawHandlers that take a json.RawMessage. - for _, st := range tools { - if st.rawHandler == nil { - // This ServerTool was not created with NewServerTool. - if st.Handler == nil { - return fmt.Errorf("AddTools: tool %q has no handler", st.Tool.Name) - } - st.rawHandler = newRawHandler(st) - // Resolve the schemas, with no base URI. We don't expect tool schemas to - // refer outside of themselves. - if st.Tool.InputSchema != nil { - r, err := st.Tool.InputSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) - if err != nil { - return err - } - st.inputResolved = r - } +} - // if st.Tool.OutputSchema != nil { - // st.outputResolved, err := st.Tool.OutputSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) - // if err != nil { - // return err - // } - // } - } +func addToolErr[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) (err error) { + defer util.Wrapf(&err, "adding tool %q", t.Name) + st, err := newServerTool(t, h) + if err != nil { + return err } - // Assume there was a change, since add replaces existing tools. // (It's possible a tool was replaced with an identical one, but not worth checking.) - // TODO: surface notify error here? + // TODO: Batch these changes by size and time? The typescript SDK doesn't. + // TODO: Surface notify error here? best not, in case we need to batch. s.changeAndNotify(notificationToolListChanged, &ToolListChangedParams{}, - func() bool { s.tools.add(tools...); return true }) + func() bool { s.tools.add(st); return true }) return nil } @@ -177,26 +161,19 @@ func (s *Server) RemoveTools(names ...string) { func() bool { return s.tools.remove(names...) }) } -// AddResources adds the given resources to the server. -// If a resource with the same URI already exists, it is replaced. -// AddResources panics if a resource URI is invalid or not absolute (has an empty scheme). -func (s *Server) AddResources(resources ...*ServerResource) { - // Only notify if something could change. - if len(resources) == 0 { - return - } +// AddResource adds a [Resource] to the server, or replaces one with the same URI. +// AddResource panics if the resource URI is invalid or not absolute (has an empty scheme). +func (s *Server) AddResource(r *Resource, h ResourceHandler) { s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{}, func() bool { - for _, r := range resources { - u, err := url.Parse(r.Resource.URI) - if err != nil { - panic(err) // url.Parse includes the URI in the error - } - if !u.IsAbs() { - panic(fmt.Errorf("URI %s needs a scheme", r.Resource.URI)) - } - s.resources.add(r) + u, err := url.Parse(r.URI) + if err != nil { + panic(err) // url.Parse includes the URI in the error + } + if !u.IsAbs() { + panic(fmt.Errorf("URI %s needs a scheme", r.URI)) } + s.resources.add(&serverResource{r, h}) return true }) } @@ -208,20 +185,13 @@ func (s *Server) RemoveResources(uris ...string) { func() bool { return s.resources.remove(uris...) }) } -// AddResourceTemplates adds the given resource templates to the server. -// If a resource template with the same URI template already exists, it will be replaced. -// AddResourceTemplates panics if a URI template is invalid or not absolute (has an empty scheme). -func (s *Server) AddResourceTemplates(templates ...*ServerResourceTemplate) { - // Only notify if something could change. - if len(templates) == 0 { - return - } +// AddResourceTemplate adds a [ResourceTemplate] to the server, or replaces on with the same URI. +// AddResourceTemplate panics if a URI template is invalid or not absolute (has an empty scheme). +func (s *Server) AddResourceTemplate(t *ResourceTemplate, h ResourceHandler) { s.changeAndNotify(notificationResourceListChanged, &ResourceListChangedParams{}, func() bool { - for _, t := range templates { - // TODO: check template validity. - s.resourceTemplates.add(t) - } + // TODO: check template validity. + s.resourceTemplates.add(&serverResourceTemplate{t, h}) return true }) } @@ -268,10 +238,10 @@ func (s *Server) listPrompts(_ context.Context, _ *ServerSession, params *ListPr if params == nil { params = &ListPromptsParams{} } - return paginateList(s.prompts, s.opts.PageSize, params, &ListPromptsResult{}, func(res *ListPromptsResult, prompts []*ServerPrompt) { + return paginateList(s.prompts, s.opts.PageSize, params, &ListPromptsResult{}, func(res *ListPromptsResult, prompts []*serverPrompt) { res.Prompts = []*Prompt{} // avoid JSON null for _, p := range prompts { - res.Prompts = append(res.Prompts, p.Prompt) + res.Prompts = append(res.Prompts, p.prompt) } }) } @@ -284,7 +254,7 @@ func (s *Server) getPrompt(ctx context.Context, cc *ServerSession, params *GetPr // TODO: surface the error code over the wire, instead of flattening it into the string. return nil, fmt.Errorf("%s: unknown prompt %q", jsonrpc2.ErrInvalidParams, params.Name) } - return prompt.Handler(ctx, cc, params) + return prompt.handler(ctx, cc, params) } func (s *Server) listTools(_ context.Context, _ *ServerSession, params *ListToolsParams) (*ListToolsResult, error) { @@ -293,22 +263,22 @@ func (s *Server) listTools(_ context.Context, _ *ServerSession, params *ListTool if params == nil { params = &ListToolsParams{} } - return paginateList(s.tools, s.opts.PageSize, params, &ListToolsResult{}, func(res *ListToolsResult, tools []*ServerTool) { + return paginateList(s.tools, s.opts.PageSize, params, &ListToolsResult{}, func(res *ListToolsResult, tools []*serverTool) { res.Tools = []*Tool{} // avoid JSON null for _, t := range tools { - res.Tools = append(res.Tools, t.Tool) + res.Tools = append(res.Tools, t.tool) } }) } func (s *Server) callTool(ctx context.Context, cc *ServerSession, params *CallToolParamsFor[json.RawMessage]) (*CallToolResult, error) { s.mu.Lock() - tool, ok := s.tools.get(params.Name) + st, ok := s.tools.get(params.Name) s.mu.Unlock() if !ok { return nil, fmt.Errorf("%s: unknown tool %q", jsonrpc2.ErrInvalidParams, params.Name) } - return tool.rawHandler(ctx, cc, params) + return st.handler(ctx, cc, params) } func (s *Server) listResources(_ context.Context, _ *ServerSession, params *ListResourcesParams) (*ListResourcesResult, error) { @@ -317,10 +287,10 @@ func (s *Server) listResources(_ context.Context, _ *ServerSession, params *List if params == nil { params = &ListResourcesParams{} } - return paginateList(s.resources, s.opts.PageSize, params, &ListResourcesResult{}, func(res *ListResourcesResult, resources []*ServerResource) { + return paginateList(s.resources, s.opts.PageSize, params, &ListResourcesResult{}, func(res *ListResourcesResult, resources []*serverResource) { res.Resources = []*Resource{} // avoid JSON null for _, r := range resources { - res.Resources = append(res.Resources, r.Resource) + res.Resources = append(res.Resources, r.resource) } }) } @@ -332,10 +302,10 @@ func (s *Server) listResourceTemplates(_ context.Context, _ *ServerSession, para params = &ListResourceTemplatesParams{} } return paginateList(s.resourceTemplates, s.opts.PageSize, params, &ListResourceTemplatesResult{}, - func(res *ListResourceTemplatesResult, rts []*ServerResourceTemplate) { + func(res *ListResourceTemplatesResult, rts []*serverResourceTemplate) { res.ResourceTemplates = []*ResourceTemplate{} // avoid JSON null for _, rt := range rts { - res.ResourceTemplates = append(res.ResourceTemplates, rt.ResourceTemplate) + res.ResourceTemplates = append(res.ResourceTemplates, rt.resourceTemplate) } }) } @@ -376,12 +346,12 @@ func (s *Server) lookupResourceHandler(uri string) (ResourceHandler, string, boo defer s.mu.Unlock() // Try resources first. if r, ok := s.resources.get(uri); ok { - return r.Handler, r.Resource.MIMEType, true + return r.handler, r.resource.MIMEType, true } // Look for matching template. for rt := range s.resourceTemplates.all() { if rt.Matches(uri) { - return rt.Handler, rt.ResourceTemplate.MIMEType, true + return rt.handler, rt.resourceTemplate.MIMEType, true } } return nil, "", false diff --git a/mcp/server_example_test.go b/mcp/server_example_test.go index 9e982374..fd6eea00 100644 --- a/mcp/server_example_test.go +++ b/mcp/server_example_test.go @@ -29,7 +29,7 @@ func ExampleServer() { clientTransport, serverTransport := mcp.NewInMemoryTransports() server := mcp.NewServer("greeter", "v0.0.1", nil) - server.AddTools(mcp.NewServerTool("greet", "say hi", SayHi)) + mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi) serverSession, err := server.Connect(ctx, serverTransport) if err != nil { diff --git a/mcp/shared_test.go b/mcp/shared_test.go index 5a1d5d02..f319d80e 100644 --- a/mcp/shared_test.go +++ b/mcp/shared_test.go @@ -12,7 +12,7 @@ import ( ) // TODO(jba): this shouldn't be in this file, but tool_test.go doesn't have access to unexported symbols. -func TestNewServerToolValidate(t *testing.T) { +func TestToolValidate(t *testing.T) { // Check that the tool returned from NewServerTool properly validates its input schema. type req struct { @@ -26,9 +26,10 @@ func TestNewServerToolValidate(t *testing.T) { return nil, nil } - tool := NewServerTool("test", "test", dummyHandler) - // Need to add the tool to a server to get resolved schemas. - // s := NewServer("", "", nil) + st, err := newServerTool(&Tool{Name: "test", Description: "test"}, dummyHandler) + if err != nil { + t.Fatal(err) + } for _, tt := range []struct { desc string @@ -71,7 +72,7 @@ func TestNewServerToolValidate(t *testing.T) { if err != nil { t.Fatal(err) } - _, err = tool.rawHandler(context.Background(), nil, + _, err = st.handler(context.Background(), nil, &CallToolParamsFor[json.RawMessage]{Arguments: json.RawMessage(raw)}) if err == nil && tt.want != "" { t.Error("got success, wanted failure") diff --git a/mcp/sse_example_test.go b/mcp/sse_example_test.go index 70f84c3e..816e0134 100644 --- a/mcp/sse_example_test.go +++ b/mcp/sse_example_test.go @@ -28,7 +28,7 @@ func Add(ctx context.Context, cc *mcp.ServerSession, params *mcp.CallToolParamsF func ExampleSSEHandler() { server := mcp.NewServer("adder", "v0.0.1", nil) - server.AddTools(mcp.NewServerTool("add", "add two numbers", Add)) + mcp.AddTool(server, &mcp.Tool{Name: "add", Description: "add two numbers"}, Add) handler := mcp.NewSSEHandler(func(*http.Request) *mcp.Server { return server }) httpServer := httptest.NewServer(handler) diff --git a/mcp/sse_test.go b/mcp/sse_test.go index 23621931..e1df9536 100644 --- a/mcp/sse_test.go +++ b/mcp/sse_test.go @@ -20,7 +20,7 @@ func TestSSEServer(t *testing.T) { t.Run(fmt.Sprintf("closeServerFirst=%t", closeServerFirst), func(t *testing.T) { ctx := context.Background() server := NewServer("testServer", "v1.0.0", nil) - server.AddTools(NewServerTool("greet", "say hi", sayHi)) + AddTool(server, &Tool{Name: "greet"}, sayHi) sseHandler := NewSSEHandler(func(*http.Request) *Server { return server }) diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index a8c916e8..8925b3da 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -32,8 +32,7 @@ func TestStreamableTransports(t *testing.T) { // 1. Create a server with a simple "greet" tool. server := NewServer("testServer", "v1.0.0", nil) - server.AddTools(NewServerTool("greet", "say hi", sayHi)) - + AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) // 2. Start an httptest.Server with the StreamableHTTPHandler, wrapped in a // cookie-checking middleware. handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) @@ -323,13 +322,12 @@ func TestStreamableServerTransport(t *testing.T) { // Create a server containing a single tool, which runs the test tool // behavior, if any. server := NewServer("testServer", "v1.0.0", nil) - tool := NewServerTool("tool", "test tool", func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[any]) (*CallToolResultFor[any], error) { + AddTool(server, &Tool{Name: "tool"}, func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[any]) (*CallToolResultFor[any], error) { if test.tool != nil { test.tool(t, ctx, ss) } return &CallToolResultFor[any]{}, nil }) - server.AddTools(tool) // Start the streamable handler. handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) diff --git a/mcp/tool.go b/mcp/tool.go index a6f228eb..fc154991 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -9,7 +9,7 @@ import ( "context" "encoding/json" "fmt" - "slices" + "reflect" "github.com/modelcontextprotocol/go-sdk/jsonschema" ) @@ -17,8 +17,7 @@ import ( // A ToolHandler handles a call to tools/call. // [CallToolParams.Arguments] will contain a map[string]any that has been validated // against the input schema. -// TODO: Perhaps this should be an alias for ToolHandlerFor[map[string]any, map[string]any]? -type ToolHandler func(context.Context, *ServerSession, *CallToolParamsFor[map[string]any]) (*CallToolResult, error) +type ToolHandler = ToolHandlerFor[map[string]any, any] // A ToolHandlerFor handles a call to tools/call with typed arguments and results. type ToolHandlerFor[In, Out any] func(context.Context, *ServerSession, *CallToolParamsFor[In]) (*CallToolResultFor[Out], error) @@ -26,62 +25,33 @@ type ToolHandlerFor[In, Out any] func(context.Context, *ServerSession, *CallTool // A rawToolHandler is like a ToolHandler, but takes the arguments as as json.RawMessage. type rawToolHandler = func(context.Context, *ServerSession, *CallToolParamsFor[json.RawMessage]) (*CallToolResult, error) -// A ServerTool is a tool definition that is bound to a tool handler. -type ServerTool struct { - Tool *Tool - Handler ToolHandler - // Set in NewServerTool or Server.addToolsErr. - rawHandler rawToolHandler - // Resolved tool schemas. Set in Server.addToolsErr. +// A serverTool is a tool definition that is bound to a tool handler. +type serverTool struct { + tool *Tool + handler rawToolHandler + // Resolved tool schemas. Set in newServerTool. inputResolved, outputResolved *jsonschema.Resolved } -// NewServerTool is a helper to make a tool using reflection on the given type parameters. -// When the tool is called, CallToolParams.Arguments will be of type In. -// -// If provided, variadic [ToolOption] values may be used to customize the tool. -// -// The input schema for the tool is extracted from the request type for the -// handler, and used to unmmarshal and validate requests to the handler. This -// schema may be customized using the [Input] option. -// -// TODO(jba): check that structured content is set in response. -func NewServerTool[In, Out any](name, description string, handler ToolHandlerFor[In, Out], opts ...ToolOption) *ServerTool { - st, err := newServerToolErr[In, Out](name, description, handler, opts...) - if err != nil { - panic(fmt.Errorf("NewServerTool(%q): %w", name, err)) - } - return st -} +// newServerTool creates a serverTool from a tool and a handler. +// If the tool doesn't have an input schema, it is inferred from In. +// If the tool doesn't have an output schema and Out != any, it is inferred from Out. +func newServerTool[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*serverTool, error) { + st := &serverTool{tool: t} -func newServerToolErr[In, Out any](name, description string, handler ToolHandlerFor[In, Out], opts ...ToolOption) (*ServerTool, error) { - // TODO: check that In is a struct. - ischema, err := jsonschema.For[In]() - if err != nil { + if err := setSchema[In](&t.InputSchema, &st.inputResolved); err != nil { return nil, err } - // TODO: uncomment when output schemas drop. - // oschema, err := jsonschema.For[TRes]() - // if err != nil { - // return nil, err - // } - - t := &ServerTool{ - Tool: &Tool{ - Name: name, - Description: description, - InputSchema: ischema, - // OutputSchema: oschema, - }, - } - for _, opt := range opts { - opt.set(t) + if reflect.TypeFor[Out]() != reflect.TypeFor[any]() { + if err := setSchema[Out](&t.OutputSchema, &st.outputResolved); err != nil { + return nil, err + } } - t.rawHandler = func(ctx context.Context, ss *ServerSession, rparams *CallToolParamsFor[json.RawMessage]) (*CallToolResult, error) { + st.handler = func(ctx context.Context, ss *ServerSession, rparams *CallToolParamsFor[json.RawMessage]) (*CallToolResult, error) { var args In if rparams.Arguments != nil { - if err := unmarshalSchema(rparams.Arguments, t.inputResolved, &args); err != nil { + if err := unmarshalSchema(rparams.Arguments, st.inputResolved, &args); err != nil { return nil, err } } @@ -91,55 +61,41 @@ func newServerToolErr[In, Out any](name, description string, handler ToolHandler Name: rparams.Name, Arguments: args, } - res, err := handler(ctx, ss, params) + res, err := h(ctx, ss, params) + // TODO(rfindley): investigate why server errors are embedded in this strange way, + // rather than returned as jsonrpc2 server errors. if err != nil { - return nil, err + return &CallToolResult{ + Content: []Content{&TextContent{Text: err.Error()}}, + IsError: true, + }, nil } - var ctr CallToolResult + // TODO(jba): What if res == nil? Is that valid? + // TODO(jba): if t.OutputSchema != nil, check that StructuredContent is present and validates. if res != nil { // TODO(jba): future-proof this copy. ctr.Meta = res.Meta ctr.Content = res.Content ctr.IsError = res.IsError + ctr.StructuredContent = res.StructuredContent } return &ctr, nil } - return t, nil + + return st, nil } -// newRawHandler creates a rawToolHandler for tools not created through NewServerTool. -// It unmarshals the arguments into a map[string]any and validates them against the -// schema, then calls the ServerTool's handler. -func newRawHandler(st *ServerTool) rawToolHandler { - if st.Handler == nil { - panic("st.Handler is nil") +func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) error { + var err error + if *sfield == nil { + *sfield, err = jsonschema.For[T]() } - return func(ctx context.Context, ss *ServerSession, rparams *CallToolParamsFor[json.RawMessage]) (*CallToolResult, error) { - // Unmarshal the args into what should be a map. - var args map[string]any - if rparams.Arguments != nil { - if err := unmarshalSchema(rparams.Arguments, st.inputResolved, &args); err != nil { - return nil, err - } - } - // TODO: generate copy - params := &CallToolParamsFor[map[string]any]{ - Meta: rparams.Meta, - Name: rparams.Name, - Arguments: args, - } - res, err := st.Handler(ctx, ss, params) - // TODO(rfindley): investigate why server errors are embedded in this strange way, - // rather than returned as jsonrpc2 server errors. - if err != nil { - return &CallToolResult{ - Content: []Content{&TextContent{Text: err.Error()}}, - IsError: true, - }, nil - } - return res, nil + if err != nil { + return err } + *rfield, err = (*sfield).Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) + return err } // unmarshalSchema unmarshals data into v and validates the result according to @@ -169,105 +125,6 @@ func unmarshalSchema(data json.RawMessage, resolved *jsonschema.Resolved, v any) return nil } -// A ToolOption configures the behavior of a Tool. -type ToolOption interface { - set(*ServerTool) -} - -type toolSetter func(*ServerTool) - -func (s toolSetter) set(t *ServerTool) { s(t) } - -// Input applies the provided [SchemaOption] configuration to the tool's input -// schema. -func Input(opts ...SchemaOption) ToolOption { - return toolSetter(func(t *ServerTool) { - for _, opt := range opts { - opt.set(t.Tool.InputSchema) - } - }) -} - -// A SchemaOption configures a jsonschema.Schema. -type SchemaOption interface { - set(s *jsonschema.Schema) -} - -type schemaSetter func(*jsonschema.Schema) - -func (s schemaSetter) set(schema *jsonschema.Schema) { s(schema) } - -// Property configures the schema for the property of the given name. -// If there is no such property in the schema, it is created. -func Property(name string, opts ...SchemaOption) SchemaOption { - return schemaSetter(func(schema *jsonschema.Schema) { - propSchema, ok := schema.Properties[name] - if !ok { - propSchema = new(jsonschema.Schema) - schema.Properties[name] = propSchema - } - // Apply the options, with special handling for Required, as it needs to be - // set on the parent schema. - for _, opt := range opts { - if req, ok := opt.(required); ok { - if req { - if !slices.Contains(schema.Required, name) { - schema.Required = append(schema.Required, name) - } - } else { - schema.Required = slices.DeleteFunc(schema.Required, func(s string) bool { - return s == name - }) - } - } else { - opt.set(propSchema) - } - } - }) -} - -// Required sets whether the associated property is required. It is only valid -// when used in a [Property] option: using Required outside of Property panics. -func Required(v bool) SchemaOption { - return required(v) -} - -// required must be a distinguished type as it needs special handling to mutate -// the parent schema, and to mutate prompt arguments. -type required bool - -func (required) set(s *jsonschema.Schema) { - panic("use of required outside of Property") -} - -// Enum sets the provided values as the "enum" value of the schema. -func Enum(values ...any) SchemaOption { - return schemaSetter(func(s *jsonschema.Schema) { - s.Enum = values - }) -} - -// Description sets the provided schema description. -func Description(desc string) SchemaOption { - return description(desc) -} - -// description must be a distinguished type so that it can be handled by prompt -// options. -type description string - -func (d description) set(s *jsonschema.Schema) { - s.Description = string(d) -} - -// Schema overrides the inferred schema with a shallow copy of the given -// schema. -func Schema(schema *jsonschema.Schema) SchemaOption { - return schemaSetter(func(s *jsonschema.Schema) { - *s = *schema - }) -} - // schemaJSON returns the JSON value for s as a string, or a string indicating an error. func schemaJSON(s *jsonschema.Schema) string { m, err := json.Marshal(s) diff --git a/mcp/tool_test.go b/mcp/tool_test.go index 85775e9b..4d0a329b 100644 --- a/mcp/tool_test.go +++ b/mcp/tool_test.go @@ -16,76 +16,80 @@ import ( ) // testToolHandler is used for type inference in TestNewServerTool. -func testToolHandler[T any](context.Context, *ServerSession, *CallToolParamsFor[T]) (*CallToolResultFor[any], error) { +func testToolHandler[In, Out any](context.Context, *ServerSession, *CallToolParamsFor[In]) (*CallToolResultFor[Out], error) { panic("not implemented") } +func srvTool[In, Out any](t *testing.T, tool *Tool, handler ToolHandlerFor[In, Out]) *serverTool { + t.Helper() + st, err := newServerTool(tool, handler) + if err != nil { + t.Fatal(err) + } + return st +} + func TestNewServerTool(t *testing.T) { + type ( + Name struct { + Name string `json:"name"` + } + Size struct { + Size int `json:"size"` + } + ) + + nameSchema := &jsonschema.Schema{ + Type: "object", + Required: []string{"name"}, + Properties: map[string]*jsonschema.Schema{ + "name": {Type: "string"}, + }, + AdditionalProperties: &jsonschema.Schema{Not: new(jsonschema.Schema)}, + } + sizeSchema := &jsonschema.Schema{ + Type: "object", + Required: []string{"size"}, + Properties: map[string]*jsonschema.Schema{ + "size": {Type: "integer"}, + }, + AdditionalProperties: &jsonschema.Schema{Not: new(jsonschema.Schema)}, + } + tests := []struct { - tool *ServerTool - want *jsonschema.Schema + tool *serverTool + wantIn, wantOut *jsonschema.Schema }{ { - NewServerTool("basic", "", testToolHandler[struct { - Name string `json:"name"` - }]), - &jsonschema.Schema{ - Type: "object", - Required: []string{"name"}, - Properties: map[string]*jsonschema.Schema{ - "name": {Type: "string"}, - }, - AdditionalProperties: &jsonschema.Schema{Not: new(jsonschema.Schema)}, - }, + srvTool(t, &Tool{Name: "basic"}, testToolHandler[Name, Size]), + nameSchema, + sizeSchema, }, { - NewServerTool("enum", "", testToolHandler[struct{ Name string }], Input( - Property("Name", Enum("x", "y", "z")), - )), - &jsonschema.Schema{ - Type: "object", - Required: []string{"Name"}, - Properties: map[string]*jsonschema.Schema{ - "Name": {Type: "string", Enum: []any{"x", "y", "z"}}, - }, - AdditionalProperties: &jsonschema.Schema{Not: new(jsonschema.Schema)}, - }, + srvTool(t, &Tool{ + Name: "in untouched", + InputSchema: &jsonschema.Schema{}, + }, testToolHandler[Name, Size]), + &jsonschema.Schema{}, + sizeSchema, }, { - NewServerTool("required", "", testToolHandler[struct { - Name string `json:"name"` - Language string `json:"language"` - X int `json:"x,omitempty"` - Y int `json:"y,omitempty"` - }], Input( - Property("x", Required(true)))), - &jsonschema.Schema{ - Type: "object", - Required: []string{"name", "language", "x"}, - Properties: map[string]*jsonschema.Schema{ - "language": {Type: "string"}, - "name": {Type: "string"}, - "x": {Type: "integer"}, - "y": {Type: "integer"}, - }, - AdditionalProperties: &jsonschema.Schema{Not: new(jsonschema.Schema)}, - }, + srvTool(t, &Tool{Name: "out untouched", OutputSchema: &jsonschema.Schema{}}, testToolHandler[Name, Size]), + nameSchema, + &jsonschema.Schema{}, }, { - NewServerTool("set_schema", "", testToolHandler[struct { - X int `json:"x,omitempty"` - Y int `json:"y,omitempty"` - }], Input( - Schema(&jsonschema.Schema{Type: "object"})), - ), - &jsonschema.Schema{ - Type: "object", - }, + srvTool(t, &Tool{Name: "nil out"}, testToolHandler[Name, any]), + nameSchema, + nil, }, } for _, test := range tests { - if diff := cmp.Diff(test.want, test.tool.Tool.InputSchema, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { - t.Errorf("NewServerTool(%v) mismatch (-want +got):\n%s", test.tool.Tool.Name, diff) + if diff := cmp.Diff(test.wantIn, test.tool.tool.InputSchema, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + t.Errorf("newServerTool(%q) input schema mismatch (-want +got):\n%s", test.tool.tool.Name, diff) + } + if diff := cmp.Diff(test.wantOut, test.tool.tool.OutputSchema, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + t.Errorf("newServerTool(%q) output schema mismatch (-want +got):\n%s", test.tool.tool.Name, diff) } } }