Skip to content

Commit 2ce44d4

Browse files
authored
all: expose JSON-RPC error codes, and support URLElicitationRequired errors (#655)
Expose the jsonrpc.Error type to allow access to underlying JSON-RPC error codes. Also, expose common JSON-RPC error codes in the jsonrpc package, and MCP error codes in the mcp package. Also, implement support for the MCP URL elicitation required error (-32042), enabling servers to request out-of-band user authorization and clients to automatically handle these requests with retry logic. Server-side: - Add CodeURLElicitationRequired constant - Add URLElicitationRequiredError() constructor. Client-side: - Add urlElicitationMiddleware to intercept URLElicitationRequired errors and retry calls. - Update callElicitationCompleteHandler to signal waiting operations. - Fix capability advertisement to support both form and URL modes. Fixes #452 Fixes #623
1 parent 96385fd commit 2ce44d4

File tree

12 files changed

+670
-44
lines changed

12 files changed

+670
-44
lines changed

jsonrpc/jsonrpc.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ type (
1717
Request = jsonrpc2.Request
1818
// Response is a JSON-RPC response.
1919
Response = jsonrpc2.Response
20+
// Error is a structured error in a JSON-RPC response.
21+
Error = jsonrpc2.WireError
2022
)
2123

2224
// MakeID coerces the given Go value to an ID. The value should be the
@@ -37,3 +39,18 @@ func EncodeMessage(msg Message) ([]byte, error) {
3739
func DecodeMessage(data []byte) (Message, error) {
3840
return jsonrpc2.DecodeMessage(data)
3941
}
42+
43+
// Standard JSON-RPC 2.0 error codes.
44+
// See https://www.jsonrpc.org/specification#error_object
45+
const (
46+
// CodeParseError indicates invalid JSON was received by the server.
47+
CodeParseError = -32700
48+
// CodeInvalidRequest indicates the JSON sent is not a valid Request object.
49+
CodeInvalidRequest = -32600
50+
// CodeMethodNotFound indicates the method does not exist or is not available.
51+
CodeMethodNotFound = -32601
52+
// CodeInvalidParams indicates invalid method parameter(s).
53+
CodeInvalidParams = -32602
54+
// CodeInternalError indicates an internal JSON-RPC error.
55+
CodeInternalError = -32603
56+
)

mcp/client.go

Lines changed: 177 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ package mcp
77
import (
88
"context"
99
"encoding/json"
10+
"errors"
1011
"fmt"
1112
"iter"
1213
"slices"
14+
"strings"
1315
"sync"
1416
"sync/atomic"
1517
"time"
@@ -45,7 +47,7 @@ func NewClient(impl *Implementation, opts *ClientOptions) *Client {
4547
c := &Client{
4648
impl: impl,
4749
roots: newFeatureSet(func(r *Root) string { return r.URI }),
48-
sendingMethodHandler_: defaultSendingMethodHandler[*ClientSession],
50+
sendingMethodHandler_: defaultSendingMethodHandler,
4951
receivingMethodHandler_: defaultReceivingMethodHandler[*ClientSession],
5052
}
5153
if opts != nil {
@@ -134,7 +136,8 @@ func (c *Client) capabilities() *ClientCapabilities {
134136
// {"form":{}} for backward compatibility, but we explicitly set the form
135137
// capability.
136138
caps.Elicitation.Form = &FormElicitationCapabilities{}
137-
} else if slices.Contains(modes, "url") {
139+
}
140+
if slices.Contains(modes, "url") {
138141
caps.Elicitation.URL = &URLElicitationCapabilities{}
139142
}
140143
}
@@ -206,6 +209,10 @@ type ClientSession struct {
206209
// No mutex is (currently) required to guard the session state, because it is
207210
// only set synchronously during Client.Connect.
208211
state clientSessionState
212+
213+
// Pending URL elicitations waiting for completion notifications.
214+
pendingElicitationsMu sync.Mutex
215+
pendingElicitations map[string]chan struct{}
209216
}
210217

211218
type clientSessionState struct {
@@ -250,6 +257,46 @@ func (cs *ClientSession) Wait() error {
250257
return cs.conn.Wait()
251258
}
252259

260+
// registerElicitationWaiter registers a waiter for an elicitation complete
261+
// notification with the given elicitation ID. It returns two functions: an await
262+
// function that waits for the notification or context cancellation, and a cleanup
263+
// function that must be called to unregister the waiter. This must be called before
264+
// triggering the elicitation to avoid a race condition where the notification
265+
// arrives before the waiter is registered.
266+
//
267+
// The cleanup function must be called even if the await function is never called,
268+
// to prevent leaking the registration.
269+
func (cs *ClientSession) registerElicitationWaiter(elicitationID string) (await func(context.Context) error, cleanup func()) {
270+
// Create a channel for this elicitation.
271+
ch := make(chan struct{}, 1)
272+
273+
// Register the channel.
274+
cs.pendingElicitationsMu.Lock()
275+
if cs.pendingElicitations == nil {
276+
cs.pendingElicitations = make(map[string]chan struct{})
277+
}
278+
cs.pendingElicitations[elicitationID] = ch
279+
cs.pendingElicitationsMu.Unlock()
280+
281+
// Return await and cleanup functions.
282+
await = func(ctx context.Context) error {
283+
select {
284+
case <-ctx.Done():
285+
return fmt.Errorf("context cancelled while waiting for elicitation completion: %w", ctx.Err())
286+
case <-ch:
287+
return nil
288+
}
289+
}
290+
291+
cleanup = func() {
292+
cs.pendingElicitationsMu.Lock()
293+
delete(cs.pendingElicitations, elicitationID)
294+
cs.pendingElicitationsMu.Unlock()
295+
}
296+
297+
return await, cleanup
298+
}
299+
253300
// startKeepalive starts the keepalive mechanism for this client session.
254301
func (cs *ClientSession) startKeepalive(interval time.Duration) {
255302
startKeepalive(cs, interval, &cs.keepaliveCancel)
@@ -304,14 +351,118 @@ func (c *Client) listRoots(_ context.Context, req *ListRootsRequest) (*ListRoots
304351
func (c *Client) createMessage(ctx context.Context, req *CreateMessageRequest) (*CreateMessageResult, error) {
305352
if c.opts.CreateMessageHandler == nil {
306353
// TODO: wrap or annotate this error? Pick a standard code?
307-
return nil, jsonrpc2.NewError(codeUnsupportedMethod, "client does not support CreateMessage")
354+
return nil, &jsonrpc.Error{Code: codeUnsupportedMethod, Message: "client does not support CreateMessage"}
308355
}
309356
return c.opts.CreateMessageHandler(ctx, req)
310357
}
311358

359+
// urlElicitationMiddleware returns middleware that automatically handles URL elicitation
360+
// required errors by executing the elicitation handler, waiting for completion notifications,
361+
// and retrying the operation.
362+
//
363+
// This middleware should be added to clients that want automatic URL elicitation handling:
364+
//
365+
// client := mcp.NewClient(impl, opts)
366+
// client.AddSendingMiddleware(mcp.urlElicitationMiddleware())
367+
//
368+
// TODO(rfindley): this isn't strictly necessary for the SEP, but may be
369+
// useful. Propose exporting it.
370+
func urlElicitationMiddleware() Middleware {
371+
return func(next MethodHandler) MethodHandler {
372+
return func(ctx context.Context, method string, req Request) (Result, error) {
373+
// Call the underlying handler.
374+
res, err := next(ctx, method, req)
375+
if err == nil {
376+
return res, nil
377+
}
378+
379+
// Check if this is a URL elicitation required error.
380+
var rpcErr *jsonrpc.Error
381+
if !errors.As(err, &rpcErr) || rpcErr.Code != CodeURLElicitationRequired {
382+
return res, err
383+
}
384+
385+
// Notifications don't support retries.
386+
if strings.HasPrefix(method, "notifications/") {
387+
return res, err
388+
}
389+
390+
// Extract the client session.
391+
cs, ok := req.GetSession().(*ClientSession)
392+
if !ok {
393+
return res, err
394+
}
395+
396+
// Check if the client has an elicitation handler.
397+
if cs.client.opts.ElicitationHandler == nil {
398+
return res, err
399+
}
400+
401+
// Parse the elicitations from the error data.
402+
var errorData struct {
403+
Elicitations []*ElicitParams `json:"elicitations"`
404+
}
405+
if rpcErr.Data != nil {
406+
if err := json.Unmarshal(rpcErr.Data, &errorData); err != nil {
407+
return nil, fmt.Errorf("failed to parse URL elicitation error data: %w", err)
408+
}
409+
}
410+
411+
// Validate that all elicitations are URL mode.
412+
for _, elicit := range errorData.Elicitations {
413+
mode := elicit.Mode
414+
if mode == "" {
415+
mode = "form" // Default mode.
416+
}
417+
if mode != "url" {
418+
return nil, fmt.Errorf("URLElicitationRequired error must only contain URL mode elicitations, got %q", mode)
419+
}
420+
}
421+
422+
// Register waiters for all elicitations before executing handlers
423+
// to avoid race condition where notification arrives before waiter is registered.
424+
type waiter struct {
425+
await func(context.Context) error
426+
cleanup func()
427+
}
428+
waiters := make([]waiter, 0, len(errorData.Elicitations))
429+
for _, elicitParams := range errorData.Elicitations {
430+
await, cleanup := cs.registerElicitationWaiter(elicitParams.ElicitationID)
431+
waiters = append(waiters, waiter{await: await, cleanup: cleanup})
432+
}
433+
434+
// Ensure cleanup happens even if we return early.
435+
defer func() {
436+
for _, w := range waiters {
437+
w.cleanup()
438+
}
439+
}()
440+
441+
// Execute the elicitation handler for each elicitation.
442+
for _, elicitParams := range errorData.Elicitations {
443+
elicitReq := newClientRequest(cs, elicitParams)
444+
_, elicitErr := cs.client.elicit(ctx, elicitReq)
445+
if elicitErr != nil {
446+
return nil, fmt.Errorf("URL elicitation failed: %w", elicitErr)
447+
}
448+
}
449+
450+
// Wait for all elicitations to complete.
451+
for _, w := range waiters {
452+
if err := w.await(ctx); err != nil {
453+
return nil, err
454+
}
455+
}
456+
457+
// All elicitations complete, retry the original operation.
458+
return next(ctx, method, req)
459+
}
460+
}
461+
}
462+
312463
func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult, error) {
313464
if c.opts.ElicitationHandler == nil {
314-
return nil, jsonrpc2.NewError(codeInvalidParams, "client does not support elicitation")
465+
return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "client does not support elicitation"}
315466
}
316467

317468
// Validate the elicitation parameters based on the mode.
@@ -323,11 +474,11 @@ func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult,
323474
switch mode {
324475
case "form":
325476
if req.Params.URL != "" {
326-
return nil, jsonrpc2.NewError(codeInvalidParams, "URL must not be set for form elicitation")
477+
return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "URL must not be set for form elicitation"}
327478
}
328479
schema, err := validateElicitSchema(req.Params.RequestedSchema)
329480
if err != nil {
330-
return nil, jsonrpc2.NewError(codeInvalidParams, err.Error())
481+
return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: err.Error()}
331482
}
332483
res, err := c.opts.ElicitationHandler(ctx, req)
333484
if err != nil {
@@ -337,28 +488,28 @@ func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult,
337488
if schema != nil && res.Content != nil {
338489
resolved, err := schema.Resolve(nil)
339490
if err != nil {
340-
return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("failed to resolve requested schema: %v", err))
491+
return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("failed to resolve requested schema: %v", err)}
341492
}
342493
if err := resolved.Validate(res.Content); err != nil {
343-
return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("elicitation result content does not match requested schema: %v", err))
494+
return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("elicitation result content does not match requested schema: %v", err)}
344495
}
345496
err = resolved.ApplyDefaults(&res.Content)
346497
if err != nil {
347-
return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("failed to apply schema defalts to elicitation result: %v", err))
498+
return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("failed to apply schema defalts to elicitation result: %v", err)}
348499
}
349500
}
350501
return res, nil
351502
case "url":
352503
if req.Params.RequestedSchema != nil {
353-
return nil, jsonrpc2.NewError(codeInvalidParams, "requestedSchema must not be set for URL elicitation")
504+
return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "requestedSchema must not be set for URL elicitation"}
354505
}
355506
if req.Params.URL == "" {
356-
return nil, jsonrpc2.NewError(codeInvalidParams, "URL must be set for URL elicitation")
507+
return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: "URL must be set for URL elicitation"}
357508
}
358509
// No schema validation for URL mode, just pass through to handler.
359510
return c.opts.ElicitationHandler(ctx, req)
360511
default:
361-
return nil, jsonrpc2.NewError(codeInvalidParams, fmt.Sprintf("unsupported elicitation mode: %q", mode))
512+
return nil, &jsonrpc.Error{Code: jsonrpc.CodeInvalidParams, Message: fmt.Sprintf("unsupported elicitation mode: %q", mode)}
362513
}
363514
}
364515

