Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/server/sse/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ type SayHiParams struct {
Name string `json:"name"`
}

func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[SayHiParams]]) (*mcp.CallToolResultFor[any], error) {
return &mcp.CallToolResultFor[any]{
func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args SayHiParams) (*mcp.CallToolResult, any, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: "Hi " + req.Params.Arguments.Name},
&mcp.TextContent{Text: "Hi " + args.Name},
},
}, nil
}, nil, nil
}

func main() {
Expand Down
9 changes: 5 additions & 4 deletions mcp/example_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,16 @@ func Example_loggingMiddleware() {
},
func(
ctx context.Context,
req *mcp.ServerRequest[*mcp.CallToolParamsFor[map[string]any]],
) (*mcp.CallToolResultFor[any], error) {
name, ok := req.Params.Arguments["name"].(string)
req *mcp.ServerRequest[*mcp.CallToolParams],
args any,
) (*mcp.CallToolResult, error) {
name, ok := args.(map[string]any)["name"].(string)
if !ok {
return nil, fmt.Errorf("name parameter is required and must be a string")
}

message := fmt.Sprintf("Hello, %s!", name)
return &mcp.CallToolResultFor[any]{
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: message},
},
Expand Down
12 changes: 6 additions & 6 deletions mcp/features_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ type SayHiParams struct {
Name string `json:"name"`
}

func SayHi(ctx context.Context, cc *ServerSession, params *CallToolParamsFor[SayHiParams]) (*CallToolResultFor[any], error) {
return &CallToolResultFor[any]{
func SayHi(ctx context.Context, req *ServerRequest[*CallToolParams], args SayHiParams) (*CallToolResult, any, error) {
return &CallToolResult{
Content: []Content{
&TextContent{Text: "Hi " + params.Name},
&TextContent{Text: "Hi " + args.Name},
},
}, nil
}, nil, nil
}

func TestFeatureSetOrder(t *testing.T) {
Expand All @@ -45,7 +45,7 @@ func TestFeatureSetOrder(t *testing.T) {
fs := newFeatureSet(func(t *Tool) string { return t.Name })
fs.add(tc.tools...)
got := slices.Collect(fs.all())
if diff := cmp.Diff(got, tc.want, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" {
if diff := cmp.Diff(got, tc.want, cmpopts.IgnoreUnexported(jsonschema.Schema{}, Tool{})); diff != "" {
t.Errorf("expected %v, got %v, (-want +got):\n%s", tc.want, got, diff)
}
}
Expand All @@ -69,7 +69,7 @@ func TestFeatureSetAbove(t *testing.T) {
fs := newFeatureSet(func(t *Tool) string { return t.Name })
fs.add(tc.tools...)
got := slices.Collect(fs.above(tc.above))
if diff := cmp.Diff(got, tc.want, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" {
if diff := cmp.Diff(got, tc.want, cmpopts.IgnoreUnexported(jsonschema.Schema{}, Tool{})); diff != "" {
t.Errorf("expected %v, got %v, (-want +got):\n%s", tc.want, got, diff)
}
}
Expand Down
12 changes: 6 additions & 6 deletions mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ type hiParams struct {
// TODO(jba): after schemas are stateless (WIP), this can be a variable.
func greetTool() *Tool { return &Tool{Name: "greet", Description: "say hi"} }

func sayHi(ctx context.Context, req *ServerRequest[*CallToolParamsFor[hiParams]]) (*CallToolResultFor[any], error) {
func sayHi(ctx context.Context, req *ServerRequest[*CallToolParams], args hiParams) (*CallToolResult, any, error) {
if err := req.Session.Ping(ctx, nil); err != nil {
return nil, fmt.Errorf("ping failed: %v", err)
return nil, nil, fmt.Errorf("ping failed: %v", err)
}
return &CallToolResultFor[any]{Content: []Content{&TextContent{Text: "hi " + req.Params.Arguments.Name}}}, nil
return &CallToolResult{Content: []Content{&TextContent{Text: "hi " + args.Name}}}, nil, nil
}

var codeReviewPrompt = &Prompt{
Expand Down Expand Up @@ -97,7 +97,7 @@ func TestEndToEnd(t *testing.T) {
Description: "say hi",
}, sayHi)
s.AddTool(&Tool{Name: "fail", InputSchema: &jsonschema.Schema{}},
func(context.Context, *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResult, error) {
func(context.Context, *ServerRequest[*CallToolParams], any) (*CallToolResult, error) {
return nil, errTestFailure
})
s.AddPrompt(codeReviewPrompt, codReviewPromptHandler)
Expand Down Expand Up @@ -647,7 +647,7 @@ func TestCancellation(t *testing.T) {
cancelled = make(chan struct{}, 1) // don't block the request
)

slowRequest := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResult, error) {
slowRequest := func(ctx context.Context, _ *ServerRequest[*CallToolParams], _ any) (*CallToolResult, error) {
start <- struct{}{}
select {
case <-ctx.Done():
Expand Down Expand Up @@ -836,7 +836,7 @@ func traceCalls[S Session](w io.Writer, prefix string) Middleware {
}
}

func nopHandler(context.Context, *ServerRequest[*CallToolParamsFor[map[string]any]]) (*CallToolResult, error) {
func nopHandler(context.Context, *ServerRequest[*CallToolParams], any) (*CallToolResult, error) {
return nil, nil
}

Expand Down
44 changes: 29 additions & 15 deletions mcp/protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,32 @@ type Annotations struct {
Priority float64 `json:"priority,omitempty"`
}

type CallToolParams = CallToolParamsFor[any]

type CallToolParamsFor[In any] struct {
type CallToolParams struct {
// This property is reserved by the protocol to allow clients and servers to
// attach additional metadata to their responses.
Meta `json:"_meta,omitempty"`
Name string `json:"name"`
Arguments In `json:"arguments,omitempty"`
Arguments any `json:"arguments,omitempty"`
}

// The server's response to a tool call.
type CallToolResult = CallToolResultFor[any]
// When unmarshalling CallToolParams on the server side, we need to delay unmarshaling of the arguments.
func (c *CallToolParams) UnmarshalJSON(data []byte) error {
var raw struct {
Meta `json:"_meta,omitempty"`
Name string `json:"name"`
RawArguments json.RawMessage `json:"arguments,omitempty"`
}
if err := json.Unmarshal(data, &raw); err != nil {
return err
}
c.Meta = raw.Meta
c.Name = raw.Name
c.Arguments = raw.RawArguments
return nil
}

type CallToolResultFor[Out any] struct {
// The server's response to a tool call.
type CallToolResult struct {
// This property is reserved by the protocol to allow clients and servers to
// attach additional metadata to their responses.
Meta `json:"_meta,omitempty"`
Expand All @@ -62,7 +74,7 @@ type CallToolResultFor[Out any] struct {
Content []Content `json:"content"`
// An optional JSON object that represents the structured result of the tool
// call.
StructuredContent Out `json:"structuredContent,omitempty"`
StructuredContent any `json:"structuredContent,omitempty"`
// Whether the tool call ended in an error.
//
// If not set, this is assumed to be false (the call was successful).
Expand All @@ -78,12 +90,12 @@ type CallToolResultFor[Out any] struct {
IsError bool `json:"isError,omitempty"`
}

func (*CallToolResultFor[Out]) isResult() {}
func (*CallToolResult) isResult() {}

// UnmarshalJSON handles the unmarshalling of content into the Content
// interface.
func (x *CallToolResultFor[Out]) UnmarshalJSON(data []byte) error {
type res CallToolResultFor[Out] // avoid recursion
func (x *CallToolResult) UnmarshalJSON(data []byte) error {
type res CallToolResult // avoid recursion
var wire struct {
res
Content []*wireContent `json:"content"`
Expand All @@ -95,13 +107,13 @@ func (x *CallToolResultFor[Out]) UnmarshalJSON(data []byte) error {
if wire.res.Content, err = contentsFromWire(wire.Content, nil); err != nil {
return err
}
*x = CallToolResultFor[Out](wire.res)
*x = CallToolResult(wire.res)
return nil
}

func (x *CallToolParamsFor[Out]) isParams() {}
func (x *CallToolParamsFor[Out]) GetProgressToken() any { return getProgressToken(x) }
func (x *CallToolParamsFor[Out]) SetProgressToken(t any) { setProgressToken(x, t) }
func (x *CallToolParams) isParams() {}
func (x *CallToolParams) GetProgressToken() any { return getProgressToken(x) }
func (x *CallToolParams) SetProgressToken(t any) { setProgressToken(x, t) }

type CancelledParams struct {
// This property is reserved by the protocol to allow clients and servers to
Expand Down Expand Up @@ -867,6 +879,8 @@ type Tool struct {
// If not provided, Annotations.Title should be used for display if present,
// otherwise Name.
Title string `json:"title,omitempty"`

newArgs func() any
}

// Additional properties describing a Tool to clients.
Expand Down
15 changes: 9 additions & 6 deletions mcp/protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ func TestCompleteReference(t *testing.T) {
})
}
}

func TestCompleteParams(t *testing.T) {
// Define test cases specifically for Marshalling
marshalTests := []struct {
Expand Down Expand Up @@ -514,13 +515,15 @@ func TestContentUnmarshal(t *testing.T) {
var got CallToolResult
roundtrip(ctr, &got)

ctrf := &CallToolResultFor[int]{
Meta: Meta{"m": true},
Content: content,
IsError: true,
StructuredContent: 3,
ctrf := &CallToolResult{
Meta: Meta{"m": true},
Content: content,
IsError: true,
// Ints become floats with zero fractional part when unmarshaled.
// The jsoncschema package will validate these against a schema with type "integer".
StructuredContent: float64(3),
}
var gotf CallToolResultFor[int]
var gotf CallToolResult
roundtrip(ctrf, &gotf)

pm := &PromptMessage{
Expand Down
35 changes: 12 additions & 23 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"context"
"encoding/base64"
"encoding/gob"
"encoding/json"
"fmt"
"iter"
"maps"
Expand Down Expand Up @@ -145,16 +144,7 @@ func (s *Server) RemovePrompts(names ...string) {
// or one where any input is valid, set [Tool.InputSchema] to the empty schema,
// &jsonschema.Schema{}.
func (s *Server) AddTool(t *Tool, h ToolHandler) {
if t.InputSchema == nil {
// This prevents the tool author from forgetting to write a schema where
// one should be provided. If we papered over this by supplying the empty
// schema, then every input would be validated and the problem wouldn't be
// discovered until runtime, when the LLM sent bad data.
panic(fmt.Sprintf("adding tool %q: nil input schema", t.Name))
}
if err := addToolErr(s, t, h); err != nil {
panic(err)
}
s.addServerTool(newServerTool(t, h))
}

// AddTool adds a [Tool] to the server, or replaces one with the same name.
Expand All @@ -163,25 +153,24 @@ func (s *Server) AddTool(t *Tool, h ToolHandler) {
// If the tool's output schema is nil and the Out type parameter is not the empty
// interface, then the output schema is set to the schema inferred from Out.
// The Tool argument must not be modified after this call.
func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) {
if err := addToolErr(s, t, h); err != nil {
panic(err)
}
//
// The handler should return the result as the second return value. The first return value,
// a *CallToolResult, may be nil, or its fields other than StructuredContent may be
// populated.
func AddTool[In, Out any](s *Server, t *Tool, h TypedToolHandler[In, Out]) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This still doesn't solve the problem of easily wrapping all tool handlers.

Based on feedback, we need a way to access the underlying tool handler that is created by AddTool.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what the use case is here.
If you know you want to wrap handlers before you add tools, you would call TypedTool separately.
If you want to do so after adding tools, well, we have no server-side way of enumerating tools anyway, nor has anyone ever asked for one.

s.addServerTool(newTypedServerTool(t, h))
}

func addToolErr[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) (err error) {
defer util.Wrapf(&err, "adding tool %q", t.Name)
st, err := newServerTool(t, h)
func (s *Server) addServerTool(st *serverTool, err error) {
if err != nil {
return err
panic(fmt.Sprintf("adding tool %q: %v", st.tool.Name, err))
}
// Assume there was a change, since add replaces existing tools.
// (It's possible a tool was replaced with an identical one, but not worth checking.)
// TODO: Batch these changes by size and time? The typescript SDK doesn't.
// TODO: Surface notify error here? best not, in case we need to batch.
s.changeAndNotify(notificationToolListChanged, &ToolListChangedParams{},
func() bool { s.tools.add(st); return true })
return nil
}

// RemoveTools removes the tools with the given names.
Expand Down Expand Up @@ -326,7 +315,7 @@ func (s *Server) listTools(_ context.Context, req *ServerRequest[*ListToolsParam
})
}

func (s *Server) callTool(ctx context.Context, req *ServerRequest[*CallToolParamsFor[json.RawMessage]]) (*CallToolResult, error) {
func (s *Server) callTool(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) {
s.mu.Lock()
st, ok := s.tools.get(req.Params.Name)
s.mu.Unlock()
Expand Down Expand Up @@ -612,7 +601,7 @@ func (ss *ServerSession) initialized(ctx context.Context, params *InitializedPar
return nil, fmt.Errorf("duplicate %q received", notificationInitialized)
}
if h := ss.server.opts.InitializedHandler; h != nil {
h(ctx, serverRequestFor(ss, params))
h(ctx, newServerRequest(ss, params))
}
return nil, nil
}
Expand All @@ -626,7 +615,7 @@ func (s *Server) callRootsListChangedHandler(ctx context.Context, req *ServerReq

func (ss *ServerSession) callProgressNotificationHandler(ctx context.Context, p *ProgressNotificationParams) (Result, error) {
if h := ss.server.opts.ProgressNotificationHandler; h != nil {
h(ctx, serverRequestFor(ss, p))
h(ctx, newServerRequest(ss, p))
}
return nil, nil
}
Expand Down
8 changes: 4 additions & 4 deletions mcp/server_example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ type SayHiParams struct {
Name string `json:"name"`
}

func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParamsFor[SayHiParams]]) (*mcp.CallToolResultFor[any], error) {
return &mcp.CallToolResultFor[any]{
func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args SayHiParams) (*mcp.CallToolResult, any, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: "Hi " + req.Params.Arguments.Name},
&mcp.TextContent{Text: "Hi " + args.Name},
},
}, nil
}, nil, nil
}

func ExampleServer() {
Expand Down
4 changes: 0 additions & 4 deletions mcp/shared.go
Original file line number Diff line number Diff line change
Expand Up @@ -408,10 +408,6 @@ func (r *ServerRequest[P]) GetSession() Session { return r.Session }
func (r *ClientRequest[P]) GetParams() Params { return r.Params }
func (r *ServerRequest[P]) GetParams() Params { return r.Params }

func serverRequestFor[P Params](s *ServerSession, p P) *ServerRequest[P] {
return &ServerRequest[P]{Session: s, Params: p}
}

func clientRequestFor[P Params](s *ClientSession, p P) *ClientRequest[P] {
return &ClientRequest[P]{Session: s, Params: p}
}
Expand Down
Loading
Loading