@@ -22,6 +22,7 @@ type Client struct {
2222 requestID atomic.Int64
2323 clientCapabilities mcp.ClientCapabilities
2424 serverCapabilities mcp.ServerCapabilities
25+ samplingHandler SamplingHandler
2526}
2627
2728type ClientOption func (* Client )
@@ -33,6 +34,14 @@ func WithClientCapabilities(capabilities mcp.ClientCapabilities) ClientOption {
3334 }
3435}
3536
37+ // WithSamplingHandler sets the sampling handler for the client.
38+ // When set, the client will declare sampling capability during initialization.
39+ func WithSamplingHandler (handler SamplingHandler ) ClientOption {
40+ return func (c * Client ) {
41+ c .samplingHandler = handler
42+ }
43+ }
44+
3645// WithSession assumes a MCP Session has already been initialized
3746func WithSession () ClientOption {
3847 return func (c * Client ) {
@@ -78,6 +87,12 @@ func (c *Client) Start(ctx context.Context) error {
7887 handler (notification )
7988 }
8089 })
90+
91+ // Set up request handler for bidirectional communication (e.g., sampling)
92+ if bidirectional , ok := c .transport .(transport.BidirectionalInterface ); ok {
93+ bidirectional .SetRequestHandler (c .handleIncomingRequest )
94+ }
95+
8196 return nil
8297}
8398
@@ -134,6 +149,12 @@ func (c *Client) Initialize(
134149 ctx context.Context ,
135150 request mcp.InitializeRequest ,
136151) (* mcp.InitializeResult , error ) {
152+ // Merge client capabilities with sampling capability if handler is configured
153+ capabilities := request .Params .Capabilities
154+ if c .samplingHandler != nil {
155+ capabilities .Sampling = & struct {}{}
156+ }
157+
137158 // Ensure we send a params object with all required fields
138159 params := struct {
139160 ProtocolVersion string `json:"protocolVersion"`
@@ -142,7 +163,7 @@ func (c *Client) Initialize(
142163 }{
143164 ProtocolVersion : request .Params .ProtocolVersion ,
144165 ClientInfo : request .Params .ClientInfo ,
145- Capabilities : request . Params . Capabilities , // Will be empty struct if not set
166+ Capabilities : capabilities ,
146167 }
147168
148169 response , err := c .sendRequest (ctx , "initialize" , params )
@@ -405,6 +426,64 @@ func (c *Client) Complete(
405426 return & result , nil
406427}
407428
429+ // handleIncomingRequest processes incoming requests from the server.
430+ // This is the main entry point for server-to-client requests like sampling.
431+ func (c * Client ) handleIncomingRequest (ctx context.Context , request transport.JSONRPCRequest ) (* transport.JSONRPCResponse , error ) {
432+ switch request .Method {
433+ case string (mcp .MethodSamplingCreateMessage ):
434+ return c .handleSamplingRequestTransport (ctx , request )
435+ default :
436+ return nil , fmt .Errorf ("unsupported request method: %s" , request .Method )
437+ }
438+ }
439+
440+ // handleSamplingRequestTransport handles sampling requests at the transport level.
441+ func (c * Client ) handleSamplingRequestTransport (ctx context.Context , request transport.JSONRPCRequest ) (* transport.JSONRPCResponse , error ) {
442+ if c .samplingHandler == nil {
443+ return nil , fmt .Errorf ("no sampling handler configured" )
444+ }
445+
446+ // Parse the request parameters
447+ var params mcp.CreateMessageParams
448+ if request .Params != nil {
449+ paramsBytes , err := json .Marshal (request .Params )
450+ if err != nil {
451+ return nil , fmt .Errorf ("failed to marshal params: %w" , err )
452+ }
453+ if err := json .Unmarshal (paramsBytes , & params ); err != nil {
454+ return nil , fmt .Errorf ("failed to unmarshal params: %w" , err )
455+ }
456+ }
457+
458+ // Create the MCP request
459+ mcpRequest := mcp.CreateMessageRequest {
460+ Request : mcp.Request {
461+ Method : string (mcp .MethodSamplingCreateMessage ),
462+ },
463+ CreateMessageParams : params ,
464+ }
465+
466+ // Call the sampling handler
467+ result , err := c .samplingHandler .CreateMessage (ctx , mcpRequest )
468+ if err != nil {
469+ return nil , err
470+ }
471+
472+ // Marshal the result
473+ resultBytes , err := json .Marshal (result )
474+ if err != nil {
475+ return nil , fmt .Errorf ("failed to marshal result: %w" , err )
476+ }
477+
478+ // Create the transport response
479+ response := & transport.JSONRPCResponse {
480+ JSONRPC : mcp .JSONRPC_VERSION ,
481+ ID : request .ID ,
482+ Result : json .RawMessage (resultBytes ),
483+ }
484+
485+ return response , nil
486+ }
408487func listByPage [T any ](
409488 ctx context.Context ,
410489 client * Client ,
0 commit comments