Skip to content

Commit 84008d2

Browse files
committed
add test
1 parent 5a2bc75 commit 84008d2

File tree

2 files changed

+114
-20
lines changed

2 files changed

+114
-20
lines changed

mcp/server.go

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,26 @@ func (s *Server) RemoveResourceTemplates(uriTemplates ...string) {
203203
func() bool { return s.resourceTemplates.remove(uriTemplates...) })
204204
}
205205

206+
func (s *Server) capabilities() *serverCapabilities {
207+
s.mu.Lock()
208+
defer s.mu.Unlock()
209+
210+
caps := &serverCapabilities{
211+
Completions: &completionCapabilities{},
212+
Logging: &loggingCapabilities{},
213+
}
214+
if s.tools.len() > 0 {
215+
caps.Tools = &toolCapabilities{ListChanged: true}
216+
}
217+
if s.prompts.len() > 0 {
218+
caps.Prompts = &promptCapabilities{ListChanged: true}
219+
}
220+
if s.resources.len() > 0 || s.resourceTemplates.len() > 0 {
221+
caps.Resources = &resourceCapabilities{ListChanged: true}
222+
}
223+
return caps
224+
}
225+
206226
func (s *Server) complete(ctx context.Context, ss *ServerSession, params *CompleteParams) (Result, error) {
207227
if s.opts.CompletionHandler == nil {
208228
return nil, jsonrpc2.ErrMethodNotFound
@@ -645,30 +665,11 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam
645665
version = v
646666
}
647667

648-
caps := &serverCapabilities{
649-
Completions: &completionCapabilities{},
650-
Logging: &loggingCapabilities{},
651-
}
652-
ss.server.mu.Lock()
653-
hasTools := ss.server.tools.len() > 0
654-
hasPrompts := ss.server.prompts.len() > 0
655-
hasResources := ss.server.resources.len() > 0
656-
ss.server.mu.Unlock()
657-
if hasTools {
658-
caps.Tools = &toolCapabilities{ListChanged: true}
659-
}
660-
if hasPrompts {
661-
caps.Prompts = &promptCapabilities{ListChanged: true}
662-
}
663-
if hasResources {
664-
caps.Resources = &resourceCapabilities{ListChanged: true}
665-
}
666-
667668
return &InitializeResult{
668669
// TODO(rfindley): alter behavior when falling back to an older version:
669670
// reject unsupported features.
670671
ProtocolVersion: version,
671-
Capabilities: caps,
672+
Capabilities: ss.server.capabilities(),
672673
Instructions: ss.server.opts.Instructions,
673674
ServerInfo: &implementation{
674675
Name: ss.server.name,

mcp/server_test.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,3 +227,96 @@ func TestServerPaginateVariousPageSizes(t *testing.T) {
227227
}
228228
}
229229
}
230+
231+
func TestServerCapabilities(t *testing.T) {
232+
// An empty handler that can be used for tools, prompts, and resources.
233+
emptyHandler := func(context.Context, *ServerSession, any) (any, error) {
234+
return &emptyResult{}, nil
235+
}
236+
237+
testCases := []struct {
238+
name string
239+
configureServer func(s *Server)
240+
wantCapabilities *serverCapabilities
241+
}{
242+
{
243+
name: "No capabilities",
244+
configureServer: func(s *Server) {},
245+
wantCapabilities: &serverCapabilities{
246+
Completions: &completionCapabilities{},
247+
Logging: &loggingCapabilities{},
248+
},
249+
},
250+
{
251+
name: "With prompts",
252+
configureServer: func(s *Server) {
253+
s.AddPrompts(NewServerPrompt(&Prompt{Name: "p"}, emptyHandler))
254+
},
255+
wantCapabilities: &serverCapabilities{
256+
Completions: &completionCapabilities{},
257+
Logging: &loggingCapabilities{},
258+
Prompts: &promptCapabilities{ListChanged: true},
259+
},
260+
},
261+
{
262+
name: "With resources",
263+
configureServer: func(s *Server) {
264+
s.AddResources(NewServerResource(&Resource{URI: "file:///r"}, emptyHandler))
265+
},
266+
wantCapabilities: &serverCapabilities{
267+
Completions: &completionCapabilities{},
268+
Logging: &loggingCapabilities{},
269+
Resources: &resourceCapabilities{ListChanged: true},
270+
},
271+
},
272+
{
273+
name: "With resource templates",
274+
configureServer: func(s *Server) {
275+
s.AddResourceTemplates(NewServerResourceTemplate(&ResourceTemplate{URITemplate: "file:///rt"}, emptyHandler))
276+
},
277+
wantCapabilities: &serverCapabilities{
278+
Completions: &completionCapabilities{},
279+
Logging: &loggingCapabilities{},
280+
Resources: &resourceCapabilities{ListChanged: true},
281+
},
282+
},
283+
{
284+
name: "With tools",
285+
configureServer: func(s *Server) {
286+
s.AddTools(NewServerTool(&Tool{Name: "t"}, emptyHandler))
287+
},
288+
wantCapabilities: &serverCapabilities{
289+
Completions: &completionCapabilities{},
290+
Logging: &loggingCapabilities{},
291+
Tools: &toolCapabilities{ListChanged: true},
292+
},
293+
},
294+
{
295+
name: "With all capabilities",
296+
configureServer: func(s *Server) {
297+
s.AddPrompts(NewServerPrompt(&Prompt{Name: "p"}, emptyHandler))
298+
s.AddResources(NewServerResource(&Resource{URI: "file:///r"}, emptyHandler))
299+
s.AddResourceTemplates(NewServerResourceTemplate(&ResourceTemplate{URITemplate: "file:///rt"}, emptyHandler))
300+
s.AddTools(NewServerTool(&Tool{Name: "t"}, emptyHandler))
301+
},
302+
wantCapabilities: &serverCapabilities{
303+
Completions: &completionCapabilities{},
304+
Logging: &loggingCapabilities{},
305+
Prompts: &promptCapabilities{ListChanged: true},
306+
Resources: &resourceCapabilities{ListChanged: true},
307+
Tools: &toolCapabilities{ListChanged: true},
308+
},
309+
},
310+
}
311+
312+
for _, tc := range testCases {
313+
t.Run(tc.name, func(t *testing.T) {
314+
server := NewServer(nil, nil, nil)
315+
tc.configureServer(server)
316+
gotCapabilities := server.capabilities()
317+
if diff := cmp.Diff(tc.wantCapabilities, gotCapabilities); diff != "" {
318+
t.Errorf("capabilities() mismatch (-want +got):\n%s", diff)
319+
}
320+
})
321+
}
322+
}

0 commit comments

Comments
 (0)