Skip to content

Commit 48abccb

Browse files
mcp/server: expose InitializeParams to ServerSession (#336)
This CL enables the server session to see what capabilities the client has by introducing the InitializeParams() function. This CL also adds a test to ensure ClientCapabilities is accurate. Fixes #141
1 parent 62d8159 commit 48abccb

File tree

3 files changed

+57
-7
lines changed

3 files changed

+57
-7
lines changed

mcp/client.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,15 @@ func (e unsupportedProtocolVersionError) Error() string {
105105
// ClientSessionOptions is reserved for future use.
106106
type ClientSessionOptions struct{}
107107

108+
func (c *Client) capabilities() *ClientCapabilities {
109+
caps := &ClientCapabilities{}
110+
caps.Roots.ListChanged = true
111+
if c.opts.CreateMessageHandler != nil {
112+
caps.Sampling = &SamplingCapabilities{}
113+
}
114+
return caps
115+
}
116+
108117
// Connect begins an MCP session by connecting to a server over the given
109118
// transport, and initializing the session.
110119
//
@@ -118,16 +127,10 @@ func (c *Client) Connect(ctx context.Context, t Transport, _ *ClientSessionOptio
118127
return nil, err
119128
}
120129

121-
caps := &ClientCapabilities{}
122-
caps.Roots.ListChanged = true
123-
if c.opts.CreateMessageHandler != nil {
124-
caps.Sampling = &SamplingCapabilities{}
125-
}
126-
127130
params := &InitializeParams{
128131
ProtocolVersion: latestProtocolVersion,
129132
ClientInfo: c.impl,
130-
Capabilities: caps,
133+
Capabilities: c.capabilities(),
131134
}
132135
req := &ClientRequest[*InitializeParams]{Session: cs, Params: params}
133136
res, err := handleSend[*InitializeResult](ctx, methodInitialize, req)

mcp/client_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,3 +190,48 @@ func TestClientPaginateVariousPageSizes(t *testing.T) {
190190
})
191191
}
192192
}
193+
194+
func TestClientCapabilities(t *testing.T) {
195+
testCases := []struct {
196+
name string
197+
configureClient func(s *Client)
198+
clientOpts ClientOptions
199+
wantCapabilities *ClientCapabilities
200+
}{
201+
{
202+
name: "With initial capabilities",
203+
configureClient: func(s *Client) {},
204+
wantCapabilities: &ClientCapabilities{
205+
Roots: struct {
206+
ListChanged bool "json:\"listChanged,omitempty\""
207+
}{ListChanged: true},
208+
},
209+
},
210+
{
211+
name: "With sampling",
212+
configureClient: func(s *Client) {},
213+
clientOpts: ClientOptions{
214+
CreateMessageHandler: func(context.Context, *ClientRequest[*CreateMessageParams]) (*CreateMessageResult, error) {
215+
return nil, nil
216+
},
217+
},
218+
wantCapabilities: &ClientCapabilities{
219+
Roots: struct {
220+
ListChanged bool "json:\"listChanged,omitempty\""
221+
}{ListChanged: true},
222+
Sampling: &SamplingCapabilities{},
223+
},
224+
},
225+
}
226+
227+
for _, tc := range testCases {
228+
t.Run(tc.name, func(t *testing.T) {
229+
client := NewClient(testImpl, &tc.clientOpts)
230+
tc.configureClient(client)
231+
gotCapabilities := client.capabilities()
232+
if diff := cmp.Diff(tc.wantCapabilities, gotCapabilities); diff != "" {
233+
t.Errorf("capabilities() mismatch (-want +got):\n%s", diff)
234+
}
235+
})
236+
}
237+
}

mcp/server.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -849,6 +849,8 @@ func (ss *ServerSession) handle(ctx context.Context, req *jsonrpc.Request) (any,
849849
return handleReceive(ctx, ss, req)
850850
}
851851

852+
func (ss *ServerSession) InitializeParams() *InitializeParams { return ss.state.InitializeParams }
853+
852854
func (ss *ServerSession) initialize(ctx context.Context, params *InitializeParams) (*InitializeResult, error) {
853855
if params == nil {
854856
return nil, fmt.Errorf("%w: \"params\" must be be provided", jsonrpc2.ErrInvalidParams)

0 commit comments

Comments
 (0)