Skip to content

Commit 7bfde44

Browse files
authored
mcp: statically type the server-side tool params (#378)
Introduce CallToolParamsRaw on the server side, so that tool handlers see in the type system that a tool's arguments are a json.RawMessage. Fixes #377.
1 parent 8f11a86 commit 7bfde44

File tree

4 files changed

+17
-22
lines changed

4 files changed

+17
-22
lines changed

mcp/mcp_test.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,6 @@ func TestEndToEnd(t *testing.T) {
494494
if result.Action != "accept" {
495495
t.Errorf("got action %q, want %q", result.Action, "accept")
496496
}
497-
498497
})
499498

500499
// Disconnect.
@@ -1638,7 +1637,7 @@ func TestPointerArgEquivalence(t *testing.T) {
16381637
//
16391638
// We handle a few different types of results, to assert they behave the
16401639
// same in all cases.
1641-
AddTool(s, &Tool{Name: "pointer"}, func(_ context.Context, req *ServerRequest[*CallToolParams], in *input) (*CallToolResult, *output, error) {
1640+
AddTool(s, &Tool{Name: "pointer"}, func(_ context.Context, req *CallToolRequest, in *input) (*CallToolResult, *output, error) {
16421641
switch in.In {
16431642
case "":
16441643
return nil, nil, fmt.Errorf("must provide input")
@@ -1652,7 +1651,7 @@ func TestPointerArgEquivalence(t *testing.T) {
16521651
panic("unreachable")
16531652
}
16541653
})
1655-
AddTool(s, &Tool{Name: "nonpointer"}, func(_ context.Context, req *ServerRequest[*CallToolParams], in input) (*CallToolResult, output, error) {
1654+
AddTool(s, &Tool{Name: "nonpointer"}, func(_ context.Context, req *CallToolRequest, in input) (*CallToolResult, output, error) {
16561655
switch in.In {
16571656
case "":
16581657
return nil, output{}, fmt.Errorf("must provide input")

mcp/protocol.go

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ type Annotations struct {
4040
Priority float64 `json:"priority,omitempty"`
4141
}
4242

43+
// CallToolParams is used by clients to call a tool.
4344
type CallToolParams struct {
4445
// This property is reserved by the protocol to allow clients and servers to
4546
// attach additional metadata to their responses.
@@ -48,20 +49,13 @@ type CallToolParams struct {
4849
Arguments any `json:"arguments,omitempty"`
4950
}
5051

51-
// When unmarshalling CallToolParams on the server side, we need to delay unmarshaling of the arguments.
52-
func (c *CallToolParams) UnmarshalJSON(data []byte) error {
53-
var raw struct {
54-
Meta `json:"_meta,omitempty"`
55-
Name string `json:"name"`
56-
RawArguments json.RawMessage `json:"arguments,omitempty"`
57-
}
58-
if err := json.Unmarshal(data, &raw); err != nil {
59-
return err
60-
}
61-
c.Meta = raw.Meta
62-
c.Name = raw.Name
63-
c.Arguments = raw.RawArguments
64-
return nil
52+
// CallToolParamsRaw is passed to tool handlers on the server.
53+
type CallToolParamsRaw struct {
54+
// This property is reserved by the protocol to allow clients and servers to
55+
// attach additional metadata to their responses.
56+
Meta `json:"_meta,omitempty"`
57+
Name string `json:"name"`
58+
Arguments json.RawMessage `json:"arguments,omitempty"`
6559
}
6660

6761
// The server's response to a tool call.
@@ -115,6 +109,10 @@ func (x *CallToolParams) isParams() {}
115109
func (x *CallToolParams) GetProgressToken() any { return getProgressToken(x) }
116110
func (x *CallToolParams) SetProgressToken(t any) { setProgressToken(x, t) }
117111

112+
func (x *CallToolParamsRaw) isParams() {}
113+
func (x *CallToolParamsRaw) GetProgressToken() any { return getProgressToken(x) }
114+
func (x *CallToolParamsRaw) SetProgressToken(t any) { setProgressToken(x, t) }
115+
118116
type CancelledParams struct {
119117
// This property is reserved by the protocol to allow clients and servers to
120118
// attach additional metadata to their responses.

mcp/requests.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
package mcp
88

99
type (
10-
CallToolRequest = ServerRequest[*CallToolParams]
10+
CallToolRequest = ServerRequest[*CallToolParamsRaw]
1111
CompleteRequest = ServerRequest[*CompleteParams]
1212
GetPromptRequest = ServerRequest[*GetPromptParams]
1313
InitializedRequest = ServerRequest[*InitializedParams]

mcp/server.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"context"
1010
"encoding/base64"
1111
"encoding/gob"
12-
"encoding/json"
1312
"fmt"
1413
"iter"
1514
"maps"
@@ -234,10 +233,9 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan
234233

235234
th := func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) {
236235
// Unmarshal and validate args.
237-
rawArgs := req.Params.Arguments.(json.RawMessage)
238236
var in In
239-
if rawArgs != nil {
240-
if err := unmarshalSchema(rawArgs, inputResolved, &in); err != nil {
237+
if req.Params.Arguments != nil {
238+
if err := unmarshalSchema(req.Params.Arguments, inputResolved, &in); err != nil {
241239
return nil, err
242240
}
243241
}

0 commit comments

Comments
 (0)