@@ -723,6 +874,20 @@ func (cs *ClientSession) callProgressNotificationHandler(ctx context.Context, pa
723874
}
724875

725876
func (c *Client) callElicitationCompleteHandler(ctx context.Context, req *ElicitationCompleteNotificationRequest) (Result, error) {
877+
// Check if there's a pending elicitation waiting for this notification.
878+
if cs, ok := req.GetSession().(*ClientSession); ok {
879+
cs.pendingElicitationsMu.Lock()
880+
if ch, exists := cs.pendingElicitations[req.Params.ElicitationID]; exists {
881+
select {
882+
case ch <- struct{}{}:
883+
default:
884+
// Channel already signaled.
885+
}
886+
}
887+
cs.pendingElicitationsMu.Unlock()
888+
}
889+
890+
// Call the user's handler if provided.
726891
if h := c.opts.ElicitationCompleteHandler; h != nil {
727892
h(ctx, req)
728893
}

mcp/client_test.go

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,61 @@ func TestClientCapabilities(t *testing.T) {
222222
Sampling: &SamplingCapabilities{},
223223
},
224224
},
225+
{
226+
name: "With form elicitation",
227+
configureClient: func(s *Client) {},
228+
clientOpts: ClientOptions{
229+
ElicitationModes: []string{"form"},
230+
ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) {
231+
return nil, nil
232+
},
233+
},
234+
wantCapabilities: &ClientCapabilities{
235+
Roots: struct {
236+
ListChanged bool "json:\"listChanged,omitempty\""
237+
}{ListChanged: true},
238+
Elicitation: &ElicitationCapabilities{
239+
Form: &FormElicitationCapabilities{},
240+
},
241+
},
242+
},
243+
{
244+
name: "With URL elicitation",
245+
configureClient: func(s *Client) {},
246+
clientOpts: ClientOptions{
247+
ElicitationModes: []string{"url"},
248+
ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) {
249+
return nil, nil
250+
},
251+
},
252+
wantCapabilities: &ClientCapabilities{
253+
Roots: struct {
254+
ListChanged bool "json:\"listChanged,omitempty\""
255+
}{ListChanged: true},
256+
Elicitation: &ElicitationCapabilities{
257+
URL: &URLElicitationCapabilities{},
258+
},
259+
},
260+
},
261+
{
262+
name: "With both form and URL elicitation",
263+
configureClient: func(s *Client) {},
264+
clientOpts: ClientOptions{
265+
ElicitationModes: []string{"form", "url"},
266+
ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) {
267+
return nil, nil
268+
},
269+
},
270+
wantCapabilities: &ClientCapabilities{
271+
Roots: struct {
272+
ListChanged bool "json:\"listChanged,omitempty\""
273+
}{ListChanged: true},
274+
Elicitation: &ElicitationCapabilities{
275+
Form: &FormElicitationCapabilities{},
276+
URL: &URLElicitationCapabilities{},
277+
},
278+
},
279+
},
225280
}
226281

227282
for _, tc := range testCases {

mcp/elicitation_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"time"
1313

1414
"github.com/google/jsonschema-go/jsonschema"
15+
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
1516
)
1617

1718
// TODO: migrate other elicitation tests here.
@@ -74,7 +75,7 @@ func TestElicitationURLMode(t *testing.T) {
7475
Message: "URL is missing",
7576
},
7677
wantErrMsg: "URL must be set for URL elicitation",
77-
wantErrCode: codeInvalidParams,
78+
wantErrCode: jsonrpc.CodeInvalidParams,
7879
},
7980
{
8081
name: "schema not allowed",
@@ -90,7 +91,7 @@ func TestElicitationURLMode(t *testing.T) {
9091
},
9192
},
9293
wantErrMsg: "requestedSchema must not be set for URL elicitation",
93-
wantErrCode: codeInvalidParams,
94+
wantErrCode: jsonrpc.CodeInvalidParams,
9495
},
9596
}
9697
for _, tc := range testCases {

0 commit comments

Comments
 (0)