From 36e3a6761020384d35666679769a5e632ae64fcc Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 14 Aug 2025 06:49:04 -0400 Subject: [PATCH 1/2] mcp: remove tool genericity DO NOT SUBMIT TESTS DO NOT PASS YET API changes to remove genericity from the tool call path. This makes it easier to write code that can deal with tools generally, like wrappers around a ToolHandler. Here is the go doc diff: --- /tmp/old.doc 2025-08-14 09:03:30.772292329 -0400 +++ /tmp/new.doc 2025-08-14 08:58:37.113063370 -0400 @@ -73,7 +73,7 @@ FUNCTIONS -func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) +func AddTool[In, Out any](s *Server, t *Tool, h TypedToolHandler[In, Out]) AddTool adds a Tool to the server, or replaces one with the same name. If the tool's input schema is nil, it is set to the schema inferred from the In type parameter, using jsonschema.For. If the tool's output schema is @@ -81,6 +81,10 @@ schema is set to the schema inferred from Out. The Tool argument must not be modified after this call. + The handler should return the result as the second return value. The first + return value, a *CallToolResult, may be nil, or its fields other than + StructuredContent may be populated. + func NewInMemoryTransports() (*InMemoryTransport, *InMemoryTransport) NewInMemoryTransports returns two [InMemoryTransports] that connect to each other. @@ -125,24 +129,28 @@ func (c AudioContent) MarshalJSON() ([]byte, error) -type CallToolParams = CallToolParamsFor[any] - -type CallToolParamsFor[In any] struct { +type CallToolParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. Meta `json:"_meta,omitempty"` Name string `json:"name"` - Arguments In `json:"arguments,omitempty"` + Arguments any `json:"arguments,omitempty"` } -func (x *CallToolParamsFor[Out]) GetProgressToken() any +func (x *CallToolParams) GetProgressToken() any -func (x *CallToolParamsFor[Out]) SetProgressToken(t any) +func (x *CallToolParams) SetProgressToken(t any) -type CallToolResult = CallToolResultFor[any] - The server's response to a tool call. +func (c *CallToolParams) UnmarshalJSON(data []byte) error + When unmarshalling CallToolParams on the server side, we need to delay + unmarshaling of the arguments. -type CallToolResultFor[Out any] struct { +type CallToolRequest struct { + Session *ServerSession + Params *CallToolParams +} + +type CallToolResult struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. Meta `json:"_meta,omitempty"` @@ -151,7 +159,7 @@ Content []Content `json:"content"` // An optional JSON object that represents the structured result of the tool // call. - StructuredContent Out `json:"structuredContent,omitempty"` + StructuredContent any `json:"structuredContent,omitempty"` // Whether the tool call ended in an error. // // If not set, this is assumed to be false (the call was successful). @@ -166,8 +174,9 @@ // should be reported as an MCP error response. IsError bool `json:"isError,omitempty"` } + The server's response to a tool call. -func (x *CallToolResultFor[Out]) UnmarshalJSON(data []byte) error +func (x *CallToolResult) UnmarshalJSON(data []byte) error UnmarshalJSON handles the unmarshalling of content into the Content interface. @@ -283,7 +292,7 @@ Session *ClientSession Params P } - A ClientRequest is a request to a client. + A ClientRequest[P] is a request to a client. func (r *ClientRequest[P]) GetParams() Params @@ -1532,9 +1541,7 @@ type ServerSession struct { // Has unexported fields. } - A ServerSession is a logical connection from a single MCP client. - Its methods can be used to send requests or notifications to the client. - Create a session by calling Server.Connect. + a session by calling Server.Connect. Call ServerSession.Close to close the connection, or await client termination with ServerSession.Wait. @@ -1786,6 +1793,8 @@ // If not provided, Annotations.Title should be used for display if present, // otherwise Name. Title string `json:"title,omitempty"` + + // Has unexported fields. } Definition for a tool the client can call. @@ -1826,13 +1835,10 @@ Clients should never make tool use decisions based on ToolAnnotations received from untrusted servers. -type ToolHandler = ToolHandlerFor[map[string]any, any] - A ToolHandler handles a call to tools/call. [CallToolParams.Arguments] will - contain a map[string]any that has been validated against the input schema. - -type ToolHandlerFor[In, Out any] func(context.Context, *ServerRequest[*CallToolParamsFor[In]]) (*CallToolResultFor[Out], error) - A ToolHandlerFor handles a call to tools/call with typed arguments and - results. +type ToolHandler func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, error) + A ToolHandler handles a call to tools/call. req.Params.Arguments will + contain a json.RawMessage containing the arguments. args will contain a + value that has been validated against the input schema. type ToolListChangedParams struct { // This property is reserved by the protocol to allow clients and servers to @@ -1856,6 +1862,10 @@ Transports should be used for at most one call to Server.Connect or Client.Connect. +type TypedToolHandler[In, Out any] func(context.Context, *ServerRequest[*CallToolParams], In) (*CallToolResult, Out, error) + A TypedToolHandler handles a call to tools/call with typed arguments and + results. + type UnsubscribeParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. --- examples/server/sse/main.go | 8 +- mcp/example_middleware_test.go | 9 +- mcp/features_test.go | 12 +- mcp/mcp_test.go | 12 +- mcp/protocol.go | 44 ++-- mcp/protocol_test.go | 15 +- mcp/server.go | 35 +-- mcp/server_example_test.go | 8 +- mcp/shared.go | 4 - mcp/shared_test.go | 435 ++++++++++++++++----------------- mcp/sse_example_test.go | 8 +- mcp/streamable_test.go | 20 +- mcp/tool.go | 126 +++++----- mcp/tool_test.go | 6 +- 14 files changed, 374 insertions(+), 368 deletions(-) diff --git a/examples/server/sse/main.go b/examples/server/sse/main.go index 2fbd695e..c2603b41 100644 --- a/examples/server/sse/main.go +++ b/examples/server/sse/main.go @@ -24,12 +24,12 @@ type SayHiParams struct { Name string `json:"name"` } -func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[SayHiParams]]) (*mcp.CallToolResultFor[any], error) { - return &mcp.CallToolResultFor[any]{ +func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args SayHiParams) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ Content: []mcp.Content{ - &mcp.TextContent{Text: "Hi " + req.Params.Arguments.Name}, + &mcp.TextContent{Text: "Hi " + args.Name}, }, - }, nil + }, nil, nil } func main() { diff --git a/mcp/example_middleware_test.go b/mcp/example_middleware_test.go index 56f7428a..b0074cd3 100644 --- a/mcp/example_middleware_test.go +++ b/mcp/example_middleware_test.go @@ -89,15 +89,16 @@ func Example_loggingMiddleware() { }, func( ctx context.Context, - req *mcp.ServerRequest[*mcp.CallToolParamsFor[map[string]any]], - ) (*mcp.CallToolResultFor[any], error) { - name, ok := req.Params.Arguments["name"].(string) + req *mcp.ServerRequest[*mcp.CallToolParams], + args any, + ) (*mcp.CallToolResult, error) { + name, ok := args.(map[string]any)["name"].(string) if !ok { return nil, fmt.Errorf("name parameter is required and must be a string") } message := fmt.Sprintf("Hello, %s!", name) - return &mcp.CallToolResultFor[any]{ + return &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{Text: message}, }, diff --git a/mcp/features_test.go b/mcp/features_test.go index 1c22ecd3..f52ffe5d 100644 --- a/mcp/features_test.go +++ b/mcp/features_test.go @@ -18,12 +18,12 @@ type SayHiParams struct { Name string `json:"name"` } -func SayHi(ctx context.Context, cc *ServerSession, params *CallToolParamsFor[SayHiParams]) (*CallToolResultFor[any], error) { - return &CallToolResultFor[any]{ +func SayHi(ctx context.Context, req *ServerRequest[*CallToolParams], args SayHiParams) (*CallToolResult, any, error) { + return &CallToolResult{ Content: []Content{ - &TextContent{Text: "Hi " + params.Name}, + &TextContent{Text: "Hi " + args.Name}, }, - }, nil + }, nil, nil } func TestFeatureSetOrder(t *testing.T) { @@ -45,7 +45,7 @@ func TestFeatureSetOrder(t *testing.T) { fs := newFeatureSet(func(t *Tool) string { return t.Name }) fs.add(tc.tools...) got := slices.Collect(fs.all()) - if diff := cmp.Diff(got, tc.want, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + if diff := cmp.Diff(got, tc.want, cmpopts.IgnoreUnexported(jsonschema.Schema{}, Tool{})); diff != "" { t.Errorf("expected %v, got %v, (-want +got):\n%s", tc.want, got, diff) } } @@ -69,7 +69,7 @@ func TestFeatureSetAbove(t *testing.T) { fs := newFeatureSet(func(t *Tool) string { return t.Name }) fs.add(tc.tools...) got := slices.Collect(fs.above(tc.above)) - if diff := cmp.Diff(got, tc.want, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + if diff := cmp.Diff(got, tc.want, cmpopts.IgnoreUnexported(jsonschema.Schema{}, Tool{})); diff != "" { t.Errorf("expected %v, got %v, (-want +got):\n%s", tc.want, got, diff) } } diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 58b0377e..56a32ead 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -32,11 +32,11 @@ type hiParams struct { // TODO(jba): after schemas are stateless (WIP), this can be a variable. func greetTool() *Tool { return &Tool{Name: "greet", Description: "say hi"} } -func sayHi(ctx context.Context, req *ServerRequest[*CallToolParamsFor[hiParams]]) (*CallToolResultFor[any], error) { +func sayHi(ctx context.Context, req *ServerRequest[*CallToolParams], args hiParams) (*CallToolResult, any, error) { if err := req.Session.Ping(ctx, nil); err != nil { - return nil, fmt.Errorf("ping failed: %v", err) + return nil, nil, fmt.Errorf("ping failed: %v", err) } - return &CallToolResultFor[any]{Content: []Content{&TextContent{Text: "hi " + req.Params.Arguments.Name}}}, nil + return &CallToolResult{Content: []Content{&TextContent{Text: "hi " + args.Name}}}, nil, nil } var codeReviewPrompt = &Prompt{ @@ -97,7 +97,7 @@ func TestEndToEnd(t *testing.T) { Description: "say hi", }, sayHi) s.AddTool(&Tool{Name: "fail", InputSchema: &jsonschema.Schema{}}, - func(context.Context, *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResult, error) { + func(context.Context, *ServerRequest[*CallToolParams], any) (*CallToolResult, error) { return nil, errTestFailure }) s.AddPrompt(codeReviewPrompt, codReviewPromptHandler) @@ -647,7 +647,7 @@ func TestCancellation(t *testing.T) { cancelled = make(chan struct{}, 1) // don't block the request ) - slowRequest := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResult, error) { + slowRequest := func(ctx context.Context, _ *ServerRequest[*CallToolParams], _ any) (*CallToolResult, error) { start <- struct{}{} select { case <-ctx.Done(): @@ -836,7 +836,7 @@ func traceCalls[S Session](w io.Writer, prefix string) Middleware { } } -func nopHandler(context.Context, *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResult, error) { +func nopHandler(context.Context, *ServerRequest[*CallToolParams], any) (*CallToolResult, error) { return nil, nil } diff --git a/mcp/protocol.go b/mcp/protocol.go index d2d343b8..2f222952 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -40,20 +40,32 @@ type Annotations struct { Priority float64 `json:"priority,omitempty"` } -type CallToolParams = CallToolParamsFor[any] - -type CallToolParamsFor[In any] struct { +type CallToolParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. Meta `json:"_meta,omitempty"` Name string `json:"name"` - Arguments In `json:"arguments,omitempty"` + Arguments any `json:"arguments,omitempty"` } -// The server's response to a tool call. -type CallToolResult = CallToolResultFor[any] +// When unmarshalling CallToolParams on the server side, we need to delay unmarshaling of the arguments. +func (c *CallToolParams) UnmarshalJSON(data []byte) error { + var raw struct { + Meta `json:"_meta,omitempty"` + Name string `json:"name"` + RawArguments json.RawMessage `json:"arguments,omitempty"` + } + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + c.Meta = raw.Meta + c.Name = raw.Name + c.Arguments = raw.RawArguments + return nil +} -type CallToolResultFor[Out any] struct { +// The server's response to a tool call. +type CallToolResult struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. Meta `json:"_meta,omitempty"` @@ -62,7 +74,7 @@ type CallToolResultFor[Out any] struct { Content []Content `json:"content"` // An optional JSON object that represents the structured result of the tool // call. - StructuredContent Out `json:"structuredContent,omitempty"` + StructuredContent any `json:"structuredContent,omitempty"` // Whether the tool call ended in an error. // // If not set, this is assumed to be false (the call was successful). @@ -78,12 +90,12 @@ type CallToolResultFor[Out any] struct { IsError bool `json:"isError,omitempty"` } -func (*CallToolResultFor[Out]) isResult() {} +func (*CallToolResult) isResult() {} // UnmarshalJSON handles the unmarshalling of content into the Content // interface. -func (x *CallToolResultFor[Out]) UnmarshalJSON(data []byte) error { - type res CallToolResultFor[Out] // avoid recursion +func (x *CallToolResult) UnmarshalJSON(data []byte) error { + type res CallToolResult // avoid recursion var wire struct { res Content []*wireContent `json:"content"` @@ -95,13 +107,13 @@ func (x *CallToolResultFor[Out]) UnmarshalJSON(data []byte) error { if wire.res.Content, err = contentsFromWire(wire.Content, nil); err != nil { return err } - *x = CallToolResultFor[Out](wire.res) + *x = CallToolResult(wire.res) return nil } -func (x *CallToolParamsFor[Out]) isParams() {} -func (x *CallToolParamsFor[Out]) GetProgressToken() any { return getProgressToken(x) } -func (x *CallToolParamsFor[Out]) SetProgressToken(t any) { setProgressToken(x, t) } +func (x *CallToolParams) isParams() {} +func (x *CallToolParams) GetProgressToken() any { return getProgressToken(x) } +func (x *CallToolParams) SetProgressToken(t any) { setProgressToken(x, t) } type CancelledParams struct { // This property is reserved by the protocol to allow clients and servers to @@ -867,6 +879,8 @@ type Tool struct { // If not provided, Annotations.Title should be used for display if present, // otherwise Name. Title string `json:"title,omitempty"` + + newArgs func() any } // Additional properties describing a Tool to clients. diff --git a/mcp/protocol_test.go b/mcp/protocol_test.go index dba80a8b..cd9b5146 100644 --- a/mcp/protocol_test.go +++ b/mcp/protocol_test.go @@ -208,6 +208,7 @@ func TestCompleteReference(t *testing.T) { }) } } + func TestCompleteParams(t *testing.T) { // Define test cases specifically for Marshalling marshalTests := []struct { @@ -514,13 +515,15 @@ func TestContentUnmarshal(t *testing.T) { var got CallToolResult roundtrip(ctr, &got) - ctrf := &CallToolResultFor[int]{ - Meta: Meta{"m": true}, - Content: content, - IsError: true, - StructuredContent: 3, + ctrf := &CallToolResult{ + Meta: Meta{"m": true}, + Content: content, + IsError: true, + // Ints become floats with zero fractional part when unmarshaled. + // The jsoncschema package will validate these against a schema with type "integer". + StructuredContent: float64(3), } - var gotf CallToolResultFor[int] + var gotf CallToolResult roundtrip(ctrf, &gotf) pm := &PromptMessage{ diff --git a/mcp/server.go b/mcp/server.go index e39372dc..9d7ed9ed 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -9,7 +9,6 @@ import ( "context" "encoding/base64" "encoding/gob" - "encoding/json" "fmt" "iter" "maps" @@ -145,16 +144,7 @@ func (s *Server) RemovePrompts(names ...string) { // or one where any input is valid, set [Tool.InputSchema] to the empty schema, // &jsonschema.Schema{}. func (s *Server) AddTool(t *Tool, h ToolHandler) { - if t.InputSchema == nil { - // This prevents the tool author from forgetting to write a schema where - // one should be provided. If we papered over this by supplying the empty - // schema, then every input would be validated and the problem wouldn't be - // discovered until runtime, when the LLM sent bad data. - panic(fmt.Sprintf("adding tool %q: nil input schema", t.Name)) - } - if err := addToolErr(s, t, h); err != nil { - panic(err) - } + s.addServerTool(newServerTool(t, h)) } // AddTool adds a [Tool] to the server, or replaces one with the same name. @@ -163,17 +153,17 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) { // If the tool's output schema is nil and the Out type parameter is not the empty // interface, then the output schema is set to the schema inferred from Out. // The Tool argument must not be modified after this call. -func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) { - if err := addToolErr(s, t, h); err != nil { - panic(err) - } +// +// The handler should return the result as the second return value. The first return value, +// a *CallToolResult, may be nil, or its fields other than StructuredContent may be +// populated. +func AddTool[In, Out any](s *Server, t *Tool, h TypedToolHandler[In, Out]) { + s.addServerTool(newTypedServerTool(t, h)) } -func addToolErr[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) (err error) { - defer util.Wrapf(&err, "adding tool %q", t.Name) - st, err := newServerTool(t, h) +func (s *Server) addServerTool(st *serverTool, err error) { if err != nil { - return err + panic(fmt.Sprintf("adding tool %q: %v", st.tool.Name, err)) } // Assume there was a change, since add replaces existing tools. // (It's possible a tool was replaced with an identical one, but not worth checking.) @@ -181,7 +171,6 @@ func addToolErr[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) (err // TODO: Surface notify error here? best not, in case we need to batch. s.changeAndNotify(notificationToolListChanged, &ToolListChangedParams{}, func() bool { s.tools.add(st); return true }) - return nil } // RemoveTools removes the tools with the given names. @@ -326,7 +315,7 @@ func (s *Server) listTools(_ context.Context, req *ServerRequest[*ListToolsParam }) } -func (s *Server) callTool(ctx context.Context, req *ServerRequest[*CallToolParamsFor[json.RawMessage]]) (*CallToolResult, error) { +func (s *Server) callTool(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { s.mu.Lock() st, ok := s.tools.get(req.Params.Name) s.mu.Unlock() @@ -612,7 +601,7 @@ func (ss *ServerSession) initialized(ctx context.Context, params *InitializedPar return nil, fmt.Errorf("duplicate %q received", notificationInitialized) } if h := ss.server.opts.InitializedHandler; h != nil { - h(ctx, serverRequestFor(ss, params)) + h(ctx, newServerRequest(ss, params)) } return nil, nil } @@ -626,7 +615,7 @@ func (s *Server) callRootsListChangedHandler(ctx context.Context, req *ServerReq func (ss *ServerSession) callProgressNotificationHandler(ctx context.Context, p *ProgressNotificationParams) (Result, error) { if h := ss.server.opts.ProgressNotificationHandler; h != nil { - h(ctx, serverRequestFor(ss, p)) + h(ctx, newServerRequest(ss, p)) } return nil, nil } diff --git a/mcp/server_example_test.go b/mcp/server_example_test.go index f735b84e..2b4a0bf1 100644 --- a/mcp/server_example_test.go +++ b/mcp/server_example_test.go @@ -16,12 +16,12 @@ type SayHiParams struct { Name string `json:"name"` } -func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[SayHiParams]]) (*mcp.CallToolResultFor[any], error) { - return &mcp.CallToolResultFor[any]{ +func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args SayHiParams) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ Content: []mcp.Content{ - &mcp.TextContent{Text: "Hi " + req.Params.Arguments.Name}, + &mcp.TextContent{Text: "Hi " + args.Name}, }, - }, nil + }, nil, nil } func ExampleServer() { diff --git a/mcp/shared.go b/mcp/shared.go index ca062214..ea4975ef 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -408,10 +408,6 @@ func (r *ServerRequest[P]) GetSession() Session { return r.Session } func (r *ClientRequest[P]) GetParams() Params { return r.Params } func (r *ServerRequest[P]) GetParams() Params { return r.Params } -func serverRequestFor[P Params](s *ServerSession, p P) *ServerRequest[P] { - return &ServerRequest[P]{Session: s, Params: p} -} - func clientRequestFor[P Params](s *ClientSession, p P) *ClientRequest[P] { return &ClientRequest[P]{Session: s, Params: p} } diff --git a/mcp/shared_test.go b/mcp/shared_test.go index 01d1eff7..2bec742f 100644 --- a/mcp/shared_test.go +++ b/mcp/shared_test.go @@ -4,232 +4,221 @@ package mcp -import ( - "context" - "encoding/json" - "fmt" - "strings" - "testing" -) - // TODO(jba): this shouldn't be in this file, but tool_test.go doesn't have access to unexported symbols. -func TestToolValidate(t *testing.T) { - // Check that the tool returned from NewServerTool properly validates its input schema. - - type req struct { - I int - B bool - S string `json:",omitempty"` - P *int `json:",omitempty"` - } - - dummyHandler := func(context.Context, *ServerRequest[*CallToolParamsFor[req]]) (*CallToolResultFor[any], error) { - return nil, nil - } - - st, err := newServerTool(&Tool{Name: "test", Description: "test"}, dummyHandler) - if err != nil { - t.Fatal(err) - } - - for _, tt := range []struct { - desc string - args map[string]any - want string // error should contain this string; empty for success - }{ - { - "both required", - map[string]any{"I": 1, "B": true}, - "", - }, - { - "optional", - map[string]any{"I": 1, "B": true, "S": "foo"}, - "", - }, - { - "wrong type", - map[string]any{"I": 1.5, "B": true}, - "cannot unmarshal", - }, - { - "extra property", - map[string]any{"I": 1, "B": true, "C": 2}, - "unknown field", - }, - { - "value for pointer", - map[string]any{"I": 1, "B": true, "P": 3}, - "", - }, - { - "null for pointer", - map[string]any{"I": 1, "B": true, "P": nil}, - "", - }, - } { - t.Run(tt.desc, func(t *testing.T) { - raw, err := json.Marshal(tt.args) - if err != nil { - t.Fatal(err) - } - _, err = st.handler(context.Background(), &ServerRequest[*CallToolParamsFor[json.RawMessage]]{ - Params: &CallToolParamsFor[json.RawMessage]{Arguments: json.RawMessage(raw)}, - }) - if err == nil && tt.want != "" { - t.Error("got success, wanted failure") - } - if err != nil { - if tt.want == "" { - t.Fatalf("failed with:\n%s\nwanted success", err) - } - if !strings.Contains(err.Error(), tt.want) { - t.Fatalf("got:\n%s\nwanted to contain %q", err, tt.want) - } - } - }) - } -} +// func TestToolValidate(t *testing.T) { +// // Check that the tool returned from NewServerTool properly validates its input schema. + +// type req struct { +// I int +// B bool +// S string `json:",omitempty"` +// P *int `json:",omitempty"` +// } + +// dummyHandler := func(context.Context, *ServerRequest[*CallToolParamsFor[req]]) (*CallToolResultFor[any], error) { +// return nil, nil +// } + +// st, err := newServerTool(&Tool{Name: "test", Description: "test"}, dummyHandler) +// if err != nil { +// t.Fatal(err) +// } + +// for _, tt := range []struct { +// desc string +// args map[string]any +// want string // error should contain this string; empty for success +// }{ +// { +// "both required", +// map[string]any{"I": 1, "B": true}, +// "", +// }, +// { +// "optional", +// map[string]any{"I": 1, "B": true, "S": "foo"}, +// "", +// }, +// { +// "wrong type", +// map[string]any{"I": 1.5, "B": true}, +// "cannot unmarshal", +// }, +// { +// "extra property", +// map[string]any{"I": 1, "B": true, "C": 2}, +// "unknown field", +// }, +// { +// "value for pointer", +// map[string]any{"I": 1, "B": true, "P": 3}, +// "", +// }, +// { +// "null for pointer", +// map[string]any{"I": 1, "B": true, "P": nil}, +// "", +// }, +// } { +// t.Run(tt.desc, func(t *testing.T) { +// raw, err := json.Marshal(tt.args) +// if err != nil { +// t.Fatal(err) +// } +// _, err = st.handler(context.Background(), &ServerRequest[*CallToolParamsFor[json.RawMessage]]{ +// Params: &CallToolParamsFor[json.RawMessage]{Arguments: json.RawMessage(raw)}, +// }) +// if err == nil && tt.want != "" { +// t.Error("got success, wanted failure") +// } +// if err != nil { +// if tt.want == "" { +// t.Fatalf("failed with:\n%s\nwanted success", err) +// } +// if !strings.Contains(err.Error(), tt.want) { +// t.Fatalf("got:\n%s\nwanted to contain %q", err, tt.want) +// } +// } +// }) +// } +// } // TestNilParamsHandling tests that nil parameters don't cause panic in unmarshalParams. // This addresses a vulnerability where missing or null parameters could crash the server. -func TestNilParamsHandling(t *testing.T) { - // Define test types for clarity - type TestArgs struct { - Name string `json:"name"` - Value int `json:"value"` - } - type TestParams = *CallToolParamsFor[TestArgs] - type TestResult = *CallToolResultFor[string] - - // Simple test handler - testHandler := func(ctx context.Context, req *ServerRequest[TestParams]) (TestResult, error) { - result := "processed: " + req.Params.Arguments.Name - return &CallToolResultFor[string]{StructuredContent: result}, nil - } - - methodInfo := newServerMethodInfo(testHandler, missingParamsOK) - - // Helper function to test that unmarshalParams doesn't panic and handles nil gracefully - mustNotPanic := func(t *testing.T, rawMsg json.RawMessage, expectNil bool) Params { - t.Helper() - - defer func() { - if r := recover(); r != nil { - t.Fatalf("unmarshalParams panicked: %v", r) - } - }() - - params, err := methodInfo.unmarshalParams(rawMsg) - if err != nil { - t.Fatalf("unmarshalParams failed: %v", err) - } - - if expectNil { - if params != nil { - t.Fatalf("Expected nil params, got %v", params) - } - return params - } - - if params == nil { - t.Fatal("unmarshalParams returned unexpected nil") - } - - // Verify the result can be used safely - typedParams := params.(TestParams) - _ = typedParams.Name - _ = typedParams.Arguments.Name - _ = typedParams.Arguments.Value - - return params - } - - // Test different nil parameter scenarios - with missingParamsOK flag, nil/null should return nil - t.Run("missing_params", func(t *testing.T) { - mustNotPanic(t, nil, true) // Expect nil with missingParamsOK flag - }) - - t.Run("explicit_null", func(t *testing.T) { - mustNotPanic(t, json.RawMessage(`null`), true) // Expect nil with missingParamsOK flag - }) - - t.Run("empty_object", func(t *testing.T) { - mustNotPanic(t, json.RawMessage(`{}`), false) // Empty object should create valid params - }) - - t.Run("valid_params", func(t *testing.T) { - rawMsg := json.RawMessage(`{"name":"test","arguments":{"name":"hello","value":42}}`) - params := mustNotPanic(t, rawMsg, false) - - // For valid params, also verify the values are parsed correctly - typedParams := params.(TestParams) - if typedParams.Name != "test" { - t.Errorf("Expected name 'test', got %q", typedParams.Name) - } - if typedParams.Arguments.Name != "hello" { - t.Errorf("Expected argument name 'hello', got %q", typedParams.Arguments.Name) - } - if typedParams.Arguments.Value != 42 { - t.Errorf("Expected argument value 42, got %d", typedParams.Arguments.Value) - } - }) -} +// func TestNilParamsHandling(t *testing.T) { +// // Define test types for clarity +// type TestArgs struct { +// Name string `json:"name"` +// Value int `json:"value"` +// } + +// // Simple test handler +// testHandler := func(ctx context.Context, req *ServerRequest[*CallToolParams], args TestArgs) (*CallToolResult, string, error) { +// result := "processed: " + args.Name +// return nil, result, nil +// } + +// methodInfo := newServerMethodInfo(testHandler, missingParamsOK) + +// // Helper function to test that unmarshalParams doesn't panic and handles nil gracefully +// mustNotPanic := func(t *testing.T, rawMsg json.RawMessage, expectNil bool) Params { +// t.Helper() + +// defer func() { +// if r := recover(); r != nil { +// t.Fatalf("unmarshalParams panicked: %v", r) +// } +// }() + +// params, err := methodInfo.unmarshalParams(rawMsg) +// if err != nil { +// t.Fatalf("unmarshalParams failed: %v", err) +// } + +// if expectNil { +// if params != nil { +// t.Fatalf("Expected nil params, got %v", params) +// } +// return params +// } + +// if params == nil { +// t.Fatal("unmarshalParams returned unexpected nil") +// } + +// // Verify the result can be used safely +// typedParams := params.(TestParams) +// _ = typedParams.Name +// _ = typedParams.Arguments.Name +// _ = typedParams.Arguments.Value + +// return params +// } + +// // Test different nil parameter scenarios - with missingParamsOK flag, nil/null should return nil +// t.Run("missing_params", func(t *testing.T) { +// mustNotPanic(t, nil, true) // Expect nil with missingParamsOK flag +// }) + +// t.Run("explicit_null", func(t *testing.T) { +// mustNotPanic(t, json.RawMessage(`null`), true) // Expect nil with missingParamsOK flag +// }) + +// t.Run("empty_object", func(t *testing.T) { +// mustNotPanic(t, json.RawMessage(`{}`), false) // Empty object should create valid params +// }) + +// t.Run("valid_params", func(t *testing.T) { +// rawMsg := json.RawMessage(`{"name":"test","arguments":{"name":"hello","value":42}}`) +// params := mustNotPanic(t, rawMsg, false) + +// // For valid params, also verify the values are parsed correctly +// typedParams := params.(TestParams) +// if typedParams.Name != "test" { +// t.Errorf("Expected name 'test', got %q", typedParams.Name) +// } +// if typedParams.Arguments.Name != "hello" { +// t.Errorf("Expected argument name 'hello', got %q", typedParams.Arguments.Name) +// } +// if typedParams.Arguments.Value != 42 { +// t.Errorf("Expected argument value 42, got %d", typedParams.Arguments.Value) +// } +// }) +// } // TestNilParamsEdgeCases tests edge cases to ensure we don't over-fix -func TestNilParamsEdgeCases(t *testing.T) { - type TestArgs struct { - Name string `json:"name"` - Value int `json:"value"` - } - type TestParams = *CallToolParamsFor[TestArgs] - - testHandler := func(context.Context, *ServerRequest[TestParams]) (*CallToolResultFor[string], error) { - return &CallToolResultFor[string]{StructuredContent: "test"}, nil - } - - methodInfo := newServerMethodInfo(testHandler, missingParamsOK) - - // These should fail normally, not be treated as nil params - invalidCases := []json.RawMessage{ - json.RawMessage(""), // empty string - should error - json.RawMessage("[]"), // array - should error - json.RawMessage(`"null"`), // string "null" - should error - json.RawMessage("0"), // number - should error - json.RawMessage("false"), // boolean - should error - } - - for i, rawMsg := range invalidCases { - t.Run(fmt.Sprintf("invalid_case_%d", i), func(t *testing.T) { - params, err := methodInfo.unmarshalParams(rawMsg) - if err == nil && params == nil { - t.Error("Should not return nil params without error") - } - }) - } - - // Test that methods without missingParamsOK flag properly reject nil params - t.Run("reject_when_params_required", func(t *testing.T) { - methodInfoStrict := newServerMethodInfo(testHandler, 0) // No missingParamsOK flag - - testCases := []struct { - name string - params json.RawMessage - }{ - {"nil_params", nil}, - {"null_params", json.RawMessage(`null`)}, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - _, err := methodInfoStrict.unmarshalParams(tc.params) - if err == nil { - t.Error("Expected error for required params, got nil") - } - if !strings.Contains(err.Error(), "missing required \"params\"") { - t.Errorf("Expected 'missing required params' error, got: %v", err) - } - }) - } - }) -} +// func TestNilParamsEdgeCases(t *testing.T) { +// type TestArgs struct { +// Name string `json:"name"` +// Value int `json:"value"` +// } + +// testHandler := func(context.Context, *ServerRequest[*CallToolParams], TestArgs) (*CallToolResult, string, error) { +// return nil, "test", nil +// } + +// methodInfo := newServerMethodInfo(testHandler, missingParamsOK) + +// // These should fail normally, not be treated as nil params +// invalidCases := []json.RawMessage{ +// json.RawMessage(""), // empty string - should error +// json.RawMessage("[]"), // array - should error +// json.RawMessage(`"null"`), // string "null" - should error +// json.RawMessage("0"), // number - should error +// json.RawMessage("false"), // boolean - should error +// } + +// for i, rawMsg := range invalidCases { +// t.Run(fmt.Sprintf("invalid_case_%d", i), func(t *testing.T) { +// params, err := methodInfo.unmarshalParams(rawMsg) +// if err == nil && params == nil { +// t.Error("Should not return nil params without error") +// } +// }) +// } + +// // Test that methods without missingParamsOK flag properly reject nil params +// t.Run("reject_when_params_required", func(t *testing.T) { +// methodInfoStrict := newServerMethodInfo(testHandler, 0) // No missingParamsOK flag + +// testCases := []struct { +// name string +// params json.RawMessage +// }{ +// {"nil_params", nil}, +// {"null_params", json.RawMessage(`null`)}, +// } + +// for _, tc := range testCases { +// t.Run(tc.name, func(t *testing.T) { +// _, err := methodInfoStrict.unmarshalParams(tc.params) +// if err == nil { +// t.Error("Expected error for required params, got nil") +// } +// if !strings.Contains(err.Error(), "missing required \"params\"") { +// t.Errorf("Expected 'missing required params' error, got: %v", err) +// } +// }) +// } +// }) +// } diff --git a/mcp/sse_example_test.go b/mcp/sse_example_test.go index b5dfdc56..aa1a770b 100644 --- a/mcp/sse_example_test.go +++ b/mcp/sse_example_test.go @@ -18,12 +18,12 @@ type AddParams struct { X, Y int } -func Add(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[AddParams]]) (*mcp.CallToolResultFor[any], error) { - return &mcp.CallToolResultFor[any]{ +func Add(ctx context.Context, _ *mcp.ServerRequest[*mcp.CallToolParams], args AddParams) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ Content: []mcp.Content{ - &mcp.TextContent{Text: fmt.Sprintf("%d", req.Params.Arguments.X+req.Params.Arguments.Y)}, + &mcp.TextContent{Text: fmt.Sprintf("%d", args.X+args.Y)}, }, - }, nil + }, nil, nil } func ExampleSSEHandler() { diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 25dd224e..c5cf669b 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -40,7 +40,7 @@ func TestStreamableTransports(t *testing.T) { // 1. Create a server with a simple "greet" tool. server := NewServer(testImpl, nil) AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) - AddTool(server, &Tool{Name: "sample"}, func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + AddTool(server, &Tool{Name: "sample"}, func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, any, error) { // Test that we can make sampling requests during tool handling. // // Try this on both the request context and a background context, so @@ -51,13 +51,13 @@ func TestStreamableTransports(t *testing.T) { } { res, err := req.Session.CreateMessage(ctx, &CreateMessageParams{}) if err != nil { - return nil, err + return nil, nil, err } if g, w := res.Model, "aModel"; g != w { - return nil, fmt.Errorf("got %q, want %q", g, w) + return nil, nil, fmt.Errorf("got %q, want %q", g, w) } } - return &CallToolResultFor[any]{}, nil + return &CallToolResult{}, nil, nil }) // 2. Start an httptest.Server with the StreamableHTTPHandler, wrapped in a @@ -172,7 +172,7 @@ func TestClientReplay(t *testing.T) { serverReadyToKillProxy := make(chan struct{}) serverClosed := make(chan struct{}) server.AddTool(&Tool{Name: "multiMessageTool", InputSchema: &jsonschema.Schema{}}, - func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResult, error) { + func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, error) { go func() { bgCtx := context.Background() // Send the first two messages immediately. @@ -283,7 +283,7 @@ func TestServerInitiatedSSE(t *testing.T) { } defer clientSession.Close() server.AddTool(&Tool{Name: "testTool", InputSchema: &jsonschema.Schema{}}, - func(context.Context, *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResult, error) { + func(context.Context, *ServerRequest[*CallToolParams], any) (*CallToolResult, error) { return &CallToolResult{}, nil }) receivedNotifications := readNotifications(t, ctx, notifications, 1) @@ -546,11 +546,11 @@ func TestStreamableServerTransport(t *testing.T) { // Create a server containing a single tool, which runs the test tool // behavior, if any. server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) - AddTool(server, &Tool{Name: "tool"}, func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[any]]) (*CallToolResultFor[any], error) { + AddTool(server, &Tool{Name: "tool"}, func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, any, error) { if test.tool != nil { test.tool(t, ctx, req.Session) } - return &CallToolResultFor[any]{}, nil + return &CallToolResult{}, nil, nil }) // Start the streamable handler. @@ -866,8 +866,8 @@ func TestStreamableStateless(t *testing.T) { // This version of sayHi doesn't make a ping request (we can't respond to // that request from our client). - sayHi := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[hiParams]]) (*CallToolResultFor[any], error) { - return &CallToolResultFor[any]{Content: []Content{&TextContent{Text: "hi " + req.Params.Arguments.Name}}}, nil + sayHi := func(ctx context.Context, req *ServerRequest[*CallToolParams], args hiParams) (*CallToolResult, any, error) { + return &CallToolResult{Content: []Content{&TextContent{Text: "hi " + args.Name}}}, nil, nil } server := NewServer(testImpl, nil) AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) diff --git a/mcp/tool.go b/mcp/tool.go index 15f17e11..b09eccb3 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -8,6 +8,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "reflect" @@ -15,17 +16,16 @@ import ( ) // A ToolHandler handles a call to tools/call. -// [CallToolParams.Arguments] will contain a map[string]any that has been validated -// against the input schema. -type ToolHandler = ToolHandlerFor[map[string]any, any] +// req.Params.Arguments will contain a json.RawMessage containing the arguments. +// args will contain a value that has been validated against the input schema. +type ToolHandler func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, error) -// A ToolHandlerFor handles a call to tools/call with typed arguments and results. -type ToolHandlerFor[In, Out any] func(context.Context, *ServerRequest[*CallToolParamsFor[In]]) (*CallToolResultFor[Out], error) +type CallToolRequest struct { + Session *ServerSession + Params *CallToolParams +} -// A rawToolHandler is like a ToolHandler, but takes the arguments as as json.RawMessage. -// Second arg is *Request[*ServerSession, *CallToolParamsFor[json.RawMessage]], but that creates -// a cycle. -type rawToolHandler = func(context.Context, any) (*CallToolResult, error) +type rawToolHandler func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) // A serverTool is a tool definition that is bound to a tool handler. type serverTool struct { @@ -35,40 +35,40 @@ type serverTool struct { inputResolved, outputResolved *jsonschema.Resolved } -// newServerTool creates a serverTool from a tool and a handler. -// If the tool doesn't have an input schema, it is inferred from In. -// If the tool doesn't have an output schema and Out != any, it is inferred from Out. -func newServerTool[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*serverTool, error) { - st := &serverTool{tool: t} +// A TypedToolHandler handles a call to tools/call with typed arguments and results. +type TypedToolHandler[In, Out any] func(context.Context, *ServerRequest[*CallToolParams], In) (*CallToolResult, Out, error) - if err := setSchema[In](&t.InputSchema, &st.inputResolved); err != nil { - return nil, err +func newServerTool(t *Tool, h ToolHandler) (*serverTool, error) { + st := &serverTool{tool: t} + if t.newArgs == nil { + t.newArgs = func() any { return &map[string]any{} } } - if reflect.TypeFor[Out]() != reflect.TypeFor[any]() { - if err := setSchema[Out](&t.OutputSchema, &st.outputResolved); err != nil { - return nil, err - } + if t.InputSchema == nil { + // This prevents the tool author from forgetting to write a schema where + // one should be provided. If we papered over this by supplying the empty + // schema, then every input would be validated and the problem wouldn't be + // discovered until runtime, when the LLM sent bad data. + return nil, errors.New("missing input schema") } - - st.handler = func(ctx context.Context, areq any) (*CallToolResult, error) { - req := areq.(*ServerRequest[*CallToolParamsFor[json.RawMessage]]) - var args In - if req.Params.Arguments != nil { - if err := unmarshalSchema(req.Params.Arguments, st.inputResolved, &args); err != nil { - return nil, err - } - } - // TODO(jba): future-proof this copy. - params := &CallToolParamsFor[In]{ - Meta: req.Params.Meta, - Name: req.Params.Name, - Arguments: args, + var err error + st.inputResolved, err = t.InputSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) + if err != nil { + return nil, fmt.Errorf("input schema: %w", err) + } + if t.OutputSchema != nil { + st.outputResolved, err = t.OutputSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) + } + if err != nil { + return nil, fmt.Errorf("output schema: %w", err) + } + // Ignore output schema. + st.handler = func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + rawArgs := req.Params.Arguments.(json.RawMessage) + args := t.newArgs() + if err := unmarshalSchema(rawArgs, st.inputResolved, args); err != nil { + return nil, err } - // TODO(jba): improve copy - res, err := h(ctx, &ServerRequest[*CallToolParamsFor[In]]{ - Session: req.Session, - Params: params, - }) + res, err := h(ctx, req, args) // TODO(rfindley): investigate why server errors are embedded in this strange way, // rather than returned as jsonrpc2 server errors. if err != nil { @@ -77,32 +77,45 @@ func newServerTool[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*serverTool IsError: true, }, nil } - var ctr CallToolResult - // TODO(jba): What if res == nil? Is that valid? // TODO(jba): if t.OutputSchema != nil, check that StructuredContent is present and validates. - if res != nil { - // TODO(jba): future-proof this copy. - ctr.Meta = res.Meta - ctr.Content = res.Content - ctr.IsError = res.IsError - ctr.StructuredContent = res.StructuredContent - } - return &ctr, nil + return res, nil } - return st, nil } -func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) error { +// newTypedServerTool creates a serverTool from a tool and a handler. +// If the tool doesn't have an input schema, it is inferred from In. +// If the tool doesn't have an output schema and Out != any, it is inferred from Out. +func newTypedServerTool[In, Out any](t *Tool, h TypedToolHandler[In, Out]) (*serverTool, error) { + assert(t.newArgs == nil, "newArgs is nil") + t.newArgs = func() any { var x In; return &x } + var err error - if *sfield == nil { - *sfield, err = jsonschema.For[T](nil) + t.InputSchema, err = jsonschema.For[In](nil) + if err != nil { + return nil, err + } + if reflect.TypeFor[Out]() != reflect.TypeFor[any]() { + t.OutputSchema, err = jsonschema.For[Out](nil) } if err != nil { - return err + return nil, err + } + + toolHandler := func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, error) { + res, out, err := h(ctx, req, *args.(*In)) + if err != nil { + return nil, err + } + if res == nil { + res = &CallToolResult{} + } + // TODO: return the serialized JSON in a TextContent block, as per spec? + // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content + res.StructuredContent = out + return res, nil } - *rfield, err = (*sfield).Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) - return err + return newServerTool(t, toolHandler) } // unmarshalSchema unmarshals data into v and validates the result according to @@ -120,6 +133,7 @@ func unmarshalSchema(data json.RawMessage, resolved *jsonschema.Resolved, v any) if err := dec.Decode(v); err != nil { return fmt.Errorf("unmarshaling: %w", err) } + // TODO: test with nil args. if resolved != nil { if err := resolved.ApplyDefaults(v); err != nil { diff --git a/mcp/tool_test.go b/mcp/tool_test.go index 609536cc..e98410bb 100644 --- a/mcp/tool_test.go +++ b/mcp/tool_test.go @@ -16,13 +16,13 @@ import ( ) // testToolHandler is used for type inference in TestNewServerTool. -func testToolHandler[In, Out any](context.Context, *ServerRequest[*CallToolParamsFor[In]]) (*CallToolResultFor[Out], error) { +func testToolHandler[In, Out any](context.Context, *ServerRequest[*CallToolParams], In) (*CallToolResult, Out, error) { panic("not implemented") } -func srvTool[In, Out any](t *testing.T, tool *Tool, handler ToolHandlerFor[In, Out]) *serverTool { +func srvTool[In, Out any](t *testing.T, tool *Tool, handler TypedToolHandler[In, Out]) *serverTool { t.Helper() - st, err := newServerTool(tool, handler) + st, err := newTypedServerTool(tool, handler) if err != nil { t.Fatal(err) } From d95bc2bd60b2b5d1feed3ad9ea86dc5d58cae04d Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 14 Aug 2025 06:49:04 -0400 Subject: [PATCH 2/2] mcp: remove tool genericity, cont. Added TypedTool, fixed tests. TODOs for followups: - Rewrite TestToolValidate. - Re-fix the bug from adding a duplicate tool. --- examples/server/custom-transport/main.go | 8 +- examples/server/hello/main.go | 8 +- examples/server/memory/kb.go | 105 +++---- examples/server/memory/kb_test.go | 241 ++++++++-------- examples/server/sequentialthinking/main.go | 40 ++- .../server/sequentialthinking/main_test.go | 85 +----- internal/readme/server/server.go | 8 +- mcp/client.go | 3 +- mcp/client_list_test.go | 12 +- mcp/content.go | 3 + mcp/content_nil_test.go | 224 +++++++++++++++ mcp/mcp_test.go | 15 +- mcp/server.go | 32 ++- mcp/shared.go | 44 +-- mcp/shared_test.go | 263 +++++++++--------- mcp/streamable_test.go | 17 ++ mcp/tool.go | 37 ++- mcp/tool_test.go | 2 +- 18 files changed, 666 insertions(+), 481 deletions(-) create mode 100644 mcp/content_nil_test.go diff --git a/examples/server/custom-transport/main.go b/examples/server/custom-transport/main.go index bf0306cf..72cfc31d 100644 --- a/examples/server/custom-transport/main.go +++ b/examples/server/custom-transport/main.go @@ -85,12 +85,12 @@ type HiArgs struct { } // SayHi is a tool handler that responds with a greeting. -func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[HiArgs]]) (*mcp.CallToolResultFor[struct{}], error) { - return &mcp.CallToolResultFor[struct{}]{ +func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args HiArgs) (*mcp.CallToolResult, struct{}, error) { + return &mcp.CallToolResult{ Content: []mcp.Content{ - &mcp.TextContent{Text: "Hi " + req.Params.Arguments.Name}, + &mcp.TextContent{Text: "Hi " + args.Name}, }, - }, nil + }, struct{}{}, nil } func main() { diff --git a/examples/server/hello/main.go b/examples/server/hello/main.go index 8125441b..d0b20377 100644 --- a/examples/server/hello/main.go +++ b/examples/server/hello/main.go @@ -22,12 +22,12 @@ type HiArgs struct { Name string `json:"name" jsonschema:"the name to say hi to"` } -func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[HiArgs]]) (*mcp.CallToolResultFor[struct{}], error) { - return &mcp.CallToolResultFor[struct{}]{ +func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args HiArgs) (*mcp.CallToolResult, struct{}, error) { + return &mcp.CallToolResult{ Content: []mcp.Content{ - &mcp.TextContent{Text: "Hi " + req.Params.Arguments.Name}, + &mcp.TextContent{Text: "Hi " + args.Name}, }, - }, nil + }, struct{}{}, nil } func PromptHi(ctx context.Context, ss *mcp.ServerSession, params *mcp.GetPromptParams) (*mcp.GetPromptResult, error) { diff --git a/examples/server/memory/kb.go b/examples/server/memory/kb.go index f053bee5..b4a02cdc 100644 --- a/examples/server/memory/kb.go +++ b/examples/server/memory/kb.go @@ -431,152 +431,137 @@ func (k knowledgeBase) openNodes(names []string) (KnowledgeGraph, error) { }, nil } -func (k knowledgeBase) CreateEntities(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[CreateEntitiesArgs]]) (*mcp.CallToolResultFor[CreateEntitiesResult], error) { - var res mcp.CallToolResultFor[CreateEntitiesResult] +func (k knowledgeBase) CreateEntities(ctx context.Context, _ *mcp.ServerRequest[*mcp.CallToolParams], args CreateEntitiesArgs) (*mcp.CallToolResult, CreateEntitiesResult, error) { + var res mcp.CallToolResult - entities, err := k.createEntities(req.Params.Arguments.Entities) + entities, err := k.createEntities(args.Entities) if err != nil { - return nil, err + return nil, CreateEntitiesResult{}, err } res.Content = []mcp.Content{ &mcp.TextContent{Text: "Entities created successfully"}, } - res.StructuredContent = CreateEntitiesResult{ - Entities: entities, - } - - return &res, nil + return &res, CreateEntitiesResult{Entities: entities}, nil } -func (k knowledgeBase) CreateRelations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[CreateRelationsArgs]]) (*mcp.CallToolResultFor[CreateRelationsResult], error) { - var res mcp.CallToolResultFor[CreateRelationsResult] +func (k knowledgeBase) CreateRelations(ctx context.Context, _ *mcp.ServerRequest[*mcp.CallToolParams], args CreateRelationsArgs) (*mcp.CallToolResult, CreateRelationsResult, error) { + var res mcp.CallToolResult - relations, err := k.createRelations(req.Params.Arguments.Relations) + relations, err := k.createRelations(args.Relations) if err != nil { - return nil, err + return nil, CreateRelationsResult{}, err } res.Content = []mcp.Content{ &mcp.TextContent{Text: "Relations created successfully"}, } - res.StructuredContent = CreateRelationsResult{ - Relations: relations, - } - - return &res, nil + return &res, CreateRelationsResult{Relations: relations}, nil } -func (k knowledgeBase) AddObservations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[AddObservationsArgs]]) (*mcp.CallToolResultFor[AddObservationsResult], error) { - var res mcp.CallToolResultFor[AddObservationsResult] +func (k knowledgeBase) AddObservations(ctx context.Context, _ *mcp.ServerRequest[*mcp.CallToolParams], args AddObservationsArgs) (*mcp.CallToolResult, AddObservationsResult, error) { + var res mcp.CallToolResult - observations, err := k.addObservations(req.Params.Arguments.Observations) + observations, err := k.addObservations(args.Observations) if err != nil { - return nil, err + return nil, AddObservationsResult{}, err } res.Content = []mcp.Content{ &mcp.TextContent{Text: "Observations added successfully"}, } - res.StructuredContent = AddObservationsResult{ + return &res, AddObservationsResult{ Observations: observations, - } - - return &res, nil + }, nil } -func (k knowledgeBase) DeleteEntities(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[DeleteEntitiesArgs]]) (*mcp.CallToolResultFor[struct{}], error) { - var res mcp.CallToolResultFor[struct{}] +func (k knowledgeBase) DeleteEntities(ctx context.Context, _ *mcp.ServerRequest[*mcp.CallToolParams], args DeleteEntitiesArgs) (*mcp.CallToolResult, any, error) { + var res mcp.CallToolResult - err := k.deleteEntities(req.Params.Arguments.EntityNames) + err := k.deleteEntities(args.EntityNames) if err != nil { - return nil, err + return nil, nil, err } res.Content = []mcp.Content{ &mcp.TextContent{Text: "Entities deleted successfully"}, } - return &res, nil + return &res, nil, nil } -func (k knowledgeBase) DeleteObservations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[DeleteObservationsArgs]]) (*mcp.CallToolResultFor[struct{}], error) { - var res mcp.CallToolResultFor[struct{}] +func (k knowledgeBase) DeleteObservations(ctx context.Context, _ *mcp.ServerRequest[*mcp.CallToolParams], args DeleteObservationsArgs) (*mcp.CallToolResult, any, error) { + var res mcp.CallToolResult - err := k.deleteObservations(req.Params.Arguments.Deletions) + err := k.deleteObservations(args.Deletions) if err != nil { - return nil, err + return nil, nil, err } res.Content = []mcp.Content{ &mcp.TextContent{Text: "Observations deleted successfully"}, } - return &res, nil + return &res, nil, nil } -func (k knowledgeBase) DeleteRelations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[DeleteRelationsArgs]]) (*mcp.CallToolResultFor[struct{}], error) { - var res mcp.CallToolResultFor[struct{}] +func (k knowledgeBase) DeleteRelations(ctx context.Context, _ *mcp.ServerRequest[*mcp.CallToolParams], args DeleteRelationsArgs) (*mcp.CallToolResult, any, error) { + var res mcp.CallToolResult - err := k.deleteRelations(req.Params.Arguments.Relations) + err := k.deleteRelations(args.Relations) if err != nil { - return nil, err + return nil, nil, err } res.Content = []mcp.Content{ &mcp.TextContent{Text: "Relations deleted successfully"}, } - return &res, nil + return &res, nil, nil } -func (k knowledgeBase) ReadGraph(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[struct{}]]) (*mcp.CallToolResultFor[KnowledgeGraph], error) { - var res mcp.CallToolResultFor[KnowledgeGraph] +func (k knowledgeBase) ReadGraph(ctx context.Context, _ *mcp.ServerRequest[*mcp.CallToolParams], args struct{}) (*mcp.CallToolResult, KnowledgeGraph, error) { + var res mcp.CallToolResult graph, err := k.loadGraph() if err != nil { - return nil, err + return nil, KnowledgeGraph{}, err } res.Content = []mcp.Content{ &mcp.TextContent{Text: "Graph read successfully"}, } - res.StructuredContent = graph - return &res, nil + return &res, graph, nil } -func (k knowledgeBase) SearchNodes(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[SearchNodesArgs]]) (*mcp.CallToolResultFor[KnowledgeGraph], error) { - var res mcp.CallToolResultFor[KnowledgeGraph] +func (k knowledgeBase) SearchNodes(ctx context.Context, _ *mcp.ServerRequest[*mcp.CallToolParams], args SearchNodesArgs) (*mcp.CallToolResult, KnowledgeGraph, error) { + var res mcp.CallToolResult - graph, err := k.searchNodes(req.Params.Arguments.Query) + graph, err := k.searchNodes(args.Query) if err != nil { - return nil, err + return nil, KnowledgeGraph{}, err } res.Content = []mcp.Content{ &mcp.TextContent{Text: "Nodes searched successfully"}, } - - res.StructuredContent = graph - return &res, nil + return &res, graph, nil } -func (k knowledgeBase) OpenNodes(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[OpenNodesArgs]]) (*mcp.CallToolResultFor[KnowledgeGraph], error) { - var res mcp.CallToolResultFor[KnowledgeGraph] +func (k knowledgeBase) OpenNodes(ctx context.Context, _ *mcp.ServerRequest[*mcp.CallToolParams], args OpenNodesArgs) (*mcp.CallToolResult, KnowledgeGraph, error) { + var res mcp.CallToolResult - graph, err := k.openNodes(req.Params.Arguments.Names) + graph, err := k.openNodes(args.Names) if err != nil { - return nil, err + return nil, KnowledgeGraph{}, err } res.Content = []mcp.Content{ &mcp.TextContent{Text: "Nodes opened successfully"}, } - - res.StructuredContent = graph - return &res, nil + return &res, graph, nil } diff --git a/examples/server/memory/kb_test.go b/examples/server/memory/kb_test.go index 6e29d5e4..8ba947dc 100644 --- a/examples/server/memory/kb_test.go +++ b/examples/server/memory/kb_test.go @@ -435,141 +435,153 @@ func TestMCPServerIntegration(t *testing.T) { // Create mock server session ctx := context.Background() - serverSession := &mcp.ServerSession{} // Test CreateEntities through MCP - createEntitiesParams := &mcp.CallToolParamsFor[CreateEntitiesArgs]{ - Arguments: CreateEntitiesArgs{ - Entities: []Entity{ - { - Name: "TestPerson", - EntityType: "Person", - Observations: []string{"Likes testing"}, - }, + args := CreateEntitiesArgs{ + Entities: []Entity{ + { + Name: "TestPerson", + EntityType: "Person", + Observations: []string{"Likes testing"}, }, }, } - - createResult, err := kb.CreateEntities(ctx, requestFor(serverSession, createEntitiesParams)) + _, createResult, err := kb.CreateEntities(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, args) if err != nil { t.Fatalf("MCP CreateEntities failed: %v", err) } - if createResult.IsError { - t.Fatalf("MCP CreateEntities returned error: %v", createResult.Content) - } - if len(createResult.StructuredContent.Entities) != 1 { - t.Errorf("expected 1 entity created, got %d", len(createResult.StructuredContent.Entities)) + if g := len(createResult.Entities); g != 1 { + t.Errorf("expected 1 entity created, got %d", g) } // Test ReadGraph through MCP - readParams := &mcp.CallToolParamsFor[struct{}]{} - readResult, err := kb.ReadGraph(ctx, requestFor(serverSession, readParams)) + _, readResult, err := kb.ReadGraph(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, struct{}{}) if err != nil { t.Fatalf("MCP ReadGraph failed: %v", err) } - if readResult.IsError { - t.Fatalf("MCP ReadGraph returned error: %v", readResult.Content) - } - if len(readResult.StructuredContent.Entities) != 1 { - t.Errorf("expected 1 entity in graph, got %d", len(readResult.StructuredContent.Entities)) + if len(readResult.Entities) != 1 { + t.Errorf("expected 1 entity in graph, got %d", len(readResult.Entities)) } // Test CreateRelations through MCP - createRelationsParams := &mcp.CallToolParamsFor[CreateRelationsArgs]{ - Arguments: CreateRelationsArgs{ - Relations: []Relation{ - { - From: "TestPerson", - To: "Testing", - RelationType: "likes", - }, + crargs := CreateRelationsArgs{ + Relations: []Relation{ + { + From: "TestPerson", + To: "Testing", + RelationType: "likes", }, }, } - - relationsResult, err := kb.CreateRelations(ctx, requestFor(serverSession, createRelationsParams)) + _, relationsResult, err := kb.CreateRelations(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, crargs) if err != nil { t.Fatalf("MCP CreateRelations failed: %v", err) } - if relationsResult.IsError { - t.Fatalf("MCP CreateRelations returned error: %v", relationsResult.Content) - } - if len(relationsResult.StructuredContent.Relations) != 1 { - t.Errorf("expected 1 relation created, got %d", len(relationsResult.StructuredContent.Relations)) + if len(relationsResult.Relations) != 1 { + t.Errorf("expected 1 relation created, got %d", len(relationsResult.Relations)) } // Test AddObservations through MCP - addObsParams := &mcp.CallToolParamsFor[AddObservationsArgs]{ - Arguments: AddObservationsArgs{ - Observations: []Observation{ - { - EntityName: "TestPerson", - Contents: []string{"Works remotely", "Drinks coffee"}, - }, + addObsArgs := AddObservationsArgs{ + Observations: []Observation{ + { + EntityName: "TestPerson", + Contents: []string{"Works remotely", "Drinks coffee"}, }, }, } - obsResult, err := kb.AddObservations(ctx, requestFor(serverSession, addObsParams)) + _, obsResult, err := kb.AddObservations(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, addObsArgs) if err != nil { t.Fatalf("MCP AddObservations failed: %v", err) } - if obsResult.IsError { - t.Fatalf("MCP AddObservations returned error: %v", obsResult.Content) - } - if len(obsResult.StructuredContent.Observations) != 1 { - t.Errorf("expected 1 observation result, got %d", len(obsResult.StructuredContent.Observations)) + if len(obsResult.Observations) != 1 { + t.Errorf("expected 1 observation result, got %d", len(obsResult.Observations)) } // Test SearchNodes through MCP - searchParams := &mcp.CallToolParamsFor[SearchNodesArgs]{ - Arguments: SearchNodesArgs{ - Query: "coffee", - }, + searchArgs := SearchNodesArgs{ + Query: "coffee", } - - searchResult, err := kb.SearchNodes(ctx, requestFor(serverSession, searchParams)) + _, searchResult, err := kb.SearchNodes(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, searchArgs) if err != nil { t.Fatalf("MCP SearchNodes failed: %v", err) } - if searchResult.IsError { - t.Fatalf("MCP SearchNodes returned error: %v", searchResult.Content) - } - if len(searchResult.StructuredContent.Entities) != 1 { - t.Errorf("expected 1 entity from search, got %d", len(searchResult.StructuredContent.Entities)) + if len(searchResult.Entities) != 1 { + t.Errorf("expected 1 entity from search, got %d", len(searchResult.Entities)) } // Test OpenNodes through MCP - openParams := &mcp.CallToolParamsFor[OpenNodesArgs]{ - Arguments: OpenNodesArgs{ - Names: []string{"TestPerson"}, - }, + openArgs := OpenNodesArgs{ + Names: []string{"TestPerson"}, } - openResult, err := kb.OpenNodes(ctx, requestFor(serverSession, openParams)) + _, openResult, err := kb.OpenNodes(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, openArgs) if err != nil { t.Fatalf("MCP OpenNodes failed: %v", err) } - if openResult.IsError { - t.Fatalf("MCP OpenNodes returned error: %v", openResult.Content) - } - if len(openResult.StructuredContent.Entities) != 1 { - t.Errorf("expected 1 entity from open, got %d", len(openResult.StructuredContent.Entities)) + if len(openResult.Entities) != 1 { + t.Errorf("expected 1 entity from open, got %d", len(openResult.Entities)) } // Test DeleteObservations through MCP - deleteObsParams := &mcp.CallToolParamsFor[DeleteObservationsArgs]{ - Arguments: DeleteObservationsArgs{ - Deletions: []Observation{ - { - EntityName: "TestPerson", - Observations: []string{"Works remotely"}, - }, + deleteObsArgs := DeleteObservationsArgs{ + Deletions: []Observation{ + { + EntityName: "TestPerson", + Observations: []string{"Works remotely"}, }, }, } - deleteObsResult, err := kb.DeleteObservations(ctx, requestFor(serverSession, deleteObsParams)) + _, _, err = kb.DeleteObservations(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, deleteObsArgs) + if err != nil { + t.Fatalf("MCP DeleteObservations failed: %v", err) + } + + // Test DeleteRelations through MCP + deleteRelArgs := DeleteRelationsArgs{ + Relations: []Relation{ + { + From: "TestPerson", + To: "Testing", + RelationType: "likes", + }, + }, + } + + _, _, err = kb.DeleteRelations(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, deleteRelArgs) + if err != nil { + t.Fatalf("MCP DeleteRelations failed: %v", err) + } + + // Test DeleteEntities through MCP + deleteEntArgs := DeleteEntitiesArgs{ + EntityNames: []string{"TestPerson"}, + } + + _, _, err = kb.DeleteEntities(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, deleteEntArgs) + if err != nil { + t.Fatalf("MCP DeleteEntities failed: %v", err) + } + + // Verify final state + _, finalRead, err := kb.ReadGraph(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, struct{}{}) + if err != nil { + t.Fatalf("Final MCP ReadGraph failed: %v", err) + } + if len(finalRead.Entities) != 0 { + t.Errorf("expected empty graph after deletion, got %d entities", len(finalRead.Entities)) + } + doargs := DeleteObservationsArgs{ + Deletions: []Observation{ + { + EntityName: "TestPerson", + Observations: []string{"Works remotely"}, + }, + }, + } + deleteObsResult, _, err := kb.DeleteObservations(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, doargs) if err != nil { t.Fatalf("MCP DeleteObservations failed: %v", err) } @@ -578,19 +590,17 @@ func TestMCPServerIntegration(t *testing.T) { } // Test DeleteRelations through MCP - deleteRelParams := &mcp.CallToolParamsFor[DeleteRelationsArgs]{ - Arguments: DeleteRelationsArgs{ - Relations: []Relation{ - { - From: "TestPerson", - To: "Testing", - RelationType: "likes", - }, + drargs := DeleteRelationsArgs{ + Relations: []Relation{ + { + From: "TestPerson", + To: "Testing", + RelationType: "likes", }, }, } - deleteRelResult, err := kb.DeleteRelations(ctx, requestFor(serverSession, deleteRelParams)) + deleteRelResult, _, err := kb.DeleteRelations(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, drargs) if err != nil { t.Fatalf("MCP DeleteRelations failed: %v", err) } @@ -599,13 +609,11 @@ func TestMCPServerIntegration(t *testing.T) { } // Test DeleteEntities through MCP - deleteEntParams := &mcp.CallToolParamsFor[DeleteEntitiesArgs]{ - Arguments: DeleteEntitiesArgs{ - EntityNames: []string{"TestPerson"}, - }, + deargs := DeleteEntitiesArgs{ + EntityNames: []string{"TestPerson"}, } - deleteEntResult, err := kb.DeleteEntities(ctx, requestFor(serverSession, deleteEntParams)) + deleteEntResult, _, err := kb.DeleteEntities(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, deargs) if err != nil { t.Fatalf("MCP DeleteEntities failed: %v", err) } @@ -614,12 +622,12 @@ func TestMCPServerIntegration(t *testing.T) { } // Verify final state - finalRead, err := kb.ReadGraph(ctx, requestFor(serverSession, readParams)) + _, graph, err := kb.ReadGraph(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, struct{}{}) if err != nil { t.Fatalf("Final MCP ReadGraph failed: %v", err) } - if len(finalRead.StructuredContent.Entities) != 0 { - t.Errorf("expected empty graph after deletion, got %d entities", len(finalRead.StructuredContent.Entities)) + if len(graph.Entities) != 0 { + t.Errorf("expected empty graph after deletion, got %d entities", len(graph.Entities)) } }) } @@ -633,21 +641,17 @@ func TestMCPErrorHandling(t *testing.T) { kb := knowledgeBase{s: s} ctx := context.Background() - serverSession := &mcp.ServerSession{} // Test adding observations to non-existent entity - addObsParams := &mcp.CallToolParamsFor[AddObservationsArgs]{ - Arguments: AddObservationsArgs{ - Observations: []Observation{ - { - EntityName: "NonExistentEntity", - Contents: []string{"This should fail"}, - }, + + _, _, err := kb.AddObservations(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, AddObservationsArgs{ + Observations: []Observation{ + { + EntityName: "NonExistentEntity", + Contents: []string{"This should fail"}, }, }, - } - - _, err := kb.AddObservations(ctx, requestFor(serverSession, addObsParams)) + }) if err == nil { t.Errorf("expected MCP AddObservations to return error for non-existent entity") } else { @@ -667,28 +671,25 @@ func TestMCPResponseFormat(t *testing.T) { kb := knowledgeBase{s: s} ctx := context.Background() - serverSession := &mcp.ServerSession{} // Test CreateEntities response format - createParams := &mcp.CallToolParamsFor[CreateEntitiesArgs]{ - Arguments: CreateEntitiesArgs{ - Entities: []Entity{ - {Name: "FormatTest", EntityType: "Test"}, - }, + args := CreateEntitiesArgs{ + Entities: []Entity{ + {Name: "FormatTest", EntityType: "Test"}, }, } - result, err := kb.CreateEntities(ctx, requestFor(serverSession, createParams)) + result, createResult, err := kb.CreateEntities(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, args) if err != nil { t.Fatalf("CreateEntities failed: %v", err) } - // Verify response has both Content and StructuredContent + // Verify response has both Content and a structured result if len(result.Content) == 0 { t.Errorf("expected Content field to be populated") } - if len(result.StructuredContent.Entities) == 0 { - t.Errorf("expected StructuredContent.Entities to be populated") + if len(createResult.Entities) == 0 { + t.Errorf("expected createResult.Entities to be populated") } // Verify Content contains simple success message @@ -701,7 +702,3 @@ func TestMCPResponseFormat(t *testing.T) { t.Errorf("expected Content[0] to be TextContent") } } - -func requestFor[P mcp.Params](ss *mcp.ServerSession, p P) *mcp.ServerRequest[P] { - return &mcp.ServerRequest[P]{Session: ss, Params: p} -} diff --git a/examples/server/sequentialthinking/main.go b/examples/server/sequentialthinking/main.go index 45a4fa6f..af16be06 100644 --- a/examples/server/sequentialthinking/main.go +++ b/examples/server/sequentialthinking/main.go @@ -231,9 +231,7 @@ func deepCopyThoughts(thoughts []*Thought) []*Thought { } // StartThinking begins a new sequential thinking session for a complex problem. -func StartThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[StartThinkingArgs]]) (*mcp.CallToolResultFor[any], error) { - args := req.Params.Arguments - +func StartThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args StartThinkingArgs) (*mcp.CallToolResult, any, error) { sessionID := args.SessionID if sessionID == "" { sessionID = randText() @@ -255,20 +253,18 @@ func StartThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolPara store.SetSession(session) - return &mcp.CallToolResultFor[any]{ + return &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{ Text: fmt.Sprintf("Started thinking session '%s' for problem: %s\nEstimated steps: %d\nReady for your first thought.", sessionID, args.Problem, estimatedSteps), }, }, - }, nil + }, nil, nil } // ContinueThinking adds the next thought step, revises a previous step, or creates a branch in the thinking process. -func ContinueThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[ContinueThinkingArgs]]) (*mcp.CallToolResultFor[any], error) { - args := req.Params.Arguments - +func ContinueThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args ContinueThinkingArgs) (*mcp.CallToolResult, any, error) { // Handle revision of existing thought if args.ReviseStep != nil { err := store.CompareAndSwap(args.SessionID, func(session *ThinkingSession) (*ThinkingSession, error) { @@ -283,17 +279,17 @@ func ContinueThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolP return session, nil }) if err != nil { - return nil, err + return nil, nil, err } - return &mcp.CallToolResultFor[any]{ + return &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{ Text: fmt.Sprintf("Revised step %d in session '%s':\n%s", *args.ReviseStep, args.SessionID, args.Thought), }, }, - }, nil + }, nil, nil } // Handle branching @@ -322,20 +318,20 @@ func ContinueThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolP return session, nil }) if err != nil { - return nil, err + return nil, nil, err } // Save the branch session store.SetSession(branchSession) - return &mcp.CallToolResultFor[any]{ + return &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{ Text: fmt.Sprintf("Created branch '%s' from session '%s'. You can now continue thinking in either session.", branchID, args.SessionID), }, }, - }, nil + }, nil, nil } // Add new thought @@ -381,27 +377,25 @@ func ContinueThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolP return session, nil }) if err != nil { - return nil, err + return nil, nil, err } - return &mcp.CallToolResultFor[any]{ + return &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{ Text: fmt.Sprintf("Session '%s' - %s:\n%s%s", args.SessionID, progress, args.Thought, statusMsg), }, }, - }, nil + }, nil, nil } // ReviewThinking provides a complete review of the thinking process for a session. -func ReviewThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[ReviewThinkingArgs]]) (*mcp.CallToolResultFor[any], error) { - args := req.Params.Arguments - +func ReviewThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args ReviewThinkingArgs) (*mcp.CallToolResult, any, error) { // Get a snapshot of the session to avoid race conditions sessionSnapshot, exists := store.SessionSnapshot(args.SessionID) if !exists { - return nil, fmt.Errorf("session %s not found", args.SessionID) + return nil, nil, fmt.Errorf("session %s not found", args.SessionID) } var review strings.Builder @@ -424,13 +418,13 @@ func ReviewThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolPar fmt.Fprintf(&review, "%d. %s%s\n", i+1, thought.Content, status) } - return &mcp.CallToolResultFor[any]{ + return &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{ Text: review.String(), }, }, - }, nil + }, nil, nil } // ThinkingHistory handles resource requests for thinking session data and history. diff --git a/examples/server/sequentialthinking/main_test.go b/examples/server/sequentialthinking/main_test.go index c5e4a95a..9b445705 100644 --- a/examples/server/sequentialthinking/main_test.go +++ b/examples/server/sequentialthinking/main_test.go @@ -26,12 +26,7 @@ func TestStartThinking(t *testing.T) { EstimatedSteps: 5, } - params := &mcp.CallToolParamsFor[StartThinkingArgs]{ - Name: "start_thinking", - Arguments: args, - } - - result, err := StartThinking(ctx, requestFor(params)) + result, _, err := StartThinking(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, args) if err != nil { t.Fatalf("StartThinking() error = %v", err) } @@ -84,12 +79,7 @@ func TestContinueThinking(t *testing.T) { EstimatedSteps: 3, } - startParams := &mcp.CallToolParamsFor[StartThinkingArgs]{ - Name: "start_thinking", - Arguments: startArgs, - } - - _, err := StartThinking(ctx, requestFor(startParams)) + _, _, err := StartThinking(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, startArgs) if err != nil { t.Fatalf("StartThinking() error = %v", err) } @@ -100,12 +90,7 @@ func TestContinueThinking(t *testing.T) { Thought: "First thought: I need to understand the problem", } - continueParams := &mcp.CallToolParamsFor[ContinueThinkingArgs]{ - Name: "continue_thinking", - Arguments: continueArgs, - } - - result, err := ContinueThinking(ctx, requestFor(continueParams)) + result, _, err := ContinueThinking(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, continueArgs) if err != nil { t.Fatalf("ContinueThinking() error = %v", err) } @@ -153,12 +138,7 @@ func TestContinueThinkingWithCompletion(t *testing.T) { SessionID: "test_completion", } - startParams := &mcp.CallToolParamsFor[StartThinkingArgs]{ - Name: "start_thinking", - Arguments: startArgs, - } - - _, err := StartThinking(ctx, requestFor(startParams)) + _, _, err := StartThinking(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, startArgs) if err != nil { t.Fatalf("StartThinking() error = %v", err) } @@ -171,12 +151,7 @@ func TestContinueThinkingWithCompletion(t *testing.T) { NextNeeded: &nextNeeded, } - continueParams := &mcp.CallToolParamsFor[ContinueThinkingArgs]{ - Name: "continue_thinking", - Arguments: continueArgs, - } - - result, err := ContinueThinking(ctx, requestFor(continueParams)) + result, _, err := ContinueThinking(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, continueArgs) if err != nil { t.Fatalf("ContinueThinking() error = %v", err) } @@ -228,12 +203,7 @@ func TestContinueThinkingRevision(t *testing.T) { ReviseStep: &reviseStep, } - continueParams := &mcp.CallToolParamsFor[ContinueThinkingArgs]{ - Name: "continue_thinking", - Arguments: continueArgs, - } - - result, err := ContinueThinking(ctx, requestFor(continueParams)) + result, _, err := ContinueThinking(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, continueArgs) if err != nil { t.Fatalf("ContinueThinking() error = %v", err) } @@ -284,12 +254,7 @@ func TestContinueThinkingBranching(t *testing.T) { CreateBranch: true, } - continueParams := &mcp.CallToolParamsFor[ContinueThinkingArgs]{ - Name: "continue_thinking", - Arguments: continueArgs, - } - - result, err := ContinueThinking(ctx, requestFor(continueParams)) + result, _, err := ContinueThinking(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, continueArgs) if err != nil { t.Fatalf("ContinueThinking() error = %v", err) } @@ -351,12 +316,7 @@ func TestReviewThinking(t *testing.T) { SessionID: "test_review", } - reviewParams := &mcp.CallToolParamsFor[ReviewThinkingArgs]{ - Name: "review_thinking", - Arguments: reviewArgs, - } - - result, err := ReviewThinking(ctx, requestFor(reviewParams)) + result, _, err := ReviewThinking(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, reviewArgs) if err != nil { t.Fatalf("ReviewThinking() error = %v", err) } @@ -431,7 +391,7 @@ func TestThinkingHistory(t *testing.T) { URI: "thinking://sessions", } - result, err := ThinkingHistory(ctx, requestFor(listParams)) + result, err := ThinkingHistory(ctx, &mcp.ServerRequest[*mcp.ReadResourceParams]{Params: listParams}) if err != nil { t.Fatalf("ThinkingHistory() error = %v", err) } @@ -461,7 +421,7 @@ func TestThinkingHistory(t *testing.T) { URI: "thinking://session1", } - result, err = ThinkingHistory(ctx, requestFor(sessionParams)) + result, err = ThinkingHistory(ctx, &mcp.ServerRequest[*mcp.ReadResourceParams]{Params: sessionParams}) if err != nil { t.Fatalf("ThinkingHistory() error = %v", err) } @@ -491,12 +451,7 @@ func TestInvalidOperations(t *testing.T) { Thought: "Some thought", } - continueParams := &mcp.CallToolParamsFor[ContinueThinkingArgs]{ - Name: "continue_thinking", - Arguments: continueArgs, - } - - _, err := ContinueThinking(ctx, requestFor(continueParams)) + _, _, err := ContinueThinking(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, continueArgs) if err == nil { t.Error("Expected error for non-existent session") } @@ -506,12 +461,7 @@ func TestInvalidOperations(t *testing.T) { SessionID: "nonexistent", } - reviewParams := &mcp.CallToolParamsFor[ReviewThinkingArgs]{ - Name: "review_thinking", - Arguments: reviewArgs, - } - - _, err = ReviewThinking(ctx, requestFor(reviewParams)) + _, _, err = ReviewThinking(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, reviewArgs) if err == nil { t.Error("Expected error for non-existent session in review") } @@ -536,17 +486,8 @@ func TestInvalidOperations(t *testing.T) { ReviseStep: &reviseStep, } - invalidReviseParams := &mcp.CallToolParamsFor[ContinueThinkingArgs]{ - Name: "continue_thinking", - Arguments: invalidReviseArgs, - } - - _, err = ContinueThinking(ctx, requestFor(invalidReviseParams)) + _, _, err = ContinueThinking(ctx, &mcp.ServerRequest[*mcp.CallToolParams]{}, invalidReviseArgs) if err == nil { t.Error("Expected error for invalid revision step") } } - -func requestFor[P mcp.Params](p P) *mcp.ServerRequest[P] { - return &mcp.ServerRequest[P]{Params: p} -} diff --git a/internal/readme/server/server.go b/internal/readme/server/server.go index 3aa1037c..087992e8 100644 --- a/internal/readme/server/server.go +++ b/internal/readme/server/server.go @@ -16,10 +16,10 @@ type HiParams struct { Name string `json:"name" jsonschema:"the name of the person to greet"` } -func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[HiParams]]) (*mcp.CallToolResultFor[any], error) { - return &mcp.CallToolResultFor[any]{ - Content: []mcp.Content{&mcp.TextContent{Text: "Hi " + req.Params.Arguments.Name}}, - }, nil +func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args HiParams) (*mcp.CallToolResult, any, error) { + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: "Hi " + args.Name}}, + }, nil, nil } func main() { diff --git a/mcp/client.go b/mcp/client.go index b0db1d64..b6693056 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -103,8 +103,7 @@ func (e unsupportedProtocolVersionError) Error() string { } // ClientSessionOptions is reserved for future use. -type ClientSessionOptions struct { -} +type ClientSessionOptions struct{} // Connect begins an MCP session by connecting to a server over the given // transport, and initializing the session. diff --git a/mcp/client_list_test.go b/mcp/client_list_test.go index 836d4803..8973749f 100644 --- a/mcp/client_list_test.go +++ b/mcp/client_list_test.go @@ -33,7 +33,7 @@ func TestList(t *testing.T) { if err != nil { t.Fatal("ListTools() failed:", err) } - if diff := cmp.Diff(wantTools, res.Tools, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + if diff := cmp.Diff(wantTools, res.Tools, cmpopts.IgnoreUnexported(ignoreUnexp...)); diff != "" { t.Fatalf("ListTools() mismatch (-want +got):\n%s", diff) } }) @@ -55,7 +55,7 @@ func TestList(t *testing.T) { if err != nil { t.Fatal("ListResources() failed:", err) } - if diff := cmp.Diff(wantResources, res.Resources, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + if diff := cmp.Diff(wantResources, res.Resources, cmpopts.IgnoreUnexported(ignoreUnexp...)); diff != "" { t.Fatalf("ListResources() mismatch (-want +got):\n%s", diff) } }) @@ -76,7 +76,7 @@ func TestList(t *testing.T) { if err != nil { t.Fatal("ListResourceTemplates() failed:", err) } - if diff := cmp.Diff(wantResourceTemplates, res.ResourceTemplates, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + if diff := cmp.Diff(wantResourceTemplates, res.ResourceTemplates, cmpopts.IgnoreUnexported(ignoreUnexp...)); diff != "" { t.Fatalf("ListResourceTemplates() mismatch (-want +got):\n%s", diff) } }) @@ -97,7 +97,7 @@ func TestList(t *testing.T) { if err != nil { t.Fatal("ListPrompts() failed:", err) } - if diff := cmp.Diff(wantPrompts, res.Prompts, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + if diff := cmp.Diff(wantPrompts, res.Prompts, cmpopts.IgnoreUnexported(ignoreUnexp...)); diff != "" { t.Fatalf("ListPrompts() mismatch (-want +got):\n%s", diff) } }) @@ -116,7 +116,7 @@ func testIterator[T any](t *testing.T, seq iter.Seq2[*T, error], want []*T) { } got = append(got, x) } - if diff := cmp.Diff(want, got, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + if diff := cmp.Diff(want, got, cmpopts.IgnoreUnexported(ignoreUnexp...)); diff != "" { t.Fatalf("mismatch (-want +got):\n%s", diff) } } @@ -124,3 +124,5 @@ func testIterator[T any](t *testing.T, seq iter.Seq2[*T, error], want []*T) { func testPromptHandler(context.Context, *mcp.ServerSession, *mcp.GetPromptParams) (*mcp.GetPromptResult, error) { panic("not implemented") } + +var ignoreUnexp = []any{jsonschema.Schema{}, mcp.Tool{}} diff --git a/mcp/content.go b/mcp/content.go index 8bf75f0f..f8777154 100644 --- a/mcp/content.go +++ b/mcp/content.go @@ -252,6 +252,9 @@ func contentsFromWire(wires []*wireContent, allow map[string]bool) ([]Content, e } func contentFromWire(wire *wireContent, allow map[string]bool) (Content, error) { + if wire == nil { + return nil, fmt.Errorf("content wire is nil") + } if allow != nil && !allow[wire.Type] { return nil, fmt.Errorf("invalid content type %q", wire.Type) } diff --git a/mcp/content_nil_test.go b/mcp/content_nil_test.go new file mode 100644 index 00000000..32e7e8cf --- /dev/null +++ b/mcp/content_nil_test.go @@ -0,0 +1,224 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file contains tests to verify that UnmarshalJSON methods for Content types +// don't panic when unmarshaling onto nil pointers, as requested in GitHub issue #205. +// +// NOTE: The contentFromWire function has been fixed to handle nil wire.Content +// gracefully by returning an error instead of panicking. + +package mcp_test + +import ( + "encoding/json" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +func TestContentUnmarshalNil(t *testing.T) { + tests := []struct { + name string + json string + content interface{} + want interface{} + }{ + { + name: "CallToolResult nil Content", + json: `{"content":[{"type":"text","text":"hello"}]}`, + content: &mcp.CallToolResult{}, + want: &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: "hello"}}}, + }, + { + name: "CreateMessageResult nil Content", + json: `{"content":{"type":"text","text":"hello"},"model":"test","role":"user"}`, + content: &mcp.CreateMessageResult{}, + want: &mcp.CreateMessageResult{Content: &mcp.TextContent{Text: "hello"}, Model: "test", Role: "user"}, + }, + { + name: "PromptMessage nil Content", + json: `{"content":{"type":"text","text":"hello"},"role":"user"}`, + content: &mcp.PromptMessage{}, + want: &mcp.PromptMessage{Content: &mcp.TextContent{Text: "hello"}, Role: "user"}, + }, + { + name: "SamplingMessage nil Content", + json: `{"content":{"type":"text","text":"hello"},"role":"user"}`, + content: &mcp.SamplingMessage{}, + want: &mcp.SamplingMessage{Content: &mcp.TextContent{Text: "hello"}, Role: "user"}, + }, + { + name: "CallToolResultFor nil Content", + json: `{"content":[{"type":"text","text":"hello"}]}`, + content: &mcp.CallToolResult{}, + want: &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: "hello"}}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test that unmarshaling doesn't panic on nil Content fields + defer func() { + if r := recover(); r != nil { + t.Errorf("UnmarshalJSON panicked: %v", r) + } + }() + + err := json.Unmarshal([]byte(tt.json), tt.content) + if err != nil { + t.Errorf("UnmarshalJSON failed: %v", err) + } + + // Verify that the Content field was properly populated + if cmp.Diff(tt.want, tt.content) != "" { + t.Errorf("Content is not equal: %v", cmp.Diff(tt.content, tt.content)) + } + }) + } +} + +func TestContentUnmarshalNilWithDifferentTypes(t *testing.T) { + tests := []struct { + name string + json string + content interface{} + expectError bool + }{ + { + name: "ImageContent", + json: `{"content":{"type":"image","mimeType":"image/png","data":"YTFiMmMz"}}`, + content: &mcp.CreateMessageResult{}, + expectError: false, + }, + { + name: "AudioContent", + json: `{"content":{"type":"audio","mimeType":"audio/wav","data":"YTFiMmMz"}}`, + content: &mcp.CreateMessageResult{}, + expectError: false, + }, + { + name: "ResourceLink", + json: `{"content":{"type":"resource_link","uri":"file:///test","name":"test"}}`, + content: &mcp.CreateMessageResult{}, + expectError: true, // CreateMessageResult only allows text, image, audio + }, + { + name: "EmbeddedResource", + json: `{"content":{"type":"resource","resource":{"uri":"file://test","text":"test"}}}`, + content: &mcp.CreateMessageResult{}, + expectError: true, // CreateMessageResult only allows text, image, audio + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test that unmarshaling doesn't panic on nil Content fields + defer func() { + if r := recover(); r != nil { + t.Errorf("UnmarshalJSON panicked: %v", r) + } + }() + + err := json.Unmarshal([]byte(tt.json), tt.content) + if tt.expectError && err == nil { + t.Error("Expected error but got none") + } + if !tt.expectError && err != nil { + t.Errorf("Unexpected error: %v", err) + } + + // Verify that the Content field was properly populated for successful cases + if !tt.expectError { + if result, ok := tt.content.(*mcp.CreateMessageResult); ok { + if result.Content == nil { + t.Error("CreateMessageResult.Content was not populated") + } + } + } + }) + } +} + +func TestContentUnmarshalNilWithEmptyContent(t *testing.T) { + tests := []struct { + name string + json string + content interface{} + expectError bool + }{ + { + name: "Empty Content array", + json: `{"content":[]}`, + content: &mcp.CallToolResult{}, + expectError: false, + }, + { + name: "Missing Content field", + json: `{"model":"test","role":"user"}`, + content: &mcp.CreateMessageResult{}, + expectError: true, // Content field is required for CreateMessageResult + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test that unmarshaling doesn't panic on nil Content fields + // defer func() { + // if r := recover(); r != nil { + // t.Errorf("UnmarshalJSON panicked: %v", r) + // } + // }() + + err := json.Unmarshal([]byte(tt.json), tt.content) + if tt.expectError && err == nil { + t.Error("Expected error but got none") + } + if !tt.expectError && err != nil { + t.Errorf("Unexpected error: %v", err) + } + }) + } +} + +func TestContentUnmarshalNilWithInvalidContent(t *testing.T) { + tests := []struct { + name string + json string + content interface{} + expectError bool + }{ + { + name: "Invalid content type", + json: `{"content":{"type":"invalid","text":"hello"}}`, + content: &mcp.CreateMessageResult{}, + expectError: true, + }, + { + name: "Missing type field", + json: `{"content":{"text":"hello"}}`, + content: &mcp.CreateMessageResult{}, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test that unmarshaling doesn't panic on nil Content fields + defer func() { + if r := recover(); r != nil { + t.Errorf("UnmarshalJSON panicked: %v", r) + } + }() + + err := json.Unmarshal([]byte(tt.json), tt.content) + if tt.expectError && err == nil { + t.Error("Expected error but got none") + } + if !tt.expectError && err != nil { + t.Errorf("Unexpected error: %v", err) + } + }) + } +} diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 56a32ead..4c4fa708 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -646,7 +646,6 @@ func TestCancellation(t *testing.T) { start = make(chan struct{}) cancelled = make(chan struct{}, 1) // don't block the request ) - slowRequest := func(ctx context.Context, _ *ServerRequest[*CallToolParams], _ any) (*CallToolResult, error) { start <- struct{}{} select { @@ -663,8 +662,18 @@ func TestCancellation(t *testing.T) { defer cs.Close() ctx, cancel := context.WithCancel(context.Background()) - go cs.CallTool(ctx, &CallToolParams{Name: "slow"}) - <-start + errc := make(chan error, 1) + go func() { + _, err := cs.CallTool(ctx, &CallToolParams{Name: "slow"}) + if err != nil { + errc <- err + } + }() + select { + case err := <-errc: + t.Fatalf("CallTool returned %v", err) + case <-start: + } cancel() select { case <-cancelled: diff --git a/mcp/server.go b/mcp/server.go index 9d7ed9ed..3bbeadfb 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -147,18 +147,30 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) { s.addServerTool(newServerTool(t, h)) } -// AddTool adds a [Tool] to the server, or replaces one with the same name. -// If the tool's input schema is nil, it is set to the schema inferred from the In -// type parameter, using [jsonschema.For]. -// If the tool's output schema is nil and the Out type parameter is not the empty -// interface, then the output schema is set to the schema inferred from Out. -// The Tool argument must not be modified after this call. +// TypedTool returns a [Tool] and a [ToolHandler] from its arguments. +// The argument Tool must not have been used in a previous call to [AddTool] or TypedTool. +// It is returned with the following modifications: +// - If the tool doesn't have an input schema, it is inferred from In. +// - If the tool doesn't have an output schema and Out != any, it is inferred from Out. // -// The handler should return the result as the second return value. The first return value, -// a *CallToolResult, may be nil, or its fields other than StructuredContent may be -// populated. +// The returned tool must not be modified and should be used only with the returned ToolHandler. +// +// The argument handler should return the result as the second return value. The +// first return value, a *CallToolResult, may be nil, or its fields may be populated. +// TypedTool will populate the StructuredContent field with the second return value. +// It does not populate the Content field with the serialized JSON of StructuredContent, +// as suggested in the MCP specification. You can do so by wrapping the returned ToolHandler. +func TypedTool[In, Out any](t *Tool, h TypedToolHandler[In, Out]) (*Tool, ToolHandler) { + th, err := newTypedToolHandler(t, h) + if err != nil { + panic(fmt.Sprintf("TypedTool for %q: %v", t.Name, err)) + } + return t, th +} + +// AddTool is a convenience for s.AddTool(TypedTool(t, h)). func AddTool[In, Out any](s *Server, t *Tool, h TypedToolHandler[In, Out]) { - s.addServerTool(newTypedServerTool(t, h)) + s.AddTool(TypedTool(t, h)) } func (s *Server) addServerTool(st *serverTool, err error) { diff --git a/mcp/shared.go b/mcp/shared.go index ea4975ef..518f41d3 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -252,26 +252,8 @@ func newServerMethodInfo[P paramsPtr[T], R Result, T any](d typedServerMethodHan // notification. func newMethodInfo[P paramsPtr[T], R Result, T any](flags methodFlags) methodInfo { return methodInfo{ - flags: flags, - unmarshalParams: func(m json.RawMessage) (Params, error) { - var p P - if m != nil { - if err := json.Unmarshal(m, &p); err != nil { - return nil, fmt.Errorf("unmarshaling %q into a %T: %w", m, p, err) - } - } - // We must check missingParamsOK here, in addition to checkRequest, to - // catch the edge cases where "params" is set to JSON null. - // See also https://go.dev/issue/33835. - // - // We need to ensure that p is non-null to guard against crashes, as our - // internal code or externally provided handlers may assume that params - // is non-null. - if flags&missingParamsOK == 0 && p == nil { - return nil, fmt.Errorf("%w: missing required \"params\"", jsonrpc2.ErrInvalidRequest) - } - return orZero[Params](p), nil - }, + flags: flags, + unmarshalParams: unmarshalParamsFunc[P](flags), // newResult is used on the send side, to construct the value to unmarshal the result into. // R is a pointer to a result struct. There is no way to "unpointer" it without reflection. // TODO(jba): explore generic approaches to this, perhaps by treating R in @@ -280,6 +262,28 @@ func newMethodInfo[P paramsPtr[T], R Result, T any](flags methodFlags) methodInf } } +func unmarshalParamsFunc[P paramsPtr[T], T any](flags methodFlags) func(m json.RawMessage) (Params, error) { + return func(m json.RawMessage) (Params, error) { + var p P + if m != nil { + if err := json.Unmarshal(m, &p); err != nil { + return nil, fmt.Errorf("unmarshaling %q into a %T: %w", m, p, err) + } + } + // We must check missingParamsOK here, in addition to checkRequest, to + // catch the edge cases where "params" is set to JSON null. + // See also https://go.dev/issue/33835. + // + // We need to ensure that p is non-null to guard against crashes, as our + // internal code or externally provided handlers may assume that params + // is non-null. + if flags&missingParamsOK == 0 && p == nil { + return nil, fmt.Errorf("%w: missing required \"params\"", jsonrpc2.ErrInvalidRequest) + } + return orZero[Params](p), nil + } +} + // serverMethod is glue for creating a typedMethodHandler from a method on Server. func serverMethod[P Params, R Result]( f func(*Server, context.Context, *ServerRequest[P]) (R, error), diff --git a/mcp/shared_test.go b/mcp/shared_test.go index 2bec742f..de5fe7de 100644 --- a/mcp/shared_test.go +++ b/mcp/shared_test.go @@ -4,7 +4,14 @@ package mcp -// TODO(jba): this shouldn't be in this file, but tool_test.go doesn't have access to unexported symbols. +import ( + "encoding/json" + "fmt" + "strings" + "testing" +) + +// TODO(jba): rewrite to use public API. // func TestToolValidate(t *testing.T) { // // Check that the tool returned from NewServerTool properly validates its input schema. @@ -15,7 +22,7 @@ package mcp // P *int `json:",omitempty"` // } -// dummyHandler := func(context.Context, *ServerRequest[*CallToolParamsFor[req]]) (*CallToolResultFor[any], error) { +// dummyHandler := func(context.Context, *ServerRequest[*CallToolParams], req) (*CallToolResultFor[any], error) { // return nil, nil // } @@ -85,140 +92,122 @@ package mcp // TestNilParamsHandling tests that nil parameters don't cause panic in unmarshalParams. // This addresses a vulnerability where missing or null parameters could crash the server. -// func TestNilParamsHandling(t *testing.T) { -// // Define test types for clarity -// type TestArgs struct { -// Name string `json:"name"` -// Value int `json:"value"` -// } - -// // Simple test handler -// testHandler := func(ctx context.Context, req *ServerRequest[*CallToolParams], args TestArgs) (*CallToolResult, string, error) { -// result := "processed: " + args.Name -// return nil, result, nil -// } - -// methodInfo := newServerMethodInfo(testHandler, missingParamsOK) - -// // Helper function to test that unmarshalParams doesn't panic and handles nil gracefully -// mustNotPanic := func(t *testing.T, rawMsg json.RawMessage, expectNil bool) Params { -// t.Helper() - -// defer func() { -// if r := recover(); r != nil { -// t.Fatalf("unmarshalParams panicked: %v", r) -// } -// }() - -// params, err := methodInfo.unmarshalParams(rawMsg) -// if err != nil { -// t.Fatalf("unmarshalParams failed: %v", err) -// } - -// if expectNil { -// if params != nil { -// t.Fatalf("Expected nil params, got %v", params) -// } -// return params -// } - -// if params == nil { -// t.Fatal("unmarshalParams returned unexpected nil") -// } - -// // Verify the result can be used safely -// typedParams := params.(TestParams) -// _ = typedParams.Name -// _ = typedParams.Arguments.Name -// _ = typedParams.Arguments.Value - -// return params -// } - -// // Test different nil parameter scenarios - with missingParamsOK flag, nil/null should return nil -// t.Run("missing_params", func(t *testing.T) { -// mustNotPanic(t, nil, true) // Expect nil with missingParamsOK flag -// }) - -// t.Run("explicit_null", func(t *testing.T) { -// mustNotPanic(t, json.RawMessage(`null`), true) // Expect nil with missingParamsOK flag -// }) - -// t.Run("empty_object", func(t *testing.T) { -// mustNotPanic(t, json.RawMessage(`{}`), false) // Empty object should create valid params -// }) - -// t.Run("valid_params", func(t *testing.T) { -// rawMsg := json.RawMessage(`{"name":"test","arguments":{"name":"hello","value":42}}`) -// params := mustNotPanic(t, rawMsg, false) - -// // For valid params, also verify the values are parsed correctly -// typedParams := params.(TestParams) -// if typedParams.Name != "test" { -// t.Errorf("Expected name 'test', got %q", typedParams.Name) -// } -// if typedParams.Arguments.Name != "hello" { -// t.Errorf("Expected argument name 'hello', got %q", typedParams.Arguments.Name) -// } -// if typedParams.Arguments.Value != 42 { -// t.Errorf("Expected argument value 42, got %d", typedParams.Arguments.Value) -// } -// }) -// } +func TestNilParamsHandling(t *testing.T) { + unmarshalParams := unmarshalParamsFunc[*GetPromptParams](missingParamsOK) + + // Helper function to test that unmarshalParams doesn't panic and handles nil gracefully + mustNotPanic := func(t *testing.T, rawMsg json.RawMessage, expectNil bool) Params { + t.Helper() + + defer func() { + if r := recover(); r != nil { + t.Fatalf("unmarshalParams panicked: %v", r) + } + }() + + params, err := unmarshalParams(rawMsg) + if err != nil { + t.Fatalf("unmarshalParams failed: %v", err) + } + + if expectNil { + if params != nil { + t.Fatalf("Expected nil params, got %v", params) + } + return params + } + + if params == nil { + t.Fatal("unmarshalParams returned unexpected nil") + } + + // Verify the result can be used safely + typedParams := params.(*GetPromptParams) + _ = typedParams.Meta + _ = typedParams.Arguments + _ = typedParams.Name + + return params + } + + // Test different nil parameter scenarios - with missingParamsOK flag, nil/null should return nil + t.Run("missing_params", func(t *testing.T) { + mustNotPanic(t, nil, true) // Expect nil with missingParamsOK flag + }) + + t.Run("explicit_null", func(t *testing.T) { + mustNotPanic(t, json.RawMessage(`null`), true) // Expect nil with missingParamsOK flag + }) + + t.Run("empty_object", func(t *testing.T) { + mustNotPanic(t, json.RawMessage(`{}`), false) // Empty object should create valid params + }) + + t.Run("valid_params", func(t *testing.T) { + rawMsg := json.RawMessage(`{"name":"test","arguments":{"name":"hello","v":"x"}}`) + params := mustNotPanic(t, rawMsg, false) + + // For valid params, also verify the values are parsed correctly + typedParams := params.(*GetPromptParams) + if typedParams.Name != "test" { + t.Errorf("Expected name 'test', got %q", typedParams.Name) + } + if g, w := typedParams.Name, "test"; g != w { + t.Errorf("got %v, want %v", g, w) + } + if g, w := typedParams.Arguments["name"], "hello"; g != w { + t.Errorf("got %v, want %v", g, w) + } + if g, w := typedParams.Arguments["v"], "x"; g != w { + t.Errorf("got %v, want %v", g, w) + } + }) +} // TestNilParamsEdgeCases tests edge cases to ensure we don't over-fix -// func TestNilParamsEdgeCases(t *testing.T) { -// type TestArgs struct { -// Name string `json:"name"` -// Value int `json:"value"` -// } - -// testHandler := func(context.Context, *ServerRequest[*CallToolParams], TestArgs) (*CallToolResult, string, error) { -// return nil, "test", nil -// } - -// methodInfo := newServerMethodInfo(testHandler, missingParamsOK) - -// // These should fail normally, not be treated as nil params -// invalidCases := []json.RawMessage{ -// json.RawMessage(""), // empty string - should error -// json.RawMessage("[]"), // array - should error -// json.RawMessage(`"null"`), // string "null" - should error -// json.RawMessage("0"), // number - should error -// json.RawMessage("false"), // boolean - should error -// } - -// for i, rawMsg := range invalidCases { -// t.Run(fmt.Sprintf("invalid_case_%d", i), func(t *testing.T) { -// params, err := methodInfo.unmarshalParams(rawMsg) -// if err == nil && params == nil { -// t.Error("Should not return nil params without error") -// } -// }) -// } - -// // Test that methods without missingParamsOK flag properly reject nil params -// t.Run("reject_when_params_required", func(t *testing.T) { -// methodInfoStrict := newServerMethodInfo(testHandler, 0) // No missingParamsOK flag - -// testCases := []struct { -// name string -// params json.RawMessage -// }{ -// {"nil_params", nil}, -// {"null_params", json.RawMessage(`null`)}, -// } - -// for _, tc := range testCases { -// t.Run(tc.name, func(t *testing.T) { -// _, err := methodInfoStrict.unmarshalParams(tc.params) -// if err == nil { -// t.Error("Expected error for required params, got nil") -// } -// if !strings.Contains(err.Error(), "missing required \"params\"") { -// t.Errorf("Expected 'missing required params' error, got: %v", err) -// } -// }) -// } -// }) -// } +func TestNilParamsEdgeCases(t *testing.T) { + unmarshalParams := unmarshalParamsFunc[*GetPromptParams](missingParamsOK) + + // These should fail normally, not be treated as nil params + invalidCases := []json.RawMessage{ + json.RawMessage(""), // empty string - should error + json.RawMessage("[]"), // array - should error + json.RawMessage(`"null"`), // string "null" - should error + json.RawMessage("0"), // number - should error + json.RawMessage("false"), // boolean - should error + } + + for i, rawMsg := range invalidCases { + t.Run(fmt.Sprintf("invalid_case_%d", i), func(t *testing.T) { + params, err := unmarshalParams(rawMsg) + if err == nil && params == nil { + t.Error("Should not return nil params without error") + } + }) + } + + // Test that methods without missingParamsOK flag properly reject nil params + t.Run("reject_when_params_required", func(t *testing.T) { + unmarshalParams := unmarshalParamsFunc[*GetPromptParams](0) // No missingParamsOK flag + + testCases := []struct { + name string + params json.RawMessage + }{ + {"nil_params", nil}, + {"null_params", json.RawMessage(`null`)}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, err := unmarshalParams(tc.params) + if err == nil { + t.Error("Expected error for required params, got nil") + } + if !strings.Contains(err.Error(), "missing required \"params\"") { + t.Errorf("Expected 'missing required params' error, got: %v", err) + } + }) + } + }) +} diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index c5cf669b..6e8db096 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -40,6 +40,23 @@ func TestStreamableTransports(t *testing.T) { // 1. Create a server with a simple "greet" tool. server := NewServer(testImpl, nil) AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi) + // The "hang" tool checks that context cancellation is propagated. + // It hangs until the context is cancelled. + var ( + start = make(chan struct{}) + cancelled = make(chan struct{}, 1) // don't block the request + ) + hang := func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, any, error) { + start <- struct{}{} + select { + case <-ctx.Done(): + cancelled <- struct{}{} + case <-time.After(5 * time.Second): + return nil, nil, nil + } + return nil, nil, nil + } + AddTool(server, &Tool{Name: "hang"}, hang) AddTool(server, &Tool{Name: "sample"}, func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, any, error) { // Test that we can make sampling requests during tool handling. // diff --git a/mcp/tool.go b/mcp/tool.go index b09eccb3..2b5881cf 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -63,11 +63,15 @@ func newServerTool(t *Tool, h ToolHandler) (*serverTool, error) { } // Ignore output schema. st.handler = func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + argsp := t.newArgs() rawArgs := req.Params.Arguments.(json.RawMessage) - args := t.newArgs() - if err := unmarshalSchema(rawArgs, st.inputResolved, args); err != nil { - return nil, err + if rawArgs != nil { + if err := unmarshalSchema(rawArgs, st.inputResolved, argsp); err != nil { + return nil, err + } } + // Dereference argsp. + args := reflect.ValueOf(argsp).Elem().Interface() res, err := h(ctx, req, args) // TODO(rfindley): investigate why server errors are embedded in this strange way, // rather than returned as jsonrpc2 server errors. @@ -83,19 +87,19 @@ func newServerTool(t *Tool, h ToolHandler) (*serverTool, error) { return st, nil } -// newTypedServerTool creates a serverTool from a tool and a handler. -// If the tool doesn't have an input schema, it is inferred from In. -// If the tool doesn't have an output schema and Out != any, it is inferred from Out. -func newTypedServerTool[In, Out any](t *Tool, h TypedToolHandler[In, Out]) (*serverTool, error) { +// newTypedToolHandler is a helper for [TypedTool]. +func newTypedToolHandler[In, Out any](t *Tool, h TypedToolHandler[In, Out]) (ToolHandler, error) { assert(t.newArgs == nil, "newArgs is nil") t.newArgs = func() any { var x In; return &x } var err error - t.InputSchema, err = jsonschema.For[In](nil) - if err != nil { - return nil, err + if t.InputSchema == nil { + t.InputSchema, err = jsonschema.For[In](nil) + if err != nil { + return nil, err + } } - if reflect.TypeFor[Out]() != reflect.TypeFor[any]() { + if t.OutputSchema == nil && reflect.TypeFor[Out]() != reflect.TypeFor[any]() { t.OutputSchema, err = jsonschema.For[Out](nil) } if err != nil { @@ -103,7 +107,11 @@ func newTypedServerTool[In, Out any](t *Tool, h TypedToolHandler[In, Out]) (*ser } toolHandler := func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, error) { - res, out, err := h(ctx, req, *args.(*In)) + var inArg In + if args != nil { + inArg = args.(In) + } + res, out, err := h(ctx, req, inArg) if err != nil { return nil, err } @@ -112,10 +120,11 @@ func newTypedServerTool[In, Out any](t *Tool, h TypedToolHandler[In, Out]) (*ser } // TODO: return the serialized JSON in a TextContent block, as per spec? // https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content + // But people may use res.Content for other things. res.StructuredContent = out return res, nil } - return newServerTool(t, toolHandler) + return toolHandler, nil } // unmarshalSchema unmarshals data into v and validates the result according to @@ -131,7 +140,7 @@ func unmarshalSchema(data json.RawMessage, resolved *jsonschema.Resolved, v any) dec := json.NewDecoder(bytes.NewReader(data)) dec.DisallowUnknownFields() if err := dec.Decode(v); err != nil { - return fmt.Errorf("unmarshaling: %w", err) + return fmt.Errorf("unmarshaling tool args %q into %T: %w", data, v, err) } // TODO: test with nil args. diff --git a/mcp/tool_test.go b/mcp/tool_test.go index e98410bb..dbae7b38 100644 --- a/mcp/tool_test.go +++ b/mcp/tool_test.go @@ -22,7 +22,7 @@ func testToolHandler[In, Out any](context.Context, *ServerRequest[*CallToolParam func srvTool[In, Out any](t *testing.T, tool *Tool, handler TypedToolHandler[In, Out]) *serverTool { t.Helper() - st, err := newTypedServerTool(tool, handler) + st, err := newServerTool(TypedTool(tool, handler)) if err != nil { t.Fatal(err) }