Skip to content

Commit 1aaa3ec

Browse files
committed
mcp: new tool API
1 parent 9b6327b commit 1aaa3ec

File tree

12 files changed

+219
-361
lines changed

12 files changed

+219
-361
lines changed

mcp/client_list_test.go

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@ func TestList(t *testing.T) {
2222
defer serverSession.Close()
2323

2424
t.Run("tools", func(t *testing.T) {
25-
toolA := mcp.NewServerTool("apple", "apple tool", SayHi)
26-
toolB := mcp.NewServerTool("banana", "banana tool", SayHi)
27-
toolC := mcp.NewServerTool("cherry", "cherry tool", SayHi)
28-
tools := []*mcp.ServerTool{toolA, toolB, toolC}
29-
wantTools := []*mcp.Tool{toolA.Tool, toolB.Tool, toolC.Tool}
30-
server.AddTools(tools...)
25+
var wantTools []*mcp.Tool
26+
for _, name := range []string{"apple", "banana", "cherry"} {
27+
wantTools = append(wantTools, &mcp.Tool{Name: name, Description: name + " tool"})
28+
}
29+
30+
for _, t := range wantTools {
31+
mcp.AddTool(server, t, SayHi)
32+
}
3133
t.Run("list", func(t *testing.T) {
3234
res, err := clientSession.ListTools(ctx, nil)
3335
if err != nil {

mcp/cmd_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@ func runServer() {
3131
ctx := context.Background()
3232

3333
server := mcp.NewServer("greeter", "v0.0.1", nil)
34-
server.AddTools(mcp.NewServerTool("greet", "say hi", SayHi))
35-
34+
mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi)
3635
if err := server.Run(ctx, mcp.NewStdioTransport()); err != nil {
3736
log.Fatal(err)
3837
}

mcp/features_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@ func SayHi(ctx context.Context, cc *ServerSession, params *CallToolParamsFor[Say
2727
}
2828

2929
func TestFeatureSetOrder(t *testing.T) {
30-
toolA := NewServerTool("apple", "apple tool", SayHi).Tool
31-
toolB := NewServerTool("banana", "banana tool", SayHi).Tool
32-
toolC := NewServerTool("cherry", "cherry tool", SayHi).Tool
30+
toolA := &Tool{Name: "apple", Description: "apple tool"}
31+
toolB := &Tool{Name: "banana", Description: "banana tool"}
32+
toolC := &Tool{Name: "cherry", Description: "cherry tool"}
3333

3434
testCases := []struct {
3535
tools []*Tool
@@ -52,9 +52,9 @@ func TestFeatureSetOrder(t *testing.T) {
5252
}
5353

5454
func TestFeatureSetAbove(t *testing.T) {
55-
toolA := NewServerTool("apple", "apple tool", SayHi).Tool
56-
toolB := NewServerTool("banana", "banana tool", SayHi).Tool
57-
toolC := NewServerTool("cherry", "cherry tool", SayHi).Tool
55+
toolA := &Tool{Name: "apple", Description: "apple tool"}
56+
toolB := &Tool{Name: "banana", Description: "banana tool"}
57+
toolC := &Tool{Name: "cherry", Description: "cherry tool"}
5858

5959
testCases := []struct {
6060
tools []*Tool

mcp/mcp_test.go

Lines changed: 27 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ import (
2121
"time"
2222

2323
"github.com/google/go-cmp/cmp"
24-
"github.com/google/go-cmp/cmp/cmpopts"
2524
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
2625
"github.com/modelcontextprotocol/go-sdk/jsonschema"
2726
)
@@ -30,6 +29,9 @@ type hiParams struct {
3029
Name string
3130
}
3231

32+
// TODO(jba): after schemas are stateless (WIP), this can be a variable.
33+
func greetTool() *Tool { return &Tool{Name: "greet", Description: "say hi"} }
34+
3335
func sayHi(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[hiParams]) (*CallToolResultFor[any], error) {
3436
if err := ss.Ping(ctx, nil); err != nil {
3537
return nil, fmt.Errorf("ping failed: %v", err)
@@ -63,7 +65,14 @@ func TestEndToEnd(t *testing.T) {
6365
},
6466
}
6567
s := NewServer("testServer", "v1.0.0", sopts)
66-
add(tools, s.AddTools, "greet", "fail")
68+
AddTool(s, &Tool{
69+
Name: "greet",
70+
Description: "say hi",
71+
}, sayHi)
72+
s.AddTool(&Tool{Name: "fail", InputSchema: &jsonschema.Schema{}},
73+
func(context.Context, *ServerSession, *CallToolParamsFor[map[string]any]) (*CallToolResult, error) {
74+
return nil, errTestFailure
75+
})
6776
add(prompts, s.AddPrompts, "code_review", "fail")
6877
add(resources, s.AddResources, "info.txt", "fail.txt")
6978

@@ -161,32 +170,7 @@ func TestEndToEnd(t *testing.T) {
161170
})
162171

163172
t.Run("tools", func(t *testing.T) {
164-
res, err := cs.ListTools(ctx, nil)
165-
if err != nil {
166-
t.Errorf("tools/list failed: %v", err)
167-
}
168-
wantTools := []*Tool{
169-
{
170-
Name: "fail",
171-
InputSchema: nil,
172-
},
173-
{
174-
Name: "greet",
175-
Description: "say hi",
176-
InputSchema: &jsonschema.Schema{
177-
Type: "object",
178-
Required: []string{"Name"},
179-
Properties: map[string]*jsonschema.Schema{
180-
"Name": {Type: "string"},
181-
},
182-
AdditionalProperties: falseSchema(),
183-
},
184-
},
185-
}
186-
if diff := cmp.Diff(wantTools, res.Tools, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" {
187-
t.Fatalf("tools/list mismatch (-want +got):\n%s", diff)
188-
}
189-
173+
// ListTools is tested in client_list_test.go.
190174
gotHi, err := cs.CallTool(ctx, &CallToolParams{
191175
Name: "greet",
192176
Arguments: map[string]any{"name": "user"},
@@ -222,7 +206,7 @@ func TestEndToEnd(t *testing.T) {
222206
t.Errorf("tools/call 'fail' mismatch (-want +got):\n%s", diff)
223207
}
224208

225-
s.AddTools(&ServerTool{Tool: &Tool{Name: "T"}, Handler: nopHandler})
209+
s.AddTool(&Tool{Name: "T", InputSchema: &jsonschema.Schema{}}, nopHandler)
226210
waitForNotification(t, "tools")
227211
s.RemoveTools("T")
228212
waitForNotification(t, "tools")
@@ -434,16 +418,6 @@ func TestEndToEnd(t *testing.T) {
434418
var (
435419
errTestFailure = errors.New("mcp failure")
436420

437-
tools = map[string]*ServerTool{
438-
"greet": NewServerTool("greet", "say hi", sayHi),
439-
"fail": {
440-
Tool: &Tool{Name: "fail"},
441-
Handler: func(context.Context, *ServerSession, *CallToolParamsFor[map[string]any]) (*CallToolResult, error) {
442-
return nil, errTestFailure
443-
},
444-
},
445-
}
446-
447421
prompts = map[string]*ServerPrompt{
448422
"code_review": {
449423
Prompt: &Prompt{
@@ -540,21 +514,21 @@ func errorCode(err error) int64 {
540514
return -1
541515
}
542516

543-
// basicConnection returns a new basic client-server connection configured with
544-
// the provided tools.
517+
// basicConnection returns a new basic client-server connection, with the server
518+
// configured via the provided function.
545519
//
546520
// The caller should cancel either the client connection or server connection
547521
// when the connections are no longer needed.
548-
func basicConnection(t *testing.T, tools ...*ServerTool) (*ServerSession, *ClientSession) {
522+
func basicConnection(t *testing.T, config func(*Server)) (*ServerSession, *ClientSession) {
549523
t.Helper()
550524

551525
ctx := context.Background()
552526
ct, st := NewInMemoryTransports()
553527

554528
s := NewServer("testServer", "v1.0.0", nil)
555-
556-
// The 'greet' tool says hi.
557-
s.AddTools(tools...)
529+
if config != nil {
530+
config(s)
531+
}
558532
ss, err := s.Connect(ctx, st)
559533
if err != nil {
560534
t.Fatal(err)
@@ -569,7 +543,9 @@ func basicConnection(t *testing.T, tools ...*ServerTool) (*ServerSession, *Clien
569543
}
570544

571545
func TestServerClosing(t *testing.T) {
572-
cc, cs := basicConnection(t, NewServerTool("greet", "say hi", sayHi))
546+
cc, cs := basicConnection(t, func(s *Server) {
547+
AddTool(s, greetTool(), sayHi)
548+
})
573549
defer cs.Close()
574550

575551
ctx := context.Background()
@@ -651,11 +627,9 @@ func TestCancellation(t *testing.T) {
651627
}
652628
return nil, nil
653629
}
654-
st := &ServerTool{
655-
Tool: &Tool{Name: "slow"},
656-
Handler: slowRequest,
657-
}
658-
_, cs := basicConnection(t, st)
630+
_, cs := basicConnection(t, func(s *Server) {
631+
s.AddTool(&Tool{Name: "slow"}, slowRequest)
632+
})
659633
defer cs.Close()
660634

661635
ctx, cancel := context.WithCancel(context.Background())
@@ -852,7 +826,7 @@ func TestKeepAlive(t *testing.T) {
852826
KeepAlive: 100 * time.Millisecond,
853827
}
854828
s := NewServer("testServer", "v1.0.0", serverOpts)
855-
s.AddTools(NewServerTool("greet", "say hi", sayHi))
829+
AddTool(s, greetTool(), sayHi)
856830

857831
ss, err := s.Connect(ctx, st)
858832
if err != nil {
@@ -897,7 +871,7 @@ func TestKeepAliveFailure(t *testing.T) {
897871

898872
// Server without keepalive (to test one-sided keepalive)
899873
s := NewServer("testServer", "v1.0.0", nil)
900-
s.AddTools(NewServerTool("greet", "say hi", sayHi))
874+
AddTool(s, greetTool(), sayHi)
901875
ss, err := s.Connect(ctx, st)
902876
if err != nil {
903877
t.Fatal(err)

mcp/server.go

Lines changed: 35 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"encoding/json"
1313
"fmt"
1414
"iter"
15+
"log"
1516
"net/url"
1617
"path/filepath"
1718
"slices"
@@ -20,7 +21,6 @@ import (
2021

2122
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
2223
"github.com/modelcontextprotocol/go-sdk/internal/util"
23-
"github.com/modelcontextprotocol/go-sdk/jsonschema"
2424
)
2525

2626
const DefaultPageSize = 1000
@@ -37,7 +37,7 @@ type Server struct {
3737

3838
mu sync.Mutex
3939
prompts *featureSet[*ServerPrompt]
40-
tools *featureSet[*ServerTool]
40+
tools *featureSet[*serverTool]
4141
resources *featureSet[*ServerResource]
4242
resourceTemplates *featureSet[*ServerResourceTemplate]
4343
sessions []*ServerSession
@@ -88,7 +88,7 @@ func NewServer(name, version string, opts *ServerOptions) *Server {
8888
version: version,
8989
opts: *opts,
9090
prompts: newFeatureSet(func(p *ServerPrompt) string { return p.Prompt.Name }),
91-
tools: newFeatureSet(func(t *ServerTool) string { return t.Tool.Name }),
91+
tools: newFeatureSet(func(t *serverTool) string { return t.tool.Name }),
9292
resources: newFeatureSet(func(r *ServerResource) string { return r.Resource.URI }),
9393
resourceTemplates: newFeatureSet(func(t *ServerResourceTemplate) string { return t.ResourceTemplate.URITemplate }),
9494
sendingMethodHandler_: defaultSendingMethodHandler[*ServerSession],
@@ -118,55 +118,44 @@ func (s *Server) RemovePrompts(names ...string) {
118118
func() bool { return s.prompts.remove(names...) })
119119
}
120120

121-
// AddTools adds the given tools to the server,
122-
// replacing any with the same names.
123-
// The arguments must not be modified after this call.
124-
//
125-
// AddTools panics if errors are detected.
126-
func (s *Server) AddTools(tools ...*ServerTool) {
127-
if err := s.addToolsErr(tools...); err != nil {
121+
// AddTool adds a [Tool] to the server, or replaces one with the same name.
122+
// The tool's input schema must be non-nil.
123+
// The Tool argument must not be modified after this call.
124+
func (s *Server) AddTool(t *Tool, h ToolHandler) {
125+
// TODO(jba): This is a breaking behavior change. Add before v0.2.0?
126+
if t.InputSchema == nil {
127+
log.Printf("mcp: tool %q has a nil input schema. This will panic in a future release.", t.Name)
128+
// panic(fmt.Sprintf("adding tool %q: nil input schema", t.Name))
129+
}
130+
if err := addToolErr(s, t, h); err != nil {
128131
panic(err)
129132
}
130133
}
131134

132-
// addToolsErr is like [AddTools], but returns an error instead of panicking.
133-
func (s *Server) addToolsErr(tools ...*ServerTool) error {
134-
// Only notify if something could change.
135-
if len(tools) == 0 {
136-
return nil
135+
// AddTool adds a [Tool] to the server, or replaces one with the same name.
136+
// If the tool's input schema is nil, it is set to the schema inferred from the In
137+
// type parameter, using [jsonschema.For].
138+
// If the tool's output schema is nil and the Out type parameter is not the empty
139+
// interface, then the output schema is set to the schema inferred from Out.
140+
// The Tool argument must not be modified after this call.
141+
func AddTool[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) {
142+
if err := addToolErr(s, t, h); err != nil {
143+
panic(err)
137144
}
138-
// Wrap the user's Handlers with rawHandlers that take a json.RawMessage.
139-
for _, st := range tools {
140-
if st.rawHandler == nil {
141-
// This ServerTool was not created with NewServerTool.
142-
if st.Handler == nil {
143-
return fmt.Errorf("AddTools: tool %q has no handler", st.Tool.Name)
144-
}
145-
st.rawHandler = newRawHandler(st)
146-
// Resolve the schemas, with no base URI. We don't expect tool schemas to
147-
// refer outside of themselves.
148-
if st.Tool.InputSchema != nil {
149-
r, err := st.Tool.InputSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
150-
if err != nil {
151-
return err
152-
}
153-
st.inputResolved = r
154-
}
145+
}
155146

156-
// if st.Tool.OutputSchema != nil {
157-
// st.outputResolved, err := st.Tool.OutputSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
158-
// if err != nil {
159-
// return err
160-
// }
161-
// }
162-
}
147+
func addToolErr[In, Out any](s *Server, t *Tool, h ToolHandlerFor[In, Out]) (err error) {
148+
defer util.Wrapf(&err, "adding tool %q", t.Name)
149+
st, err := newServerTool(t, h)
150+
if err != nil {
151+
return err
163152
}
164-
165153
// Assume there was a change, since add replaces existing tools.
166154
// (It's possible a tool was replaced with an identical one, but not worth checking.)
167-
// TODO: surface notify error here?
155+
// TODO: Batch these changes by size and time? The typescript SDK doesn't.
156+
// TODO: Surface notify error here? best not, in case we need to batch.
168157
s.changeAndNotify(notificationToolListChanged, &ToolListChangedParams{},
169-
func() bool { s.tools.add(tools...); return true })
158+
func() bool { s.tools.add(st); return true })
170159
return nil
171160
}
172161

@@ -293,22 +282,22 @@ func (s *Server) listTools(_ context.Context, _ *ServerSession, params *ListTool
293282
if params == nil {
294283
params = &ListToolsParams{}
295284
}
296-
return paginateList(s.tools, s.opts.PageSize, params, &ListToolsResult{}, func(res *ListToolsResult, tools []*ServerTool) {
285+
return paginateList(s.tools, s.opts.PageSize, params, &ListToolsResult{}, func(res *ListToolsResult, tools []*serverTool) {
297286
res.Tools = []*Tool{} // avoid JSON null
298287
for _, t := range tools {
299-
res.Tools = append(res.Tools, t.Tool)
288+
res.Tools = append(res.Tools, t.tool)
300289
}
301290
})
302291
}
303292

304293
func (s *Server) callTool(ctx context.Context, cc *ServerSession, params *CallToolParamsFor[json.RawMessage]) (*CallToolResult, error) {
305294
s.mu.Lock()
306-
tool, ok := s.tools.get(params.Name)
295+
st, ok := s.tools.get(params.Name)
307296
s.mu.Unlock()
308297
if !ok {
309298
return nil, fmt.Errorf("%s: unknown tool %q", jsonrpc2.ErrInvalidParams, params.Name)
310299
}
311-
return tool.rawHandler(ctx, cc, params)
300+
return st.handler(ctx, cc, params)
312301
}
313302

314303
func (s *Server) listResources(_ context.Context, _ *ServerSession, params *ListResourcesParams) (*ListResourcesResult, error) {

mcp/server_example_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func ExampleServer() {
2929
clientTransport, serverTransport := mcp.NewInMemoryTransports()
3030

3131
server := mcp.NewServer("greeter", "v0.0.1", nil)
32-
server.AddTools(mcp.NewServerTool("greet", "say hi", SayHi))
32+
mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi)
3333

3434
serverSession, err := server.Connect(ctx, serverTransport)
3535
if err != nil {

0 commit comments

Comments
 (0)