From d36849a1e56795438e29bc744088af05e60c08a1 Mon Sep 17 00:00:00 2001 From: Sam Thanawalla Date: Wed, 20 Aug 2025 17:25:31 +0000 Subject: [PATCH] mcp/server: expose InitializeParams to ServerSession This CL enables the server session to see what capabilities the client has by introducing the InitializeParams() function. Fixes #141 --- mcp/client.go | 17 ++++++++++------- mcp/client_test.go | 45 +++++++++++++++++++++++++++++++++++++++++++++ mcp/server.go | 2 ++ 3 files changed, 57 insertions(+), 7 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index d1d17502..0151c942 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -105,6 +105,15 @@ func (e unsupportedProtocolVersionError) Error() string { // ClientSessionOptions is reserved for future use. type ClientSessionOptions struct{} +func (c *Client) capabilities() *ClientCapabilities { + caps := &ClientCapabilities{} + caps.Roots.ListChanged = true + if c.opts.CreateMessageHandler != nil { + caps.Sampling = &SamplingCapabilities{} + } + return caps +} + // Connect begins an MCP session by connecting to a server over the given // transport, and initializing the session. // @@ -118,16 +127,10 @@ func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptio return nil, err } - caps := &ClientCapabilities{} - caps.Roots.ListChanged = true - if c.opts.CreateMessageHandler != nil { - caps.Sampling = &SamplingCapabilities{} - } - params := &InitializeParams{ ProtocolVersion: latestProtocolVersion, ClientInfo: c.impl, - Capabilities: caps, + Capabilities: c.capabilities(), } req := &ClientRequest[*InitializeParams]{Session: cs, Params: params} res, err := handleSend[*InitializeResult](ctx, methodInitialize, req) diff --git a/mcp/client_test.go b/mcp/client_test.go index 7920c55c..469fa3fb 100644 --- a/mcp/client_test.go +++ b/mcp/client_test.go @@ -190,3 +190,48 @@ func TestClientPaginateVariousPageSizes(t *testing.T) { }) } } + +func TestClientCapabilities(t *testing.T) { + testCases := []struct { + name string + configureClient func(s *Client) + clientOpts ClientOptions + wantCapabilities *ClientCapabilities + }{ + { + name: "With initial capabilities", + configureClient: func(s *Client) {}, + wantCapabilities: &ClientCapabilities{ + Roots: struct { + ListChanged bool "json:\"listChanged,omitempty\"" + }{ListChanged: true}, + }, + }, + { + name: "With sampling", + configureClient: func(s *Client) {}, + clientOpts: ClientOptions{ + CreateMessageHandler: func(context.Context, *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) { + return nil, nil + }, + }, + wantCapabilities: &ClientCapabilities{ + Roots: struct { + ListChanged bool "json:\"listChanged,omitempty\"" + }{ListChanged: true}, + Sampling: &SamplingCapabilities{}, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + client := NewClient(testImpl, &tc.clientOpts) + tc.configureClient(client) + gotCapabilities := client.capabilities() + if diff := cmp.Diff(tc.wantCapabilities, gotCapabilities); diff != "" { + t.Errorf("capabilities() mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/mcp/server.go b/mcp/server.go index 65115afa..48092e7d 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -848,6 +848,8 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any, return handleReceive(ctx, ss, req) } +func (ss *ServerSession) InitializeParams() *InitializeParams { return ss.state.InitializeParams } + func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParams) (*InitializeResult, error) { if params == nil { return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams)