Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions mcp/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ func NewClient(impl *Implementation, opts *ClientOptions) *Client {
type ClientOptions struct {
// Handler for sampling.
// Called when a server calls CreateMessage.
CreateMessageHandler func(context.Context, *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error)
CreateMessageHandler func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error)
// Handlers for notifications from the server.
ToolListChangedHandler func(context.Context, *ClientRequest[*ToolListChangedParams])
PromptListChangedHandler func(context.Context, *ClientRequest[*PromptListChangedParams])
ResourceListChangedHandler func(context.Context, *ClientRequest[*ResourceListChangedParams])
ResourceUpdatedHandler func(context.Context, *ClientRequest[*ResourceUpdatedNotificationParams])
LoggingMessageHandler func(context.Context, *ClientRequest[*LoggingMessageParams])
ProgressNotificationHandler func(context.Context, *ClientRequest[*ProgressNotificationParams])
ToolListChangedHandler func(context.Context, *ToolListChangedRequest)
PromptListChangedHandler func(context.Context, *PromptListChangedRequest)
ResourceListChangedHandler func(context.Context, *ResourceListChangedRequest)
ResourceUpdatedHandler func(context.Context, *ResourceUpdatedNotificationRequest)
LoggingMessageHandler func(context.Context, *LoggingMessageRequest)
ProgressNotificationHandler func(context.Context, *ProgressNotificationClientRequest)
// If non-zero, defines an interval for regular "ping" requests.
// If the peer fails to respond to pings originating from the keepalive check,
// the session is automatically closed.
Expand Down Expand Up @@ -132,7 +132,7 @@ func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptio
ClientInfo: c.impl,
Capabilities: c.capabilities(),
}
req := &ClientRequest[*InitializeParams]{Session: cs, Params: params}
req := &InitializeRequest{Session: cs, Params: params}
res, err := handleSend[*InitializeResult](ctx, methodInitialize, req)
if err != nil {
_ = cs.Close()
Expand All @@ -145,7 +145,7 @@ func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptio
if hc, ok := cs.mcpConn.(clientConnection); ok {
hc.sessionUpdated(cs.state)
}
req2 := &ClientRequest[*InitializedParams]{Session: cs, Params: &InitializedParams{}}
req2 := &InitializedClientRequest{Session: cs, Params: &InitializedParams{}}
if err := handleNotify(ctx, notificationInitialized, req2); err != nil {
_ = cs.Close()
return nil, err
Expand Down Expand Up @@ -248,7 +248,7 @@ func changeAndNotify[P Params](c *Client, notification string, params P, change
notifySessions(sessions, notification, params)
}

func (c *Client) listRoots(_ context.Context, req *ClientRequest[*ListRootsParams]) (*ListRootsResult, error) {
func (c *Client) listRoots(_ context.Context, req *ListRootsRequest) (*ListRootsResult, error) {
c.mu.Lock()
defer c.mu.Unlock()
roots := slices.Collect(c.roots.all())
Expand All @@ -260,7 +260,7 @@ func (c *Client) listRoots(_ context.Context, req *ClientRequest[*ListRootsParam
}, nil
}

func (c *Client) createMessage(ctx context.Context, req *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) {
func (c *Client) createMessage(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) {
if c.opts.CreateMessageHandler == nil {
// TODO: wrap or annotate this error? Pick a standard code?
return nil, jsonrpc2.NewError(CodeUnsupportedMethod, "client does not support CreateMessage")
Expand Down Expand Up @@ -436,35 +436,35 @@ func (cs *ClientSession) Unsubscribe(ctx context.Context, params *UnsubscribePar
return err
}

func (c *Client) callToolChangedHandler(ctx context.Context, req *ClientRequest[*ToolListChangedParams]) (Result, error) {
func (c *Client) callToolChangedHandler(ctx context.Context, req *ToolListChangedRequest) (Result, error) {
if h := c.opts.ToolListChangedHandler; h != nil {
h(ctx, req)
}
return nil, nil
}

func (c *Client) callPromptChangedHandler(ctx context.Context, req *ClientRequest[*PromptListChangedParams]) (Result, error) {
func (c *Client) callPromptChangedHandler(ctx context.Context, req *PromptListChangedRequest) (Result, error) {
if h := c.opts.PromptListChangedHandler; h != nil {
h(ctx, req)
}
return nil, nil
}

func (c *Client) callResourceChangedHandler(ctx context.Context, req *ClientRequest[*ResourceListChangedParams]) (Result, error) {
func (c *Client) callResourceChangedHandler(ctx context.Context, req *ResourceListChangedRequest) (Result, error) {
if h := c.opts.ResourceListChangedHandler; h != nil {
h(ctx, req)
}
return nil, nil
}

func (c *Client) callResourceUpdatedHandler(ctx context.Context, req *ClientRequest[*ResourceUpdatedNotificationParams]) (Result, error) {
func (c *Client) callResourceUpdatedHandler(ctx context.Context, req *ResourceUpdatedNotificationRequest) (Result, error) {
if h := c.opts.ResourceUpdatedHandler; h != nil {
h(ctx, req)
}
return nil, nil
}

func (c *Client) callLoggingHandler(ctx context.Context, req *ClientRequest[*LoggingMessageParams]) (Result, error) {
func (c *Client) callLoggingHandler(ctx context.Context, req *LoggingMessageRequest) (Result, error) {
if h := c.opts.LoggingMessageHandler; h != nil {
h(ctx, req)
}
Expand Down
2 changes: 1 addition & 1 deletion mcp/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ func TestClientCapabilities(t *testing.T) {
name: "With sampling",
configureClient: func(s *Client) {},
clientOpts: ClientOptions{
CreateMessageHandler: func(context.Context, *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) {
CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) {
return nil, nil
},
},
Expand Down
24 changes: 12 additions & 12 deletions mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ func TestEndToEnd(t *testing.T) {
}

sopts := &ServerOptions{
InitializedHandler: func(context.Context, *InitializedRequest) {
InitializedHandler: func(context.Context, *InitializedServerRequest) {
notificationChans["initialized"] <- 0
},
RootsListChangedHandler: func(context.Context, *RootsListChangedRequest) {
notificationChans["roots"] <- 0
},
ProgressNotificationHandler: func(context.Context, *ProgressNotificationRequest) {
ProgressNotificationHandler: func(context.Context, *ProgressNotificationServerRequest) {
notificationChans["progress_server"] <- 0
},
SubscribeHandler: func(context.Context, *SubscribeRequest) error {
Expand Down Expand Up @@ -129,25 +129,25 @@ func TestEndToEnd(t *testing.T) {

loggingMessages := make(chan *LoggingMessageParams, 100) // big enough for all logging
opts := &ClientOptions{
CreateMessageHandler: func(context.Context, *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) {
CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) {
return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil
},
ToolListChangedHandler: func(context.Context, *ClientRequest[*ToolListChangedParams]) {
ToolListChangedHandler: func(context.Context, *ToolListChangedRequest) {
notificationChans["tools"] <- 0
},
PromptListChangedHandler: func(context.Context, *ClientRequest[*PromptListChangedParams]) {
PromptListChangedHandler: func(context.Context, *PromptListChangedRequest) {
notificationChans["prompts"] <- 0
},
ResourceListChangedHandler: func(context.Context, *ClientRequest[*ResourceListChangedParams]) {
ResourceListChangedHandler: func(context.Context, *ResourceListChangedRequest) {
notificationChans["resources"] <- 0
},
LoggingMessageHandler: func(_ context.Context, req *ClientRequest[*LoggingMessageParams]) {
LoggingMessageHandler: func(_ context.Context, req *LoggingMessageRequest) {
loggingMessages <- req.Params
},
ProgressNotificationHandler: func(context.Context, *ClientRequest[*ProgressNotificationParams]) {
ProgressNotificationHandler: func(context.Context, *ProgressNotificationClientRequest) {
notificationChans["progress_client"] <- 0
},
ResourceUpdatedHandler: func(context.Context, *ClientRequest[*ResourceUpdatedNotificationParams]) {
ResourceUpdatedHandler: func(context.Context, *ResourceUpdatedNotificationRequest) {
notificationChans["resource_updated"] <- 0
},
}
Expand Down Expand Up @@ -992,10 +992,10 @@ func TestAddTool_DuplicateNoPanicAndNoDuplicate(t *testing.T) {
func TestSynchronousNotifications(t *testing.T) {
var toolsChanged atomic.Bool
clientOpts := &ClientOptions{
ToolListChangedHandler: func(ctx context.Context, req *ClientRequest[*ToolListChangedParams]) {
ToolListChangedHandler: func(ctx context.Context, req *ToolListChangedRequest) {
toolsChanged.Store(true)
},
CreateMessageHandler: func(ctx context.Context, req *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) {
CreateMessageHandler: func(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) {
if !toolsChanged.Load() {
return nil, fmt.Errorf("didn't get a tools changed notification")
}
Expand Down Expand Up @@ -1057,7 +1057,7 @@ func TestNoDistributedDeadlock(t *testing.T) {
// possible, and in any case making tool calls asynchronous by default
// delegates synchronization to the user.
clientOpts := &ClientOptions{
CreateMessageHandler: func(ctx context.Context, req *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) {
CreateMessageHandler: func(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) {
req.Session.CallTool(ctx, &CallToolParams{Name: "tool2"})
return &CreateMessageResult{Content: &TextContent{}}, nil
},
Expand Down
39 changes: 26 additions & 13 deletions mcp/requests.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,30 @@ package mcp

// TODO: expand the aliases
type (
CallToolRequest = ServerRequest[*CallToolParams]
CompleteRequest = ServerRequest[*CompleteParams]
GetPromptRequest = ServerRequest[*GetPromptParams]
InitializedRequest = ServerRequest[*InitializedParams]
ListPromptsRequest = ServerRequest[*ListPromptsParams]
ListResourcesRequest = ServerRequest[*ListResourcesParams]
ListResourceTemplatesRequest = ServerRequest[*ListResourceTemplatesParams]
ListToolsRequest = ServerRequest[*ListToolsParams]
ProgressNotificationRequest = ServerRequest[*ProgressNotificationParams]
ReadResourceRequest = ServerRequest[*ReadResourceParams]
RootsListChangedRequest = ServerRequest[*RootsListChangedParams]
SubscribeRequest = ServerRequest[*SubscribeParams]
UnsubscribeRequest = ServerRequest[*UnsubscribeParams]
CallToolRequest = ServerRequest[*CallToolParams]
CompleteRequest = ServerRequest[*CompleteParams]
GetPromptRequest = ServerRequest[*GetPromptParams]
InitializedServerRequest = ServerRequest[*InitializedParams]
ListPromptsRequest = ServerRequest[*ListPromptsParams]
ListResourcesRequest = ServerRequest[*ListResourcesParams]
ListResourceTemplatesRequest = ServerRequest[*ListResourceTemplatesParams]
ListToolsRequest = ServerRequest[*ListToolsParams]
ProgressNotificationServerRequest = ServerRequest[*ProgressNotificationParams]
ReadResourceRequest = ServerRequest[*ReadResourceParams]
RootsListChangedRequest = ServerRequest[*RootsListChangedParams]
SubscribeRequest = ServerRequest[*SubscribeParams]
UnsubscribeRequest = ServerRequest[*UnsubscribeParams]
)

type (
CreateMessageRequest = ClientRequest[*CreateMessageParams]
InitializedClientRequest = ClientRequest[*InitializedParams]
InitializeRequest = ClientRequest[*InitializeParams]
ListRootsRequest = ClientRequest[*ListRootsParams]
LoggingMessageRequest = ClientRequest[*LoggingMessageParams]
ProgressNotificationClientRequest = ClientRequest[*ProgressNotificationParams]
PromptListChangedRequest = ClientRequest[*PromptListChangedParams]
ResourceListChangedRequest = ClientRequest[*ResourceListChangedParams]
ResourceUpdatedNotificationRequest = ClientRequest[*ResourceUpdatedNotificationParams]
ToolListChangedRequest = ClientRequest[*ToolListChangedParams]
)
4 changes: 2 additions & 2 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ type ServerOptions struct {
// Optional instructions for connected clients.
Instructions string
// If non-nil, called when "notifications/initialized" is received.
InitializedHandler func(context.Context, *InitializedRequest)
InitializedHandler func(context.Context, *InitializedServerRequest)
// PageSize is the maximum number of items to return in a single page for
// list methods (e.g. ListTools).
PageSize int
// If non-nil, called when "notifications/roots/list_changed" is received.
RootsListChangedHandler func(context.Context, *RootsListChangedRequest)
// If non-nil, called when "notifications/progress" is received.
ProgressNotificationHandler func(context.Context, *ProgressNotificationRequest)
ProgressNotificationHandler func(context.Context, *ProgressNotificationServerRequest)
// If non-nil, called when "completion/complete" is received.
CompletionHandler func(context.Context, *CompleteRequest) (*CompleteResult, error)
// If non-zero, defines an interval for regular "ping" requests.
Expand Down
6 changes: 3 additions & 3 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func TestStreamableTransports(t *testing.T) {
HTTPClient: httpClient,
}
client := NewClient(testImpl, &ClientOptions{
CreateMessageHandler: func(context.Context, *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) {
CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) {
return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil
},
})
Expand Down Expand Up @@ -255,7 +255,7 @@ func testClientReplay(t *testing.T, test clientReplayTest) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
client := NewClient(testImpl, &ClientOptions{
ProgressNotificationHandler: func(ctx context.Context, req *ClientRequest[*ProgressNotificationParams]) {
ProgressNotificationHandler: func(ctx context.Context, req *ProgressNotificationClientRequest) {
notifications <- req.Params.Message
},
})
Expand Down Expand Up @@ -344,7 +344,7 @@ func TestServerInitiatedSSE(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
client := NewClient(testImpl, &ClientOptions{
ToolListChangedHandler: func(context.Context, *ClientRequest[*ToolListChangedParams]) {
ToolListChangedHandler: func(context.Context, *ToolListChangedRequest) {
notifications <- "toolListChanged"
},
})
Expand Down