diff --git a/mcp/client.go b/mcp/client.go index 2511c05b..ec1dc456 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -55,14 +55,14 @@ func NewClient(impl *Implementation, opts *ClientOptions) *Client { type ClientOptions struct { // Handler for sampling. // Called when a server calls CreateMessage. - CreateMessageHandler func(context.Context, *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) + CreateMessageHandler func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) // Handlers for notifications from the server. - ToolListChangedHandler func(context.Context, *ClientRequest[*ToolListChangedParams]) - PromptListChangedHandler func(context.Context, *ClientRequest[*PromptListChangedParams]) - ResourceListChangedHandler func(context.Context, *ClientRequest[*ResourceListChangedParams]) - ResourceUpdatedHandler func(context.Context, *ClientRequest[*ResourceUpdatedNotificationParams]) - LoggingMessageHandler func(context.Context, *ClientRequest[*LoggingMessageParams]) - ProgressNotificationHandler func(context.Context, *ClientRequest[*ProgressNotificationParams]) + ToolListChangedHandler func(context.Context, *ToolListChangedRequest) + PromptListChangedHandler func(context.Context, *PromptListChangedRequest) + ResourceListChangedHandler func(context.Context, *ResourceListChangedRequest) + ResourceUpdatedHandler func(context.Context, *ResourceUpdatedNotificationRequest) + LoggingMessageHandler func(context.Context, *LoggingMessageRequest) + ProgressNotificationHandler func(context.Context, *ProgressNotificationClientRequest) // If non-zero, defines an interval for regular "ping" requests. // If the peer fails to respond to pings originating from the keepalive check, // the session is automatically closed. @@ -132,7 +132,7 @@ func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptio ClientInfo: c.impl, Capabilities: c.capabilities(), } - req := &ClientRequest[*InitializeParams]{Session: cs, Params: params} + req := &InitializeRequest{Session: cs, Params: params} res, err := handleSend[*InitializeResult](ctx, methodInitialize, req) if err != nil { _ = cs.Close() @@ -145,7 +145,7 @@ func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptio if hc, ok := cs.mcpConn.(clientConnection); ok { hc.sessionUpdated(cs.state) } - req2 := &ClientRequest[*InitializedParams]{Session: cs, Params: &InitializedParams{}} + req2 := &InitializedClientRequest{Session: cs, Params: &InitializedParams{}} if err := handleNotify(ctx, notificationInitialized, req2); err != nil { _ = cs.Close() return nil, err @@ -248,7 +248,7 @@ func changeAndNotify[P Params](c *Client, notification string, params P, change notifySessions(sessions, notification, params) } -func (c *Client) listRoots(_ context.Context, req *ClientRequest[*ListRootsParams]) (*ListRootsResult, error) { +func (c *Client) listRoots(_ context.Context, req *ListRootsRequest) (*ListRootsResult, error) { c.mu.Lock() defer c.mu.Unlock() roots := slices.Collect(c.roots.all()) @@ -260,7 +260,7 @@ func (c *Client) listRoots(_ context.Context, req *ClientRequest[*ListRootsParam }, nil } -func (c *Client) createMessage(ctx context.Context, req *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) { +func (c *Client) createMessage(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { if c.opts.CreateMessageHandler == nil { // TODO: wrap or annotate this error? Pick a standard code? return nil, jsonrpc2.NewError(CodeUnsupportedMethod, "client does not support CreateMessage") @@ -436,35 +436,35 @@ func (cs *ClientSession) Unsubscribe(ctx context.Context, params *UnsubscribePar return err } -func (c *Client) callToolChangedHandler(ctx context.Context, req *ClientRequest[*ToolListChangedParams]) (Result, error) { +func (c *Client) callToolChangedHandler(ctx context.Context, req *ToolListChangedRequest) (Result, error) { if h := c.opts.ToolListChangedHandler; h != nil { h(ctx, req) } return nil, nil } -func (c *Client) callPromptChangedHandler(ctx context.Context, req *ClientRequest[*PromptListChangedParams]) (Result, error) { +func (c *Client) callPromptChangedHandler(ctx context.Context, req *PromptListChangedRequest) (Result, error) { if h := c.opts.PromptListChangedHandler; h != nil { h(ctx, req) } return nil, nil } -func (c *Client) callResourceChangedHandler(ctx context.Context, req *ClientRequest[*ResourceListChangedParams]) (Result, error) { +func (c *Client) callResourceChangedHandler(ctx context.Context, req *ResourceListChangedRequest) (Result, error) { if h := c.opts.ResourceListChangedHandler; h != nil { h(ctx, req) } return nil, nil } -func (c *Client) callResourceUpdatedHandler(ctx context.Context, req *ClientRequest[*ResourceUpdatedNotificationParams]) (Result, error) { +func (c *Client) callResourceUpdatedHandler(ctx context.Context, req *ResourceUpdatedNotificationRequest) (Result, error) { if h := c.opts.ResourceUpdatedHandler; h != nil { h(ctx, req) } return nil, nil } -func (c *Client) callLoggingHandler(ctx context.Context, req *ClientRequest[*LoggingMessageParams]) (Result, error) { +func (c *Client) callLoggingHandler(ctx context.Context, req *LoggingMessageRequest) (Result, error) { if h := c.opts.LoggingMessageHandler; h != nil { h(ctx, req) } diff --git a/mcp/client_test.go b/mcp/client_test.go index 469fa3fb..eaeedc81 100644 --- a/mcp/client_test.go +++ b/mcp/client_test.go @@ -211,7 +211,7 @@ func TestClientCapabilities(t *testing.T) { name: "With sampling", configureClient: func(s *Client) {}, clientOpts: ClientOptions{ - CreateMessageHandler: func(context.Context, *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) { + CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) { return nil, nil }, }, diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 42aa06af..9c578392 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -74,13 +74,13 @@ func TestEndToEnd(t *testing.T) { } sopts := &ServerOptions{ - InitializedHandler: func(context.Context, *InitializedRequest) { + InitializedHandler: func(context.Context, *InitializedServerRequest) { notificationChans["initialized"] <- 0 }, RootsListChangedHandler: func(context.Context, *RootsListChangedRequest) { notificationChans["roots"] <- 0 }, - ProgressNotificationHandler: func(context.Context, *ProgressNotificationRequest) { + ProgressNotificationHandler: func(context.Context, *ProgressNotificationServerRequest) { notificationChans["progress_server"] <- 0 }, SubscribeHandler: func(context.Context, *SubscribeRequest) error { @@ -129,25 +129,25 @@ func TestEndToEnd(t *testing.T) { loggingMessages := make(chan *LoggingMessageParams, 100) // big enough for all logging opts := &ClientOptions{ - CreateMessageHandler: func(context.Context, *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) { + CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) { return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil }, - ToolListChangedHandler: func(context.Context, *ClientRequest[*ToolListChangedParams]) { + ToolListChangedHandler: func(context.Context, *ToolListChangedRequest) { notificationChans["tools"] <- 0 }, - PromptListChangedHandler: func(context.Context, *ClientRequest[*PromptListChangedParams]) { + PromptListChangedHandler: func(context.Context, *PromptListChangedRequest) { notificationChans["prompts"] <- 0 }, - ResourceListChangedHandler: func(context.Context, *ClientRequest[*ResourceListChangedParams]) { + ResourceListChangedHandler: func(context.Context, *ResourceListChangedRequest) { notificationChans["resources"] <- 0 }, - LoggingMessageHandler: func(_ context.Context, req *ClientRequest[*LoggingMessageParams]) { + LoggingMessageHandler: func(_ context.Context, req *LoggingMessageRequest) { loggingMessages <- req.Params }, - ProgressNotificationHandler: func(context.Context, *ClientRequest[*ProgressNotificationParams]) { + ProgressNotificationHandler: func(context.Context, *ProgressNotificationClientRequest) { notificationChans["progress_client"] <- 0 }, - ResourceUpdatedHandler: func(context.Context, *ClientRequest[*ResourceUpdatedNotificationParams]) { + ResourceUpdatedHandler: func(context.Context, *ResourceUpdatedNotificationRequest) { notificationChans["resource_updated"] <- 0 }, } @@ -992,10 +992,10 @@ func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) { func TestSynchronousNotifications(t *testing.T) { var toolsChanged atomic.Bool clientOpts := &ClientOptions{ - ToolListChangedHandler: func(ctx context.Context, req *ClientRequest[*ToolListChangedParams]) { + ToolListChangedHandler: func(ctx context.Context, req *ToolListChangedRequest) { toolsChanged.Store(true) }, - CreateMessageHandler: func(ctx context.Context, req *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) { + CreateMessageHandler: func(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { if !toolsChanged.Load() { return nil, fmt.Errorf("didn't get a tools changed notification") } @@ -1057,7 +1057,7 @@ func TestNoDistributedDeadlock(t *testing.T) { // possible, and in any case making tool calls asynchronous by default // delegates synchronization to the user. clientOpts := &ClientOptions{ - CreateMessageHandler: func(ctx context.Context, req *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) { + CreateMessageHandler: func(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) { req.Session.CallTool(ctx, &CallToolParams{Name: "tool2"}) return &CreateMessageResult{Content: &TextContent{}}, nil }, diff --git a/mcp/requests.go b/mcp/requests.go index ceed5026..46ff4f8d 100644 --- a/mcp/requests.go +++ b/mcp/requests.go @@ -8,17 +8,30 @@ package mcp // TODO: expand the aliases type ( - CallToolRequest = ServerRequest[*CallToolParams] - CompleteRequest = ServerRequest[*CompleteParams] - GetPromptRequest = ServerRequest[*GetPromptParams] - InitializedRequest = ServerRequest[*InitializedParams] - ListPromptsRequest = ServerRequest[*ListPromptsParams] - ListResourcesRequest = ServerRequest[*ListResourcesParams] - ListResourceTemplatesRequest = ServerRequest[*ListResourceTemplatesParams] - ListToolsRequest = ServerRequest[*ListToolsParams] - ProgressNotificationRequest = ServerRequest[*ProgressNotificationParams] - ReadResourceRequest = ServerRequest[*ReadResourceParams] - RootsListChangedRequest = ServerRequest[*RootsListChangedParams] - SubscribeRequest = ServerRequest[*SubscribeParams] - UnsubscribeRequest = ServerRequest[*UnsubscribeParams] + CallToolRequest = ServerRequest[*CallToolParams] + CompleteRequest = ServerRequest[*CompleteParams] + GetPromptRequest = ServerRequest[*GetPromptParams] + InitializedServerRequest = ServerRequest[*InitializedParams] + ListPromptsRequest = ServerRequest[*ListPromptsParams] + ListResourcesRequest = ServerRequest[*ListResourcesParams] + ListResourceTemplatesRequest = ServerRequest[*ListResourceTemplatesParams] + ListToolsRequest = ServerRequest[*ListToolsParams] + ProgressNotificationServerRequest = ServerRequest[*ProgressNotificationParams] + ReadResourceRequest = ServerRequest[*ReadResourceParams] + RootsListChangedRequest = ServerRequest[*RootsListChangedParams] + SubscribeRequest = ServerRequest[*SubscribeParams] + UnsubscribeRequest = ServerRequest[*UnsubscribeParams] +) + +type ( + CreateMessageRequest = ClientRequest[*CreateMessageParams] + InitializedClientRequest = ClientRequest[*InitializedParams] + InitializeRequest = ClientRequest[*InitializeParams] + ListRootsRequest = ClientRequest[*ListRootsParams] + LoggingMessageRequest = ClientRequest[*LoggingMessageParams] + ProgressNotificationClientRequest = ClientRequest[*ProgressNotificationParams] + PromptListChangedRequest = ClientRequest[*PromptListChangedParams] + ResourceListChangedRequest = ClientRequest[*ResourceListChangedParams] + ResourceUpdatedNotificationRequest = ClientRequest[*ResourceUpdatedNotificationParams] + ToolListChangedRequest = ClientRequest[*ToolListChangedParams] ) diff --git a/mcp/server.go b/mcp/server.go index 7af83824..1cceacbe 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -54,14 +54,14 @@ type ServerOptions struct { // Optional instructions for connected clients. Instructions string // If non-nil, called when "notifications/initialized" is received. - InitializedHandler func(context.Context, *InitializedRequest) + InitializedHandler func(context.Context, *InitializedServerRequest) // PageSize is the maximum number of items to return in a single page for // list methods (e.g. ListTools). PageSize int // If non-nil, called when "notifications/roots/list_changed" is received. RootsListChangedHandler func(context.Context, *RootsListChangedRequest) // If non-nil, called when "notifications/progress" is received. - ProgressNotificationHandler func(context.Context, *ProgressNotificationRequest) + ProgressNotificationHandler func(context.Context, *ProgressNotificationServerRequest) // If non-nil, called when "completion/complete" is received. CompletionHandler func(context.Context, *CompleteRequest) (*CompleteResult, error) // If non-zero, defines an interval for regular "ping" requests. diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 5cd04eca..603be473 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -121,7 +121,7 @@ func TestStreamableTransports(t *testing.T) { HTTPClient: httpClient, } client := NewClient(testImpl, &ClientOptions{ - CreateMessageHandler: func(context.Context, *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) { + CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) { return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil }, }) @@ -255,7 +255,7 @@ func testClientReplay(t *testing.T, test clientReplayTest) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() client := NewClient(testImpl, &ClientOptions{ - ProgressNotificationHandler: func(ctx context.Context, req *ClientRequest[*ProgressNotificationParams]) { + ProgressNotificationHandler: func(ctx context.Context, req *ProgressNotificationClientRequest) { notifications <- req.Params.Message }, }) @@ -344,7 +344,7 @@ func TestServerInitiatedSSE(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() client := NewClient(testImpl, &ClientOptions{ - ToolListChangedHandler: func(context.Context, *ClientRequest[*ToolListChangedParams]) { + ToolListChangedHandler: func(context.Context, *ToolListChangedRequest) { notifications <- "toolListChanged" }, })