Skip to content

Commit 6a54215

Browse files
Implement sampling in Stdio (#461)
* Implement sampling in Stdio * Update docs * Update server/sampling_test.go Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * use proper context * Update www/docs/pages/clients/advanced-sampling.mdx Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * enhance test * tweak example server * fix test * fix test * fixes * fix --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent 656a7b4 commit 6a54215

File tree

23 files changed

+2327
-20
lines changed

23 files changed

+2327
-20
lines changed

client/client.go

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ type Client struct {
2222
requestID atomic.Int64
2323
clientCapabilities mcp.ClientCapabilities
2424
serverCapabilities mcp.ServerCapabilities
25+
samplingHandler SamplingHandler
2526
}
2627

2728
type ClientOption func(*Client)
@@ -33,6 +34,14 @@ func WithClientCapabilities(capabilities mcp.ClientCapabilities) ClientOption {
3334
}
3435
}
3536

37+
// WithSamplingHandler sets the sampling handler for the client.
38+
// When set, the client will declare sampling capability during initialization.
39+
func WithSamplingHandler(handler SamplingHandler) ClientOption {
40+
return func(c *Client) {
41+
c.samplingHandler = handler
42+
}
43+
}
44+
3645
// WithSession assumes a MCP Session has already been initialized
3746
func WithSession() ClientOption {
3847
return func(c *Client) {
@@ -78,6 +87,12 @@ func (c *Client) Start(ctx context.Context) error {
7887
handler(notification)
7988
}
8089
})
90+
91+
// Set up request handler for bidirectional communication (e.g., sampling)
92+
if bidirectional, ok := c.transport.(transport.BidirectionalInterface); ok {
93+
bidirectional.SetRequestHandler(c.handleIncomingRequest)
94+
}
95+
8196
return nil
8297
}
8398

@@ -134,6 +149,12 @@ func (c *Client) Initialize(
134149
ctx context.Context,
135150
request mcp.InitializeRequest,
136151
) (*mcp.InitializeResult, error) {
152+
// Merge client capabilities with sampling capability if handler is configured
153+
capabilities := request.Params.Capabilities
154+
if c.samplingHandler != nil {
155+
capabilities.Sampling = &struct{}{}
156+
}
157+
137158
// Ensure we send a params object with all required fields
138159
params := struct {
139160
ProtocolVersion string `json:"protocolVersion"`
@@ -142,7 +163,7 @@ func (c *Client) Initialize(
142163
}{
143164
ProtocolVersion: request.Params.ProtocolVersion,
144165
ClientInfo: request.Params.ClientInfo,
145-
Capabilities: request.Params.Capabilities, // Will be empty struct if not set
166+
Capabilities: capabilities,
146167
}
147168

148169
response, err := c.sendRequest(ctx, "initialize", params)
@@ -405,6 +426,64 @@ func (c *Client) Complete(
405426
return &result, nil
406427
}
407428

429+
// handleIncomingRequest processes incoming requests from the server.
430+
// This is the main entry point for server-to-client requests like sampling.
431+
func (c *Client) handleIncomingRequest(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
432+
switch request.Method {
433+
case string(mcp.MethodSamplingCreateMessage):
434+
return c.handleSamplingRequestTransport(ctx, request)
435+
default:
436+
return nil, fmt.Errorf("unsupported request method: %s", request.Method)
437+
}
438+
}
439+
440+
// handleSamplingRequestTransport handles sampling requests at the transport level.
441+
func (c *Client) handleSamplingRequestTransport(ctx context.Context, request transport.JSONRPCRequest) (*transport.JSONRPCResponse, error) {
442+
if c.samplingHandler == nil {
443+
return nil, fmt.Errorf("no sampling handler configured")
444+
}
445+
446+
// Parse the request parameters
447+
var params mcp.CreateMessageParams
448+
if request.Params != nil {
449+
paramsBytes, err := json.Marshal(request.Params)
450+
if err != nil {
451+
return nil, fmt.Errorf("failed to marshal params: %w", err)
452+
}
453+
if err := json.Unmarshal(paramsBytes, &params); err != nil {
454+
return nil, fmt.Errorf("failed to unmarshal params: %w", err)
455+
}
456+
}
457+
458+
// Create the MCP request
459+
mcpRequest := mcp.CreateMessageRequest{
460+
Request: mcp.Request{
461+
Method: string(mcp.MethodSamplingCreateMessage),
462+
},
463+
CreateMessageParams: params,
464+
}
465+
466+
// Call the sampling handler
467+
result, err := c.samplingHandler.CreateMessage(ctx, mcpRequest)
468+
if err != nil {
469+
return nil, err
470+
}
471+
472+
// Marshal the result
473+
resultBytes, err := json.Marshal(result)
474+
if err != nil {
475+
return nil, fmt.Errorf("failed to marshal result: %w", err)
476+
}
477+
478+
// Create the transport response
479+
response := &transport.JSONRPCResponse{
480+
JSONRPC: mcp.JSONRPC_VERSION,
481+
ID: request.ID,
482+
Result: json.RawMessage(resultBytes),
483+
}
484+
485+
return response, nil
486+
}
408487
func listByPage[T any](
409488
ctx context.Context,
410489
client *Client,

client/sampling.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
package client
2+
3+
import (
4+
"context"
5+
6+
"github.com/mark3labs/mcp-go/mcp"
7+
)
8+
9+
// SamplingHandler defines the interface for handling sampling requests from servers.
10+
// Clients can implement this interface to provide LLM sampling capabilities to servers.
11+
type SamplingHandler interface {
12+
// CreateMessage handles a sampling request from the server and returns the generated message.
13+
// The implementation should:
14+
// 1. Validate the request parameters
15+
// 2. Optionally prompt the user for approval (human-in-the-loop)
16+
// 3. Select an appropriate model based on preferences
17+
// 4. Generate the response using the selected model
18+
// 5. Return the result with model information and stop reason
19+
CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error)
20+
}

0 commit comments

Comments
 (0)