diff --git a/mcp/features.go b/mcp/features.go index 1777b33f..43c58854 100644 --- a/mcp/features.go +++ b/mcp/features.go @@ -66,6 +66,9 @@ func (s *featureSet[T]) get(uid string) (T, bool) { return t, ok } +// len returns the number of features in the set. +func (s *featureSet[T]) len() int { return len(s.features) } + // all returns an iterator over of all the features in the set // sorted by unique ID. func (s *featureSet[T]) all() iter.Seq[T] { diff --git a/mcp/server.go b/mcp/server.go index f9b76539..16cc8d6b 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -204,6 +204,26 @@ func (s *Server) RemoveResourceTemplates(uriTemplates ...string) { func() bool { return s.resourceTemplates.remove(uriTemplates...) }) } +func (s *Server) capabilities() *serverCapabilities { + s.mu.Lock() + defer s.mu.Unlock() + + caps := &serverCapabilities{ + Completions: &completionCapabilities{}, + Logging: &loggingCapabilities{}, + } + if s.tools.len() > 0 { + caps.Tools = &toolCapabilities{ListChanged: true} + } + if s.prompts.len() > 0 { + caps.Prompts = &promptCapabilities{ListChanged: true} + } + if s.resources.len() > 0 || s.resourceTemplates.len() > 0 { + caps.Resources = &resourceCapabilities{ListChanged: true} + } + return caps +} + func (s *Server) complete(ctx context.Context, ss *ServerSession, params *CompleteParams) (Result, error) { if s.opts.CompletionHandler == nil { return nil, jsonrpc2.ErrMethodNotFound @@ -407,6 +427,11 @@ func fileResourceHandler(dir string) ResourceHandler { // // Run blocks until the client terminates the connection or the provided // context is cancelled. If the context is cancelled, Run closes the connection. +// +// If tools have been added to the server before this call, then the server will +// advertise the capability for tools, including the ability to send list-changed notifications. +// If no tools have been added, the server will not have the tool capability. +// The same goes for other features like prompts and resources. func (s *Server) Run(ctx context.Context, t Transport) error { ss, err := s.Connect(ctx, t) if err != nil { @@ -659,20 +684,8 @@ func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParam // TODO(rfindley): alter behavior when falling back to an older version: // reject unsupported features. ProtocolVersion: version, - Capabilities: &serverCapabilities{ - Completions: &completionCapabilities{}, - Prompts: &promptCapabilities{ - ListChanged: true, - }, - Tools: &toolCapabilities{ - ListChanged: true, - }, - Resources: &resourceCapabilities{ - ListChanged: true, - }, - Logging: &loggingCapabilities{}, - }, - Instructions: ss.server.opts.Instructions, + Capabilities: ss.server.capabilities(), + Instructions: ss.server.opts.Instructions, ServerInfo: &implementation{ Name: ss.server.name, Version: ss.server.version, diff --git a/mcp/server_test.go b/mcp/server_test.go index 19701f39..cc94003c 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -227,3 +227,91 @@ func TestServerPaginateVariousPageSizes(t *testing.T) { } } } + +func TestServerCapabilities(t *testing.T) { + testCases := []struct { + name string + configureServer func(s *Server) + wantCapabilities *serverCapabilities + }{ + { + name: "No capabilities", + configureServer: func(s *Server) {}, + wantCapabilities: &serverCapabilities{ + Completions: &completionCapabilities{}, + Logging: &loggingCapabilities{}, + }, + }, + { + name: "With prompts", + configureServer: func(s *Server) { + s.AddPrompt(&Prompt{Name: "p"}, nil) + }, + wantCapabilities: &serverCapabilities{ + Completions: &completionCapabilities{}, + Logging: &loggingCapabilities{}, + Prompts: &promptCapabilities{ListChanged: true}, + }, + }, + { + name: "With resources", + configureServer: func(s *Server) { + s.AddResource(&Resource{URI: "file:///r"}, nil) + }, + wantCapabilities: &serverCapabilities{ + Completions: &completionCapabilities{}, + Logging: &loggingCapabilities{}, + Resources: &resourceCapabilities{ListChanged: true}, + }, + }, + { + name: "With resource templates", + configureServer: func(s *Server) { + s.AddResourceTemplate(&ResourceTemplate{URITemplate: "file:///rt"}, nil) + }, + wantCapabilities: &serverCapabilities{ + Completions: &completionCapabilities{}, + Logging: &loggingCapabilities{}, + Resources: &resourceCapabilities{ListChanged: true}, + }, + }, + { + name: "With tools", + configureServer: func(s *Server) { + s.AddTool(&Tool{Name: "t"}, nil) + }, + wantCapabilities: &serverCapabilities{ + Completions: &completionCapabilities{}, + Logging: &loggingCapabilities{}, + Tools: &toolCapabilities{ListChanged: true}, + }, + }, + { + name: "With all capabilities", + configureServer: func(s *Server) { + s.AddPrompt(&Prompt{Name: "p"}, nil) + s.AddResource(&Resource{URI: "file:///r"}, nil) + s.AddResourceTemplate(&ResourceTemplate{URITemplate: "file:///rt"}, nil) + s.AddTool(&Tool{Name: "t"}, nil) + }, + wantCapabilities: &serverCapabilities{ + Completions: &completionCapabilities{}, + Logging: &loggingCapabilities{}, + Prompts: &promptCapabilities{ListChanged: true}, + Resources: &resourceCapabilities{ListChanged: true}, + Tools: &toolCapabilities{ListChanged: true}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + server := NewServer("", "", nil) + tc.configureServer(server) + gotCapabilities := server.capabilities() + if diff := cmp.Diff(tc.wantCapabilities, gotCapabilities); diff != "" { + t.Errorf("capabilities() mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 412d2e1d..da9c4285 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -159,8 +159,6 @@ func TestStreamableServerTransport(t *testing.T) { Capabilities: &serverCapabilities{ Completions: &completionCapabilities{}, Logging: &loggingCapabilities{}, - Prompts: &promptCapabilities{ListChanged: true}, - Resources: &resourceCapabilities{ListChanged: true}, Tools: &toolCapabilities{ListChanged: true}, }, ProtocolVersion: latestProtocolVersion,