Skip to content

Commit 5237cd1

Browse files
committed
mcp: change tool handler design
1 parent 52734fd commit 5237cd1

File tree

9 files changed

+214
-307
lines changed

9 files changed

+214
-307
lines changed

mcp/features_test.go

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
package mcp
66

77
import (
8-
"context"
98
"slices"
109
"testing"
1110

@@ -18,14 +17,6 @@ type SayHiParams struct {
1817
Name string `json:"name"`
1918
}
2019

21-
func SayHi(ctx context.Context, cc *ServerSession, params *CallToolParamsFor[SayHiParams]) (*CallToolResultFor[any], error) {
22-
return &CallToolResultFor[any]{
23-
Content: []Content{
24-
&TextContent{Text: "Hi " + params.Name},
25-
},
26-
}, nil
27-
}
28-
2920
func TestFeatureSetOrder(t *testing.T) {
3021
toolA := &Tool{Name: "apple", Description: "apple tool"}
3122
toolB := &Tool{Name: "banana", Description: "banana tool"}

mcp/mcp_test.go

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@ type hiParams struct {
3232
// TODO(jba): after schemas are stateless (WIP), this can be a variable.
3333
func greetTool() *Tool { return &Tool{Name: "greet", Description: "say hi"} }
3434

35-
func sayHi(ctx context.Context, req *ServerRequest[*CallToolParamsFor[hiParams]]) (*CallToolResultFor[any], error) {
35+
func sayHi(ctx context.Context, req *ServerRequest[*CallToolParams], args hiParams) (*CallToolResult, any, error) {
3636
if err := req.Session.Ping(ctx, nil); err != nil {
37-
return nil, fmt.Errorf("ping failed: %v", err)
37+
return nil, nil, fmt.Errorf("ping failed: %v", err)
3838
}
39-
return &CallToolResultFor[any]{Content: []Content{&TextContent{Text: "hi " + req.Params.Arguments.Name}}}, nil
39+
return &CallToolResult{Content: []Content{&TextContent{Text: "hi " + args.Name}}}, nil, nil
4040
}
4141

4242
var codeReviewPrompt = &Prompt{
@@ -96,9 +96,9 @@ func TestEndToEnd(t *testing.T) {
9696
Name: "greet",
9797
Description: "say hi",
9898
}, sayHi)
99-
s.AddTool(&Tool{Name: "fail", InputSchema: &jsonschema.Schema{}},
100-
func(context.Context, *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResult, error) {
101-
return nil, errTestFailure
99+
AddTool(s, &Tool{Name: "fail", InputSchema: &jsonschema.Schema{}},
100+
func(context.Context, *ServerRequest[*CallToolParams], map[string]any) (*CallToolResult, any, error) {
101+
return nil, nil, errTestFailure
102102
})
103103
s.AddPrompt(codeReviewPrompt, codReviewPromptHandler)
104104
s.AddPrompt(&Prompt{Name: "fail"}, func(_ context.Context, _ *ServerSession, _ *GetPromptParams) (*GetPromptResult, error) {
@@ -246,7 +246,7 @@ func TestEndToEnd(t *testing.T) {
246246
t.Errorf("tools/call 'fail' mismatch (-want +got):\n%s", diff)
247247
}
248248

249-
s.AddTool(&Tool{Name: "T", InputSchema: &jsonschema.Schema{}}, nopHandler)
249+
s.AddRawTool(&Tool{Name: "T", InputSchema: &jsonschema.Schema{}}, nopHandler)
250250
waitForNotification(t, "tools")
251251
s.RemoveTools("T")
252252
waitForNotification(t, "tools")
@@ -657,7 +657,7 @@ func TestCancellation(t *testing.T) {
657657
return nil, nil
658658
}
659659
_, cs := basicConnection(t, func(s *Server) {
660-
AddTool(s, &Tool{Name: "slow"}, slowRequest)
660+
s.AddRawTool(&Tool{Name: "slow"}, slowRequest)
661661
})
662662
defer cs.Close()
663663

@@ -835,7 +835,7 @@ func traceCalls[S Session](w io.Writer, prefix string) Middleware {
835835
}
836836
}
837837

838-
func nopHandler(context.Context, *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResult, error) {
838+
func nopHandler(context.Context, *ServerRequest[*CallToolParams]) (*CallToolResult, error) {
839839
return nil, nil
840840
}
841841

@@ -946,8 +946,8 @@ func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) {
946946
// This case was written specifically to reproduce a bug where duplicate tools where causing jsonschema errors
947947
t1 := &Tool{Name: "dup", Description: "first", InputSchema: &jsonschema.Schema{}}
948948
t2 := &Tool{Name: "dup", Description: "second", InputSchema: &jsonschema.Schema{}}
949-
s.AddTool(t1, nopHandler)
950-
s.AddTool(t2, nopHandler)
949+
s.AddRawTool(t1, nopHandler)
950+
s.AddRawTool(t2, nopHandler)
951951
})
952952
defer cs.Close()
953953

mcp/protocol.go

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,20 +40,32 @@ type Annotations struct {
4040
Priority float64 `json:"priority,omitempty"`
4141
}
4242

43-
type CallToolParams = CallToolParamsFor[any]
44-
45-
type CallToolParamsFor[In any] struct {
43+
type CallToolParams struct {
4644
// This property is reserved by the protocol to allow clients and servers to
4745
// attach additional metadata to their responses.
4846
Meta `json:"_meta,omitempty"`
4947
Name string `json:"name"`
50-
Arguments In `json:"arguments,omitempty"`
48+
Arguments any `json:"arguments,omitempty"`
5149
}
5250

53-
// The server's response to a tool call.
54-
type CallToolResult = CallToolResultFor[any]
51+
// When unmarshalling CallToolParams on the server side, we need to delay unmarshaling of the arguments.
52+
func (c *CallToolParams) UnmarshalJSON(data []byte) error {
53+
var raw struct {
54+
Meta `json:"_meta,omitempty"`
55+
Name string `json:"name"`
56+
RawArguments json.RawMessage `json:"arguments,omitempty"`
57+
}
58+
if err := json.Unmarshal(data, &raw); err != nil {
59+
return err
60+
}
61+
c.Meta = raw.Meta
62+
c.Name = raw.Name
63+
c.Arguments = raw.RawArguments
64+
return nil
65+
}
5566

56-
type CallToolResultFor[Out any] struct {
67+
// The server's response to a tool call.
68+
type CallToolResult struct {
5769
// This property is reserved by the protocol to allow clients and servers to
5870
// attach additional metadata to their responses.
5971
Meta `json:"_meta,omitempty"`
@@ -62,7 +74,7 @@ type CallToolResultFor[Out any] struct {
6274
Content []Content `json:"content"`
6375
// An optional JSON object that represents the structured result of the tool
6476
// call.
65-
StructuredContent Out `json:"structuredContent,omitempty"`
77+
StructuredContent any `json:"structuredContent,omitempty"`
6678
// Whether the tool call ended in an error.
6779
//
6880
// If not set, this is assumed to be false (the call was successful).
@@ -78,12 +90,12 @@ type CallToolResultFor[Out any] struct {
7890
IsError bool `json:"isError,omitempty"`
7991
}
8092

81-
func (*CallToolResultFor[Out]) isResult() {}
93+
func (*CallToolResult) isResult() {}
8294

8395
// UnmarshalJSON handles the unmarshalling of content into the Content
8496
// interface.
85-
func (x *CallToolResultFor[Out]) UnmarshalJSON(data []byte) error {
86-
type res CallToolResultFor[Out] // avoid recursion
97+
func (x *CallToolResult) UnmarshalJSON(data []byte) error {
98+
type res CallToolResult // avoid recursion
8799
var wire struct {
88100
res
89101
Content []*wireContent `json:"content"`
@@ -95,13 +107,13 @@ func (x *CallToolResultFor[Out]) UnmarshalJSON(data []byte) error {
95107
if wire.res.Content, err = contentsFromWire(wire.Content, nil); err != nil {
96108
return err
97109
}
98-
*x = CallToolResultFor[Out](wire.res)
110+
*x = CallToolResult(wire.res)
99111
return nil
100112
}
101113

102-
func (x *CallToolParamsFor[Out]) isParams() {}
103-
func (x *CallToolParamsFor[Out]) GetProgressToken() any { return getProgressToken(x) }
104-
func (x *CallToolParamsFor[Out]) SetProgressToken(t any) { setProgressToken(x, t) }
114+
func (x *CallToolParams) isParams() {}
115+
func (x *CallToolParams) GetProgressToken() any { return getProgressToken(x) }
116+
func (x *CallToolParams) SetProgressToken(t any) { setProgressToken(x, t) }
105117

106118
type CancelledParams struct {
107119
// This property is reserved by the protocol to allow clients and servers to
@@ -867,6 +879,8 @@ type Tool struct {
867879
// If not provided, Annotations.Title should be used for display if present,
868880
// otherwise Name.
869881
Title string `json:"title,omitempty"`
882+
883+
newArgs func() any
870884
}
871885

872886
// Additional properties describing a Tool to clients.

mcp/protocol_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ func TestCompleteReference(t *testing.T) {
208208
})
209209
}
210210
}
211+
211212
func TestCompleteParams(t *testing.T) {
212213
// Define test cases specifically for Marshalling
213214
marshalTests := []struct {
@@ -514,13 +515,13 @@ func TestContentUnmarshal(t *testing.T) {
514515
var got CallToolResult
515516
roundtrip(ctr, &got)
516517

517-
ctrf := &CallToolResultFor[int]{
518+
ctrf := &CallToolResult{
518519
Meta: Meta{"m": true},
519520
Content: content,
520521
IsError: true,
521522
StructuredContent: 3,
522523
}
523-
var gotf CallToolResultFor[int]
524+
var gotf CallToolResult
524525
roundtrip(ctrf, &gotf)
525526

526527
pm := &PromptMessage{

mcp/server.go

Lines changed: 80 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@ import (
1515
"maps"
1616
"net/url"
1717
"path/filepath"
18+
"reflect"
1819
"slices"
1920
"sync"
2021
"time"
2122

23+
"github.com/google/jsonschema-go/jsonschema"
2224
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
2325
"github.com/modelcontextprotocol/go-sdk/internal/util"
2426
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
@@ -138,59 +140,100 @@ func (s *Server) RemovePrompts(names ...string) {
138140
func() bool { return s.prompts.remove(names...) })
139141
}
140142

141-
// AddTool adds a [Tool] to the server, or replaces one with the same name.
143+
// AddRawTool adds a [Tool] to the server, or replaces one with the same name.
142144
// The Tool argument must not be modified after this call.
143145
//
144146
// The tool's input schema must be non-nil. For a tool that takes no input,
145147
// or one where any input is valid, set [Tool.InputSchema] to the empty schema,
146148
// &jsonschema.Schema{}.
147-
func (s *Server) AddTool(t *Tool, h ToolHandler) {
148-
if t.InputSchema == nil {
149-
// This prevents the tool author from forgetting to write a schema where
150-
// one should be provided. If we papered over this by supplying the empty
151-
// schema, then every input would be validated and the problem wouldn't be
152-
// discovered until runtime, when the LLM sent bad data.
153-
panic(fmt.Sprintf("adding tool %q: nil input schema", t.Name))
154-
}
155-
if err := addToolErr(s, t, h); err != nil {
156-
panic(err)
157-
}
149+
//
150+
// When the handler is invoked as part of a CallTool request, req.Params.Arguments
151+
// will be a json.RawMessage. Unmarshaling the arguments and validating them against the
152+
// input schema are the handler author's responsibility.
153+
func (s *Server) AddRawTool(t *Tool, h RawToolHandler) {
154+
st := &serverTool{tool: t, handler: h}
155+
// Assume there was a change, since add replaces existing tools.
156+
// (It's possible a tool was replaced with an identical one, but not worth checking.)
157+
// TODO: Batch these changes by size and time? The typescript SDK doesn't.
158+
// TODO: Surface notify error here? best not, in case we need to batch.
159+
s.changeAndNotify(notificationToolListChanged, &ToolListChangedParams{},
160+
func() bool { s.tools.add(st); return true })
158161
}
159162

160-
// AddTool adds a [Tool] to the server, or replaces one with the same name.
161163
// If the tool's input schema is nil, it is set to the schema inferred from the In
162164
// type parameter, using [jsonschema.For].
163165
// If the tool's output schema is nil and the Out type parameter is not the empty
164166
// interface, then the output schema is set to the schema inferred from Out.
165-
// The Tool argument must not be modified after this call.
166-
func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) {
167-
if err := addToolErr(s, t, h); err != nil {
168-
panic(err)
167+
func RawToolHandlerFor[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) RawToolHandler {
168+
hh, err := toolForErr(t, h)
169+
if err != nil {
170+
panic(fmt.Sprintf("ToolFor: tool %q: %v", t.Name, err))
169171
}
172+
return hh
170173
}
171174

172-
func addToolErr[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) (err error) {
173-
defer util.Wrapf(&err, "adding tool %q", t.Name)
174-
// If the exact same Tool pointer has already been registered under this name,
175-
// avoid rebuilding schemas and re-registering. This prevents duplicate
176-
// registration from causing errors (and unnecessary work).
177-
s.mu.Lock()
178-
if existing, ok := s.tools.get(t.Name); ok && existing.tool == t {
179-
s.mu.Unlock()
180-
return nil
175+
// TODO(v0.3.0): test
176+
func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (RawToolHandler, error) {
177+
var err error
178+
inputSchema := t.InputSchema
179+
if inputSchema == nil {
180+
inputSchema, err = jsonschema.For[In](nil)
181+
if err != nil {
182+
return nil, fmt.Errorf("input schema: %w", err)
183+
}
181184
}
182-
s.mu.Unlock()
183-
st, err := newServerTool(t, h)
185+
inputResolved, err := inputSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
184186
if err != nil {
185-
return err
187+
return nil, fmt.Errorf("resolving input schema: %w", err)
186188
}
187-
// Assume there was a change, since add replaces existing tools.
188-
// (It's possible a tool was replaced with an identical one, but not worth checking.)
189-
// TODO: Batch these changes by size and time? The typescript SDK doesn't.
190-
// TODO: Surface notify error here? best not, in case we need to batch.
191-
s.changeAndNotify(notificationToolListChanged, &ToolListChangedParams{},
192-
func() bool { s.tools.add(st); return true })
193-
return nil
189+
190+
outputSchema := t.OutputSchema
191+
if outputSchema == nil && reflect.TypeFor[Out]() != reflect.TypeFor[any]() {
192+
outputSchema, err = jsonschema.For[Out](nil)
193+
}
194+
if err != nil {
195+
return nil, fmt.Errorf("output schema: %w", err)
196+
}
197+
outputResolved, err := outputSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
198+
if err != nil {
199+
return nil, fmt.Errorf("resolving output schema: %w", err)
200+
}
201+
202+
th := func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) {
203+
// Unmarshal and validate args.
204+
rawArgs := req.Params.Arguments.(json.RawMessage)
205+
var in In
206+
if rawArgs != nil {
207+
if err := unmarshalSchema(rawArgs, inputResolved, &in); err != nil {
208+
return nil, err
209+
}
210+
}
211+
212+
// Call typed handler.
213+
res, out, err := h(ctx, req, in)
214+
if err != nil {
215+
return nil, err
216+
}
217+
218+
// TODO(v0.3.0): Validate out.
219+
_ = outputResolved
220+
221+
// TODO: return the serialized JSON in a TextContent block, as per spec?
222+
// https://modelcontextprotocol.io/specification/2025-06-18/server/tools#structured-content
223+
// But people may use res.Content for other things.
224+
if res == nil {
225+
res = &CallToolResult{}
226+
}
227+
res.StructuredContent = out
228+
return res, nil
229+
}
230+
231+
return th, nil
232+
}
233+
234+
// AddTool is a convenience for s.AddRawTool(t, RawToolHandler(t, h)).
235+
func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) {
236+
s.AddRawTool(t, RawToolHandlerFor(t, h))
194237
}
195238

196239
// RemoveTools removes the tools with the given names.
@@ -335,7 +378,7 @@ func (s *Server) listTools(_ context.Context, req *ServerRequest[*ListToolsParam
335378
})
336379
}
337380

338-
func (s *Server) callTool(ctx context.Context, req *ServerRequest[*CallToolParamsFor[json.RawMessage]]) (*CallToolResult, error) {
381+
func (s *Server) callTool(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) {
339382
s.mu.Lock()
340383
st, ok := s.tools.get(req.Params.Name)
341384
s.mu.Unlock()

mcp/server_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ func TestServerCapabilities(t *testing.T) {
296296
{
297297
name: "With tools",
298298
configureServer: func(s *Server) {
299-
s.AddTool(tool, nil)
299+
s.AddRawTool(tool, nil)
300300
},
301301
wantCapabilities: &ServerCapabilities{
302302
Logging: &LoggingCapabilities{},
@@ -322,7 +322,7 @@ func TestServerCapabilities(t *testing.T) {
322322
s.AddPrompt(&Prompt{Name: "p"}, nil)
323323
s.AddResource(&Resource{URI: "file:///r"}, nil)
324324
s.AddResourceTemplate(&ResourceTemplate{URITemplate: "file:///rt"}, nil)
325-
s.AddTool(tool, nil)
325+
s.AddRawTool(tool, nil)
326326
},
327327
serverOpts: ServerOptions{
328328
SubscribeHandler: func(context.Context, *ServerRequest[*SubscribeParams]) error {

0 commit comments

Comments
 (0)