From 58fc9cf4afdd27d344c83db99c090d9bc05846af Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 21 Aug 2025 07:39:07 -0400 Subject: [PATCH 1/2] all: remove ServerRequest[T] for concrete T Replace all occurrences of ServerRequest[*CallToolParams] and other concrete instantiations with CallToolRequest and the like. Make the XXXRequest types aliases, to preserve the convenience of generics for the internal machinery (see shared.go, for example.) TODO: either expand the aliases or unexport ServerRequest. (The latter will be problematic for docs.) --- examples/server/completion/main.go | 2 +- examples/server/custom-transport/main.go | 2 +- examples/server/hello/main.go | 4 +- examples/server/memory/kb.go | 18 ++++----- examples/server/sequentialthinking/main.go | 8 ++-- .../server/sequentialthinking/main_test.go | 22 ++++------- examples/server/sse/main.go | 2 +- internal/readme/server/server.go | 2 +- mcp/example_middleware_test.go | 2 +- mcp/mcp_test.go | 28 +++++++------- mcp/requests.go | 24 ++++++++++++ mcp/resource.go | 2 +- mcp/server.go | 38 +++++++++---------- mcp/server_example_test.go | 2 +- mcp/server_test.go | 12 +++--- mcp/shared_test.go | 2 +- mcp/sse_example_test.go | 2 +- mcp/streamable_test.go | 14 +++---- mcp/tool.go | 4 +- mcp/tool_test.go | 4 +- 20 files changed, 106 insertions(+), 88 deletions(-) create mode 100644 mcp/requests.go diff --git a/examples/server/completion/main.go b/examples/server/completion/main.go index f05b2721..b0a991fd 100644 --- a/examples/server/completion/main.go +++ b/examples/server/completion/main.go @@ -16,7 +16,7 @@ import ( // a CompletionHandler to an MCP Server's options. func main() { // Define your custom CompletionHandler logic. - myCompletionHandler := func(_ context.Context, req *mcp.ServerRequest[*mcp.CompleteParams]) (*mcp.CompleteResult, error) { + myCompletionHandler := func(_ context.Context, req *mcp.CompleteRequest) (*mcp.CompleteResult, error) { // In a real application, you'd implement actual completion logic here. // For this example, we return a fixed set of suggestions. var suggestions []string diff --git a/examples/server/custom-transport/main.go b/examples/server/custom-transport/main.go index 72cfc31d..c367cb62 100644 --- a/examples/server/custom-transport/main.go +++ b/examples/server/custom-transport/main.go @@ -85,7 +85,7 @@ type HiArgs struct { } // SayHi is a tool handler that responds with a greeting. -func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args HiArgs) (*mcp.CallToolResult, struct{}, error) { +func SayHi(ctx context.Context, req *mcp.CallToolRequest, args HiArgs) (*mcp.CallToolResult, struct{}, error) { return &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{Text: "Hi " + args.Name}, diff --git a/examples/server/hello/main.go b/examples/server/hello/main.go index f71b0a78..04c0e0b4 100644 --- a/examples/server/hello/main.go +++ b/examples/server/hello/main.go @@ -22,7 +22,7 @@ type HiArgs struct { Name string `json:"name" jsonschema:"the name to say hi to"` } -func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args HiArgs) (*mcp.CallToolResult, any, error) { +func SayHi(ctx context.Context, req *mcp.CallToolRequest, args HiArgs) (*mcp.CallToolResult, any, error) { return &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{Text: "Hi " + args.Name}, @@ -69,7 +69,7 @@ var embeddedResources = map[string]string{ "info": "This is the hello example server.", } -func handleEmbeddedResource(_ context.Context, req *mcp.ServerRequest[*mcp.ReadResourceParams]) (*mcp.ReadResourceResult, error) { +func handleEmbeddedResource(_ context.Context, req *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { u, err := url.Parse(req.Params.URI) if err != nil { return nil, err diff --git a/examples/server/memory/kb.go b/examples/server/memory/kb.go index e28a4909..c6a59ec0 100644 --- a/examples/server/memory/kb.go +++ b/examples/server/memory/kb.go @@ -431,7 +431,7 @@ func (k knowledgeBase) openNodes(names []string) (KnowledgeGraph, error) { }, nil } -func (k knowledgeBase) CreateEntities(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args CreateEntitiesArgs) (*mcp.CallToolResult, CreateEntitiesResult, error) { +func (k knowledgeBase) CreateEntities(ctx context.Context, req *mcp.CallToolRequest, args CreateEntitiesArgs) (*mcp.CallToolResult, CreateEntitiesResult, error) { var res mcp.CallToolResult entities, err := k.createEntities(args.Entities) @@ -450,7 +450,7 @@ func (k knowledgeBase) CreateEntities(ctx context.Context, req *mcp.ServerReques return &res, CreateEntitiesResult{Entities: entities}, nil } -func (k knowledgeBase) CreateRelations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args CreateRelationsArgs) (*mcp.CallToolResult, CreateRelationsResult, error) { +func (k knowledgeBase) CreateRelations(ctx context.Context, req *mcp.CallToolRequest, args CreateRelationsArgs) (*mcp.CallToolResult, CreateRelationsResult, error) { var res mcp.CallToolResult relations, err := k.createRelations(args.Relations) @@ -465,7 +465,7 @@ func (k knowledgeBase) CreateRelations(ctx context.Context, req *mcp.ServerReque return &res, CreateRelationsResult{Relations: relations}, nil } -func (k knowledgeBase) AddObservations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args AddObservationsArgs) (*mcp.CallToolResult, AddObservationsResult, error) { +func (k knowledgeBase) AddObservations(ctx context.Context, req *mcp.CallToolRequest, args AddObservationsArgs) (*mcp.CallToolResult, AddObservationsResult, error) { var res mcp.CallToolResult observations, err := k.addObservations(args.Observations) @@ -482,7 +482,7 @@ func (k knowledgeBase) AddObservations(ctx context.Context, req *mcp.ServerReque }, nil } -func (k knowledgeBase) DeleteEntities(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args DeleteEntitiesArgs) (*mcp.CallToolResult, any, error) { +func (k knowledgeBase) DeleteEntities(ctx context.Context, req *mcp.CallToolRequest, args DeleteEntitiesArgs) (*mcp.CallToolResult, any, error) { var res mcp.CallToolResult err := k.deleteEntities(args.EntityNames) @@ -497,7 +497,7 @@ func (k knowledgeBase) DeleteEntities(ctx context.Context, req *mcp.ServerReques return &res, nil, nil } -func (k knowledgeBase) DeleteObservations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args DeleteObservationsArgs) (*mcp.CallToolResult, any, error) { +func (k knowledgeBase) DeleteObservations(ctx context.Context, req *mcp.CallToolRequest, args DeleteObservationsArgs) (*mcp.CallToolResult, any, error) { var res mcp.CallToolResult err := k.deleteObservations(args.Deletions) @@ -512,7 +512,7 @@ func (k knowledgeBase) DeleteObservations(ctx context.Context, req *mcp.ServerRe return &res, nil, nil } -func (k knowledgeBase) DeleteRelations(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args DeleteRelationsArgs) (*mcp.CallToolResult, struct{}, error) { +func (k knowledgeBase) DeleteRelations(ctx context.Context, req *mcp.CallToolRequest, args DeleteRelationsArgs) (*mcp.CallToolResult, struct{}, error) { var res mcp.CallToolResult err := k.deleteRelations(args.Relations) @@ -527,7 +527,7 @@ func (k knowledgeBase) DeleteRelations(ctx context.Context, req *mcp.ServerReque return &res, struct{}{}, nil } -func (k knowledgeBase) ReadGraph(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args any) (*mcp.CallToolResult, KnowledgeGraph, error) { +func (k knowledgeBase) ReadGraph(ctx context.Context, req *mcp.CallToolRequest, args any) (*mcp.CallToolResult, KnowledgeGraph, error) { var res mcp.CallToolResult graph, err := k.loadGraph() @@ -542,7 +542,7 @@ func (k knowledgeBase) ReadGraph(ctx context.Context, req *mcp.ServerRequest[*mc return &res, graph, nil } -func (k knowledgeBase) SearchNodes(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args SearchNodesArgs) (*mcp.CallToolResult, KnowledgeGraph, error) { +func (k knowledgeBase) SearchNodes(ctx context.Context, req *mcp.CallToolRequest, args SearchNodesArgs) (*mcp.CallToolResult, KnowledgeGraph, error) { var res mcp.CallToolResult graph, err := k.searchNodes(args.Query) @@ -557,7 +557,7 @@ func (k knowledgeBase) SearchNodes(ctx context.Context, req *mcp.ServerRequest[* return &res, graph, nil } -func (k knowledgeBase) OpenNodes(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args OpenNodesArgs) (*mcp.CallToolResult, KnowledgeGraph, error) { +func (k knowledgeBase) OpenNodes(ctx context.Context, req *mcp.CallToolRequest, args OpenNodesArgs) (*mcp.CallToolResult, KnowledgeGraph, error) { var res mcp.CallToolResult graph, err := k.openNodes(args.Names) diff --git a/examples/server/sequentialthinking/main.go b/examples/server/sequentialthinking/main.go index 100e1167..e0ae5219 100644 --- a/examples/server/sequentialthinking/main.go +++ b/examples/server/sequentialthinking/main.go @@ -231,7 +231,7 @@ func deepCopyThoughts(thoughts []*Thought) []*Thought { } // StartThinking begins a new sequential thinking session for a complex problem. -func StartThinking(ctx context.Context, _ *mcp.ServerRequest[*mcp.CallToolParams], args StartThinkingArgs) (*mcp.CallToolResult, any, error) { +func StartThinking(ctx context.Context, _ *mcp.CallToolRequest, args StartThinkingArgs) (*mcp.CallToolResult, any, error) { sessionID := args.SessionID if sessionID == "" { sessionID = randText() @@ -264,7 +264,7 @@ func StartThinking(ctx context.Context, _ *mcp.ServerRequest[*mcp.CallToolParams } // ContinueThinking adds the next thought step, revises a previous step, or creates a branch in the thinking process. -func ContinueThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args ContinueThinkingArgs) (*mcp.CallToolResult, any, error) { +func ContinueThinking(ctx context.Context, req *mcp.CallToolRequest, args ContinueThinkingArgs) (*mcp.CallToolResult, any, error) { // Handle revision of existing thought if args.ReviseStep != nil { err := store.CompareAndSwap(args.SessionID, func(session *ThinkingSession) (*ThinkingSession, error) { @@ -391,7 +391,7 @@ func ContinueThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolP } // ReviewThinking provides a complete review of the thinking process for a session. -func ReviewThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args ReviewThinkingArgs) (*mcp.CallToolResult, any, error) { +func ReviewThinking(ctx context.Context, req *mcp.CallToolRequest, args ReviewThinkingArgs) (*mcp.CallToolResult, any, error) { // Get a snapshot of the session to avoid race conditions sessionSnapshot, exists := store.SessionSnapshot(args.SessionID) if !exists { @@ -428,7 +428,7 @@ func ReviewThinking(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolPar } // ThinkingHistory handles resource requests for thinking session data and history. -func ThinkingHistory(ctx context.Context, req *mcp.ServerRequest[*mcp.ReadResourceParams]) (*mcp.ReadResourceResult, error) { +func ThinkingHistory(ctx context.Context, req *mcp.ReadResourceRequest) (*mcp.ReadResourceResult, error) { // Extract session ID from URI (e.g., "thinking://session_123") u, err := url.Parse(req.Params.URI) if err != nil { diff --git a/examples/server/sequentialthinking/main_test.go b/examples/server/sequentialthinking/main_test.go index 8889db7d..2655114c 100644 --- a/examples/server/sequentialthinking/main_test.go +++ b/examples/server/sequentialthinking/main_test.go @@ -387,11 +387,11 @@ func TestThinkingHistory(t *testing.T) { ctx := context.Background() // Test listing all sessions - listParams := &mcp.ReadResourceParams{ - URI: "thinking://sessions", - } - - result, err := ThinkingHistory(ctx, requestFor(listParams)) + result, err := ThinkingHistory(ctx, &mcp.ReadResourceRequest{ + Params: &mcp.ReadResourceParams{ + URI: "thinking://sessions", + }, + }) if err != nil { t.Fatalf("ThinkingHistory() error = %v", err) } @@ -417,11 +417,9 @@ func TestThinkingHistory(t *testing.T) { } // Test getting specific session - sessionParams := &mcp.ReadResourceParams{ - URI: "thinking://session1", - } - - result, err = ThinkingHistory(ctx, requestFor(sessionParams)) + result, err = ThinkingHistory(ctx, &mcp.ReadResourceRequest{ + Params: &mcp.ReadResourceParams{URI: "thinking://session1"}, + }) if err != nil { t.Fatalf("ThinkingHistory() error = %v", err) } @@ -491,7 +489,3 @@ func TestInvalidOperations(t *testing.T) { t.Error("Expected error for invalid revision step") } } - -func requestFor[P mcp.Params](p P) *mcp.ServerRequest[P] { - return &mcp.ServerRequest[P]{Params: p} -} diff --git a/examples/server/sse/main.go b/examples/server/sse/main.go index c2603b41..27f9caed 100644 --- a/examples/server/sse/main.go +++ b/examples/server/sse/main.go @@ -24,7 +24,7 @@ type SayHiParams struct { Name string `json:"name"` } -func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args SayHiParams) (*mcp.CallToolResult, any, error) { +func SayHi(ctx context.Context, req *mcp.CallToolRequest, args SayHiParams) (*mcp.CallToolResult, any, error) { return &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{Text: "Hi " + args.Name}, diff --git a/internal/readme/server/server.go b/internal/readme/server/server.go index 087992e8..aff5fcd0 100644 --- a/internal/readme/server/server.go +++ b/internal/readme/server/server.go @@ -16,7 +16,7 @@ type HiParams struct { Name string `json:"name" jsonschema:"the name of the person to greet"` } -func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args HiParams) (*mcp.CallToolResult, any, error) { +func SayHi(ctx context.Context, req *mcp.CallToolRequest, args HiParams) (*mcp.CallToolResult, any, error) { return &mcp.CallToolResult{ Content: []mcp.Content{&mcp.TextContent{Text: "Hi " + args.Name}}, }, nil, nil diff --git a/mcp/example_middleware_test.go b/mcp/example_middleware_test.go index 0f6d540e..10dda0fa 100644 --- a/mcp/example_middleware_test.go +++ b/mcp/example_middleware_test.go @@ -89,7 +89,7 @@ func Example_loggingMiddleware() { }, func( ctx context.Context, - req *mcp.ServerRequest[*mcp.CallToolParams], args map[string]any, + req *mcp.CallToolRequest, args map[string]any, ) (*mcp.CallToolResult, any, error) { name, ok := args["name"].(string) if !ok { diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 44dd76d2..9f73f2ca 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -33,7 +33,7 @@ type hiParams struct { // TODO(jba): after schemas are stateless (WIP), this can be a variable. func greetTool() *Tool { return &Tool{Name: "greet", Description: "say hi"} } -func sayHi(ctx context.Context, req *ServerRequest[*CallToolParams], args hiParams) (*CallToolResult, any, error) { +func sayHi(ctx context.Context, req *CallToolRequest, args hiParams) (*CallToolResult, any, error) { if err := req.Session.Ping(ctx, nil); err != nil { return nil, nil, fmt.Errorf("ping failed: %v", err) } @@ -74,20 +74,20 @@ func TestEndToEnd(t *testing.T) { } sopts := &ServerOptions{ - InitializedHandler: func(context.Context, *ServerRequest[*InitializedParams]) { + InitializedHandler: func(context.Context, *InitializedRequest) { notificationChans["initialized"] <- 0 }, - RootsListChangedHandler: func(context.Context, *ServerRequest[*RootsListChangedParams]) { + RootsListChangedHandler: func(context.Context, *RootsListChangedRequest) { notificationChans["roots"] <- 0 }, - ProgressNotificationHandler: func(context.Context, *ServerRequest[*ProgressNotificationParams]) { + ProgressNotificationHandler: func(context.Context, *ProgressNotificationRequest) { notificationChans["progress_server"] <- 0 }, - SubscribeHandler: func(context.Context, *ServerRequest[*SubscribeParams]) error { + SubscribeHandler: func(context.Context, *SubscribeRequest) error { notificationChans["subscribe"] <- 0 return nil }, - UnsubscribeHandler: func(context.Context, *ServerRequest[*UnsubscribeParams]) error { + UnsubscribeHandler: func(context.Context, *UnsubscribeRequest) error { notificationChans["unsubscribe"] <- 0 return nil }, @@ -98,7 +98,7 @@ func TestEndToEnd(t *testing.T) { Description: "say hi", }, sayHi) AddTool(s, &Tool{Name: "fail", InputSchema: &jsonschema.Schema{}}, - func(context.Context, *ServerRequest[*CallToolParams], map[string]any) (*CallToolResult, any, error) { + func(context.Context, *CallToolRequest, map[string]any) (*CallToolResult, any, error) { return nil, nil, errTestFailure }) s.AddPrompt(codeReviewPrompt, codReviewPromptHandler) @@ -511,7 +511,7 @@ var embeddedResources = map[string]string{ "info": "This is the MCP test server.", } -func handleEmbeddedResource(_ context.Context, req *ServerRequest[*ReadResourceParams]) (*ReadResourceResult, error) { +func handleEmbeddedResource(_ context.Context, req *ReadResourceRequest) (*ReadResourceResult, error) { u, err := url.Parse(req.Params.URI) if err != nil { return nil, err @@ -663,7 +663,7 @@ func TestCancellation(t *testing.T) { start = make(chan struct{}) cancelled = make(chan struct{}, 1) // don't block the request ) - slowRequest := func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, any, error) { + slowRequest := func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) { start <- struct{}{} select { case <-ctx.Done(): @@ -852,7 +852,7 @@ func traceCalls[S Session](w io.Writer, prefix string) Middleware { } } -func nopHandler(context.Context, *ServerRequest[*CallToolParams]) (*CallToolResult, error) { +func nopHandler(context.Context, *CallToolRequest) (*CallToolResult, error) { return nil, nil } @@ -1009,13 +1009,13 @@ func TestSynchronousNotifications(t *testing.T) { var rootsChanged atomic.Bool serverOpts := &ServerOptions{ - RootsListChangedHandler: func(_ context.Context, req *ServerRequest[*RootsListChangedParams]) { + RootsListChangedHandler: func(_ context.Context, req *RootsListChangedRequest) { rootsChanged.Store(true) }, } server := NewServer(testImpl, serverOpts) cs, ss := basicClientServerConnection(t, client, server, func(s *Server) { - AddTool(s, &Tool{Name: "tool"}, func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, any, error) { + AddTool(s, &Tool{Name: "tool"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) { if !rootsChanged.Load() { return nil, nil, fmt.Errorf("didn't get root change notification") } @@ -1064,11 +1064,11 @@ func TestNoDistributedDeadlock(t *testing.T) { } client := NewClient(testImpl, clientOpts) cs, _ := basicClientServerConnection(t, client, nil, func(s *Server) { - AddTool(s, &Tool{Name: "tool1"}, func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, any, error) { + AddTool(s, &Tool{Name: "tool1"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) { req.Session.CreateMessage(ctx, new(CreateMessageParams)) return new(CallToolResult), nil, nil }) - AddTool(s, &Tool{Name: "tool2"}, func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, any, error) { + AddTool(s, &Tool{Name: "tool2"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) { req.Session.Ping(ctx, nil) return new(CallToolResult), nil, nil }) diff --git a/mcp/requests.go b/mcp/requests.go new file mode 100644 index 00000000..ceed5026 --- /dev/null +++ b/mcp/requests.go @@ -0,0 +1,24 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// This file holds the request types. + +package mcp + +// TODO: expand the aliases +type ( + CallToolRequest = ServerRequest[*CallToolParams] + CompleteRequest = ServerRequest[*CompleteParams] + GetPromptRequest = ServerRequest[*GetPromptParams] + InitializedRequest = ServerRequest[*InitializedParams] + ListPromptsRequest = ServerRequest[*ListPromptsParams] + ListResourcesRequest = ServerRequest[*ListResourcesParams] + ListResourceTemplatesRequest = ServerRequest[*ListResourceTemplatesParams] + ListToolsRequest = ServerRequest[*ListToolsParams] + ProgressNotificationRequest = ServerRequest[*ProgressNotificationParams] + ReadResourceRequest = ServerRequest[*ReadResourceParams] + RootsListChangedRequest = ServerRequest[*RootsListChangedParams] + SubscribeRequest = ServerRequest[*SubscribeParams] + UnsubscribeRequest = ServerRequest[*UnsubscribeParams] +) diff --git a/mcp/resource.go b/mcp/resource.go index 5445715b..0658c661 100644 --- a/mcp/resource.go +++ b/mcp/resource.go @@ -35,7 +35,7 @@ type serverResourceTemplate struct { // A ResourceHandler is a function that reads a resource. // It will be called when the client calls [ClientSession.ReadResource]. // If it cannot find the resource, it should return the result of calling [ResourceNotFoundError]. -type ResourceHandler func(context.Context, *ServerRequest[*ReadResourceParams]) (*ReadResourceResult, error) +type ResourceHandler func(context.Context, *ReadResourceRequest) (*ReadResourceResult, error) // ResourceNotFoundError returns an error indicating that a resource being read could // not be found. diff --git a/mcp/server.go b/mcp/server.go index b8e72907..4c13ca09 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -54,24 +54,24 @@ type ServerOptions struct { // Optional instructions for connected clients. Instructions string // If non-nil, called when "notifications/initialized" is received. - InitializedHandler func(context.Context, *ServerRequest[*InitializedParams]) + InitializedHandler func(context.Context, *InitializedRequest) // PageSize is the maximum number of items to return in a single page for // list methods (e.g. ListTools). PageSize int // If non-nil, called when "notifications/roots/list_changed" is received. - RootsListChangedHandler func(context.Context, *ServerRequest[*RootsListChangedParams]) + RootsListChangedHandler func(context.Context, *RootsListChangedRequest) // If non-nil, called when "notifications/progress" is received. - ProgressNotificationHandler func(context.Context, *ServerRequest[*ProgressNotificationParams]) + ProgressNotificationHandler func(context.Context, *ProgressNotificationRequest) // If non-nil, called when "completion/complete" is received. - CompletionHandler func(context.Context, *ServerRequest[*CompleteParams]) (*CompleteResult, error) + CompletionHandler func(context.Context, *CompleteRequest) (*CompleteResult, error) // If non-zero, defines an interval for regular "ping" requests. // If the peer fails to respond to pings originating from the keepalive check, // the session is automatically closed. KeepAlive time.Duration // Function called when a client session subscribes to a resource. - SubscribeHandler func(context.Context, *ServerRequest[*SubscribeParams]) error + SubscribeHandler func(context.Context, *SubscribeRequest) error // Function called when a client session unsubscribes from a resource. - UnsubscribeHandler func(context.Context, *ServerRequest[*UnsubscribeParams]) error + UnsubscribeHandler func(context.Context, *UnsubscribeRequest) error // If true, advertises the prompts capability during initialization, // even if no prompts have been registered. HasPrompts bool @@ -217,7 +217,7 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan } } - th := func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + th := func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { // Unmarshal and validate args. rawArgs := req.Params.Arguments.(json.RawMessage) var in In @@ -358,7 +358,7 @@ func (s *Server) capabilities() *ServerCapabilities { return caps } -func (s *Server) complete(ctx context.Context, req *ServerRequest[*CompleteParams]) (Result, error) { +func (s *Server) complete(ctx context.Context, req *CompleteRequest) (Result, error) { if s.opts.CompletionHandler == nil { return nil, jsonrpc2.ErrMethodNotFound } @@ -387,7 +387,7 @@ func (s *Server) Sessions() iter.Seq[*ServerSession] { return slices.Values(clients) } -func (s *Server) listPrompts(_ context.Context, req *ServerRequest[*ListPromptsParams]) (*ListPromptsResult, error) { +func (s *Server) listPrompts(_ context.Context, req *ListPromptsRequest) (*ListPromptsResult, error) { s.mu.Lock() defer s.mu.Unlock() if req.Params == nil { @@ -401,7 +401,7 @@ func (s *Server) listPrompts(_ context.Context, req *ServerRequest[*ListPromptsP }) } -func (s *Server) getPrompt(ctx context.Context, req *ServerRequest[*GetPromptParams]) (*GetPromptResult, error) { +func (s *Server) getPrompt(ctx context.Context, req *GetPromptRequest) (*GetPromptResult, error) { s.mu.Lock() prompt, ok := s.prompts.get(req.Params.Name) s.mu.Unlock() @@ -415,7 +415,7 @@ func (s *Server) getPrompt(ctx context.Context, req *ServerRequest[*GetPromptPar return prompt.handler(ctx, req.Session, req.Params) } -func (s *Server) listTools(_ context.Context, req *ServerRequest[*ListToolsParams]) (*ListToolsResult, error) { +func (s *Server) listTools(_ context.Context, req *ListToolsRequest) (*ListToolsResult, error) { s.mu.Lock() defer s.mu.Unlock() if req.Params == nil { @@ -429,7 +429,7 @@ func (s *Server) listTools(_ context.Context, req *ServerRequest[*ListToolsParam }) } -func (s *Server) callTool(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { +func (s *Server) callTool(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { s.mu.Lock() st, ok := s.tools.get(req.Params.Name) s.mu.Unlock() @@ -444,7 +444,7 @@ func (s *Server) callTool(ctx context.Context, req *ServerRequest[*CallToolParam return st.handler(ctx, req) } -func (s *Server) listResources(_ context.Context, req *ServerRequest[*ListResourcesParams]) (*ListResourcesResult, error) { +func (s *Server) listResources(_ context.Context, req *ListResourcesRequest) (*ListResourcesResult, error) { s.mu.Lock() defer s.mu.Unlock() if req.Params == nil { @@ -458,7 +458,7 @@ func (s *Server) listResources(_ context.Context, req *ServerRequest[*ListResour }) } -func (s *Server) listResourceTemplates(_ context.Context, req *ServerRequest[*ListResourceTemplatesParams]) (*ListResourceTemplatesResult, error) { +func (s *Server) listResourceTemplates(_ context.Context, req *ListResourceTemplatesRequest) (*ListResourceTemplatesResult, error) { s.mu.Lock() defer s.mu.Unlock() if req.Params == nil { @@ -473,7 +473,7 @@ func (s *Server) listResourceTemplates(_ context.Context, req *ServerRequest[*Li }) } -func (s *Server) readResource(ctx context.Context, req *ServerRequest[*ReadResourceParams]) (*ReadResourceResult, error) { +func (s *Server) readResource(ctx context.Context, req *ReadResourceRequest) (*ReadResourceResult, error) { uri := req.Params.URI // Look up the resource URI in the lists of resources and resource templates. // This is a security check as well as an information lookup. @@ -538,7 +538,7 @@ func fileResourceHandler(dir string) ResourceHandler { if err != nil { panic(err) } - return func(ctx context.Context, req *ServerRequest[*ReadResourceParams]) (_ *ReadResourceResult, err error) { + return func(ctx context.Context, req *ReadResourceRequest) (_ *ReadResourceResult, err error) { defer util.Wrapf(&err, "reading resource %s", req.Params.URI) // TODO(#25): use a memoizing API here. @@ -573,7 +573,7 @@ func (s *Server) ResourceUpdated(ctx context.Context, params *ResourceUpdatedNot return nil } -func (s *Server) subscribe(ctx context.Context, req *ServerRequest[*SubscribeParams]) (*emptyResult, error) { +func (s *Server) subscribe(ctx context.Context, req *SubscribeRequest) (*emptyResult, error) { if s.opts.SubscribeHandler == nil { return nil, fmt.Errorf("%w: server does not support resource subscriptions", jsonrpc2.ErrMethodNotFound) } @@ -591,7 +591,7 @@ func (s *Server) subscribe(ctx context.Context, req *ServerRequest[*SubscribePar return &emptyResult{}, nil } -func (s *Server) unsubscribe(ctx context.Context, req *ServerRequest[*UnsubscribeParams]) (*emptyResult, error) { +func (s *Server) unsubscribe(ctx context.Context, req *UnsubscribeRequest) (*emptyResult, error) { if s.opts.UnsubscribeHandler == nil { return nil, jsonrpc2.ErrMethodNotFound } @@ -725,7 +725,7 @@ func (ss *ServerSession) initialized(ctx context.Context, params *InitializedPar return nil, nil } -func (s *Server) callRootsListChangedHandler(ctx context.Context, req *ServerRequest[*RootsListChangedParams]) (Result, error) { +func (s *Server) callRootsListChangedHandler(ctx context.Context, req *RootsListChangedRequest) (Result, error) { if h := s.opts.RootsListChangedHandler; h != nil { h(ctx, req) } diff --git a/mcp/server_example_test.go b/mcp/server_example_test.go index 2b4a0bf1..e68dc308 100644 --- a/mcp/server_example_test.go +++ b/mcp/server_example_test.go @@ -16,7 +16,7 @@ type SayHiParams struct { Name string `json:"name"` } -func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args SayHiParams) (*mcp.CallToolResult, any, error) { +func SayHi(ctx context.Context, req *mcp.CallToolRequest, args SayHiParams) (*mcp.CallToolResult, any, error) { return &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{Text: "Hi " + args.Name}, diff --git a/mcp/server_test.go b/mcp/server_test.go index 39a4cdb4..1ed4c3cc 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -282,10 +282,10 @@ func TestServerCapabilities(t *testing.T) { s.AddResourceTemplate(&ResourceTemplate{URITemplate: "file:///rt"}, nil) }, serverOpts: ServerOptions{ - SubscribeHandler: func(context.Context, *ServerRequest[*SubscribeParams]) error { + SubscribeHandler: func(context.Context, *SubscribeRequest) error { return nil }, - UnsubscribeHandler: func(context.Context, *ServerRequest[*UnsubscribeParams]) error { + UnsubscribeHandler: func(context.Context, *UnsubscribeRequest) error { return nil }, }, @@ -308,7 +308,7 @@ func TestServerCapabilities(t *testing.T) { name: "With completions", configureServer: func(s *Server) {}, serverOpts: ServerOptions{ - CompletionHandler: func(context.Context, *ServerRequest[*CompleteParams]) (*CompleteResult, error) { + CompletionHandler: func(context.Context, *CompleteRequest) (*CompleteResult, error) { return nil, nil }, }, @@ -326,13 +326,13 @@ func TestServerCapabilities(t *testing.T) { s.AddTool(tool, nil) }, serverOpts: ServerOptions{ - SubscribeHandler: func(context.Context, *ServerRequest[*SubscribeParams]) error { + SubscribeHandler: func(context.Context, *SubscribeRequest) error { return nil }, - UnsubscribeHandler: func(context.Context, *ServerRequest[*UnsubscribeParams]) error { + UnsubscribeHandler: func(context.Context, *UnsubscribeRequest) error { return nil }, - CompletionHandler: func(context.Context, *ServerRequest[*CompleteParams]) (*CompleteResult, error) { + CompletionHandler: func(context.Context, *CompleteRequest) (*CompleteResult, error) { return nil, nil }, }, diff --git a/mcp/shared_test.go b/mcp/shared_test.go index 4d0859ac..23818f87 100644 --- a/mcp/shared_test.go +++ b/mcp/shared_test.go @@ -15,7 +15,7 @@ package mcp // P *int `json:",omitempty"` // } -// dummyHandler := func(context.Context, *ServerRequest[*CallToolParams], req) (*CallToolResultFor[any], error) { +// dummyHandler := func(context.Context, *CallToolRequest, req) (*CallToolResultFor[any], error) { // return nil, nil // } diff --git a/mcp/sse_example_test.go b/mcp/sse_example_test.go index 93ccf788..7d777114 100644 --- a/mcp/sse_example_test.go +++ b/mcp/sse_example_test.go @@ -18,7 +18,7 @@ type AddParams struct { X, Y int } -func Add(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args AddParams) (*mcp.CallToolResult, any, error) { +func Add(ctx context.Context, req *mcp.CallToolRequest, args AddParams) (*mcp.CallToolResult, any, error) { return &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{Text: fmt.Sprintf("%d", args.X+args.Y)}, diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 0e9cf455..5cd04eca 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -49,7 +49,7 @@ func TestStreamableTransports(t *testing.T) { start = make(chan struct{}) cancelled = make(chan struct{}, 1) // don't block the request ) - hang := func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, any, error) { + hang := func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) { start <- struct{}{} select { case <-ctx.Done(): @@ -60,7 +60,7 @@ func TestStreamableTransports(t *testing.T) { return nil, nil, nil } AddTool(server, &Tool{Name: "hang"}, hang) - AddTool(server, &Tool{Name: "sample"}, func(ctx context.Context, req *ServerRequest[*CallToolParams], args any) (*CallToolResult, any, error) { + AddTool(server, &Tool{Name: "sample"}, func(ctx context.Context, req *CallToolRequest, args any) (*CallToolResult, any, error) { // Test that we can make sampling requests during tool handling. // // Try this on both the request context and a background context, so @@ -220,7 +220,7 @@ func testClientReplay(t *testing.T, test clientReplayTest) { serverReadyToKillProxy := make(chan struct{}) serverClosed := make(chan struct{}) AddTool(server, &Tool{Name: "multiMessageTool", InputSchema: &jsonschema.Schema{}}, - func(ctx context.Context, req *ServerRequest[*CallToolParams], args map[string]any) (*CallToolResult, any, error) { + func(ctx context.Context, req *CallToolRequest, args map[string]any) (*CallToolResult, any, error) { // Send one message to the request context, and another to a background // context (which will end up on the hanging GET). @@ -354,7 +354,7 @@ func TestServerInitiatedSSE(t *testing.T) { } defer clientSession.Close() AddTool(server, &Tool{Name: "testTool", InputSchema: &jsonschema.Schema{}}, - func(context.Context, *ServerRequest[*CallToolParams], map[string]any) (*CallToolResult, any, error) { + func(context.Context, *CallToolRequest, map[string]any) (*CallToolResult, any, error) { return &CallToolResult{}, nil, nil }) receivedNotifications := readNotifications(t, ctx, notifications, 1) @@ -659,7 +659,7 @@ func TestStreamableServerTransport(t *testing.T) { server := NewServer(&Implementation{Name: "testServer", Version: "v1.0.0"}, nil) server.AddTool( &Tool{Name: "tool", InputSchema: &jsonschema.Schema{}}, - func(ctx context.Context, req *ServerRequest[*CallToolParams]) (*CallToolResult, error) { + func(ctx context.Context, req *CallToolRequest) (*CallToolResult, error) { if test.tool != nil { test.tool(t, ctx, req.Session) } @@ -1072,7 +1072,7 @@ func TestEventID(t *testing.T) { func TestStreamableStateless(t *testing.T) { // This version of sayHi expects // that request from our client). - sayHi := func(ctx context.Context, req *ServerRequest[*CallToolParams], args hiParams) (*CallToolResult, any, error) { + sayHi := func(ctx context.Context, req *CallToolRequest, args hiParams) (*CallToolResult, any, error) { if err := req.Session.Ping(ctx, nil); err == nil { // ping should fail, but not break the connection t.Errorf("ping succeeded unexpectedly") @@ -1179,7 +1179,7 @@ func TestTokenInfo(t *testing.T) { ctx := context.Background() // Create a server with a tool that returns TokenInfo. - tokenInfo := func(ctx context.Context, req *ServerRequest[*CallToolParams], _ struct{}) (*CallToolResult, any, error) { + tokenInfo := func(ctx context.Context, req *CallToolRequest, _ struct{}) (*CallToolResult, any, error) { return &CallToolResult{Content: []Content{&TextContent{Text: fmt.Sprintf("%v", req.Extra.TokenInfo)}}}, nil, nil } server := NewServer(testImpl, nil) diff --git a/mcp/tool.go b/mcp/tool.go index f0178c23..bd10a07c 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -16,10 +16,10 @@ import ( // A ToolHandler handles a call to tools/call. // [CallToolParams.Arguments] will contain a map[string]any that has been validated // against the input schema. -type ToolHandler func(context.Context, *ServerRequest[*CallToolParams]) (*CallToolResult, error) +type ToolHandler func(context.Context, *CallToolRequest) (*CallToolResult, error) // A ToolHandlerFor handles a call to tools/call with typed arguments and results. -type ToolHandlerFor[In, Out any] func(context.Context, *ServerRequest[*CallToolParams], In) (*CallToolResult, Out, error) +type ToolHandlerFor[In, Out any] func(context.Context, *CallToolRequest, In) (*CallToolResult, Out, error) // A serverTool is a tool definition that is bound to a tool handler. type serverTool struct { diff --git a/mcp/tool_test.go b/mcp/tool_test.go index 756d6aa4..2722a9ac 100644 --- a/mcp/tool_test.go +++ b/mcp/tool_test.go @@ -61,7 +61,7 @@ func TestToolErrorHandling(t *testing.T) { server := NewServer(testImpl, nil) // Create a tool that returns a structured error - structuredErrorHandler := func(ctx context.Context, req *ServerRequest[*CallToolParams], args map[string]any) (*CallToolResult, any, error) { + structuredErrorHandler := func(ctx context.Context, req *CallToolRequest, args map[string]any) (*CallToolResult, any, error) { return nil, nil, &jsonrpc2.WireError{ Code: CodeInvalidParams, Message: "internal server error", @@ -69,7 +69,7 @@ func TestToolErrorHandling(t *testing.T) { } // Create a tool that returns a regular error - regularErrorHandler := func(ctx context.Context, req *ServerRequest[*CallToolParams], args map[string]any) (*CallToolResult, any, error) { + regularErrorHandler := func(ctx context.Context, req *CallToolRequest, args map[string]any) (*CallToolResult, any, error) { return nil, nil, fmt.Errorf("tool execution failed") } From c179b538257ed9b6680fb4beea905a65a5fc3d11 Mon Sep 17 00:00:00 2001 From: Jonathan Amsterdam Date: Thu, 21 Aug 2025 10:58:34 -0400 Subject: [PATCH 2/2] fix readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index b46724b7..4da0ac61 100644 --- a/README.md +++ b/README.md @@ -115,7 +115,7 @@ type HiParams struct { Name string `json:"name" jsonschema:"the name of the person to greet"` } -func SayHi(ctx context.Context, req *mcp.ServerRequest[*mcp.CallToolParams], args HiParams) (*mcp.CallToolResult, any, error) { +func SayHi(ctx context.Context, req *mcp.CallToolRequest, args HiParams) (*mcp.CallToolResult, any, error) { return &mcp.CallToolResult{ Content: []mcp.Content{&mcp.TextContent{Text: "Hi " + args.Name}}, }, nil, nil