diff --git a/mcp/client.go b/mcp/client.go index ec1dc456..6df12bb0 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -6,12 +6,14 @@ package mcp import ( "context" + "encoding/json" "fmt" "iter" "slices" "sync" "time" + "github.com/google/jsonschema-go/jsonschema" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" "github.com/modelcontextprotocol/go-sdk/jsonrpc" ) @@ -56,6 +58,9 @@ type ClientOptions struct { // Handler for sampling. // Called when a server calls CreateMessage. CreateMessageHandler func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) + // Handler for elicitation. + // Called when a server requests user input via Elicit. + ElicitationHandler func(context.Context, *ElicitRequest) (*ElicitResult, error) // Handlers for notifications from the server. ToolListChangedHandler func(context.Context, *ToolListChangedRequest) PromptListChangedHandler func(context.Context, *PromptListChangedRequest) @@ -111,6 +116,9 @@ func (c *Client) capabilities() *ClientCapabilities { if c.opts.CreateMessageHandler != nil { caps.Sampling = &SamplingCapabilities{} } + if c.opts.ElicitationHandler != nil { + caps.Elicitation = &ElicitationCapabilities{} + } return caps } @@ -268,6 +276,168 @@ func (c *Client) createMessage(ctx context.Context, req *CreateMessageRequest) ( return c.opts.CreateMessageHandler(ctx, req) } +func (c *Client) elicit(ctx context.Context, req *ElicitRequest) (*ElicitResult, error) { + if c.opts.ElicitationHandler == nil { + // TODO: wrap or annotate this error? Pick a standard code? + return nil, jsonrpc2.NewError(CodeUnsupportedMethod, "client does not support elicitation") + } + + // Validate that the requested schema only contains top-level properties without nesting + if err := validateElicitSchema(req.Params.RequestedSchema); err != nil { + return nil, jsonrpc2.NewError(CodeInvalidParams, err.Error()) + } + + res, err := c.opts.ElicitationHandler(ctx, req) + if err != nil { + return nil, err + } + + // Validate elicitation result content against requested schema + if req.Params.RequestedSchema != nil && res.Content != nil { + resolved, err := req.Params.RequestedSchema.Resolve(nil) + if err != nil { + return nil, jsonrpc2.NewError(CodeInvalidParams, fmt.Sprintf("failed to resolve requested schema: %v", err)) + } + + if err := resolved.Validate(res.Content); err != nil { + return nil, jsonrpc2.NewError(CodeInvalidParams, fmt.Sprintf("elicitation result content does not match requested schema: %v", err)) + } + } + + return res, nil +} + +// validateElicitSchema validates that the schema conforms to MCP elicitation schema requirements. +// Per the MCP specification, elicitation schemas are limited to flat objects with primitive properties only. +func validateElicitSchema(schema *jsonschema.Schema) error { + if schema == nil { + return nil // nil schema is allowed + } + + // The root schema must be of type "object" if specified + if schema.Type != "" && schema.Type != "object" { + return fmt.Errorf("elicit schema must be of type 'object', got %q", schema.Type) + } + + // Check if the schema has properties + if schema.Properties != nil { + for propName, propSchema := range schema.Properties { + if propSchema == nil { + continue + } + + if err := validateElicitProperty(propName, propSchema); err != nil { + return err + } + } + } + + return nil +} + +// validateElicitProperty validates a single property in an elicitation schema. +func validateElicitProperty(propName string, propSchema *jsonschema.Schema) error { + // Check if this property has nested properties (not allowed) + if len(propSchema.Properties) > 0 { + return fmt.Errorf("elicit schema property %q contains nested properties, only primitive properties are allowed", propName) + } + + // Validate based on the property type - only primitives are supported + switch propSchema.Type { + case "string": + return validateElicitStringProperty(propName, propSchema) + case "number", "integer": + return validateElicitNumberProperty(propName, propSchema) + case "boolean": + return validateElicitBooleanProperty(propName, propSchema) + default: + return fmt.Errorf("elicit schema property %q has unsupported type %q, only string, number, integer, and boolean are allowed", propName, propSchema.Type) + } +} + +// validateElicitStringProperty validates string-type properties, including enums. +func validateElicitStringProperty(propName string, propSchema *jsonschema.Schema) error { + // Handle enum validation (enums are a special case of strings) + if len(propSchema.Enum) > 0 { + // Enums must be string type (or untyped which defaults to string) + if propSchema.Type != "" && propSchema.Type != "string" { + return fmt.Errorf("elicit schema property %q has enum values but type is %q, enums are only supported for string type", propName, propSchema.Type) + } + // Enum values themselves are validated by the JSON schema library + // Validate enumNames if present - must match enum length + if propSchema.Extra != nil { + if enumNamesRaw, exists := propSchema.Extra["enumNames"]; exists { + // Type check enumNames - should be a slice + if enumNamesSlice, ok := enumNamesRaw.([]interface{}); ok { + if len(enumNamesSlice) != len(propSchema.Enum) { + return fmt.Errorf("elicit schema property %q has %d enum values but %d enumNames, they must match", propName, len(propSchema.Enum), len(enumNamesSlice)) + } + } else { + return fmt.Errorf("elicit schema property %q has invalid enumNames type, must be an array", propName) + } + } + } + return nil + } + + // Validate format if specified - only specific formats are allowed + if propSchema.Format != "" { + allowedFormats := map[string]bool{ + "email": true, + "uri": true, + "date": true, + "date-time": true, + } + if !allowedFormats[propSchema.Format] { + return fmt.Errorf("elicit schema property %q has unsupported format %q, only email, uri, date, and date-time are allowed", propName, propSchema.Format) + } + } + + // Validate minLength constraint if specified + if propSchema.MinLength != nil { + if *propSchema.MinLength < 0 { + return fmt.Errorf("elicit schema property %q has invalid minLength %d, must be non-negative", propName, *propSchema.MinLength) + } + } + + // Validate maxLength constraint if specified + if propSchema.MaxLength != nil { + if *propSchema.MaxLength < 0 { + return fmt.Errorf("elicit schema property %q has invalid maxLength %d, must be non-negative", propName, *propSchema.MaxLength) + } + // Check that maxLength >= minLength if both are specified + if propSchema.MinLength != nil && *propSchema.MaxLength < *propSchema.MinLength { + return fmt.Errorf("elicit schema property %q has maxLength %d less than minLength %d", propName, *propSchema.MaxLength, *propSchema.MinLength) + } + } + + return nil +} + +// validateElicitNumberProperty validates number and integer-type properties. +func validateElicitNumberProperty(propName string, propSchema *jsonschema.Schema) error { + if propSchema.Minimum != nil && propSchema.Maximum != nil { + if *propSchema.Maximum < *propSchema.Minimum { + return fmt.Errorf("elicit schema property %q has maximum %g less than minimum %g", propName, *propSchema.Maximum, *propSchema.Minimum) + } + } + + return nil +} + +// validateElicitBooleanProperty validates boolean-type properties. +func validateElicitBooleanProperty(propName string, propSchema *jsonschema.Schema) error { + // Validate default value if specified - must be a valid boolean + if propSchema.Default != nil { + var defaultValue bool + if err := json.Unmarshal(propSchema.Default, &defaultValue); err != nil { + return fmt.Errorf("elicit schema property %q has invalid default value, must be a boolean: %v", propName, err) + } + } + + return nil +} + // AddSendingMiddleware wraps the current sending method handler using the provided // middleware. Middleware is applied from right to left, so that the first one is // executed first. @@ -308,6 +478,7 @@ var clientMethodInfos = map[string]methodInfo{ methodPing: newClientMethodInfo(clientSessionMethod((*ClientSession).ping), missingParamsOK), methodListRoots: newClientMethodInfo(clientMethod((*Client).listRoots), missingParamsOK), methodCreateMessage: newClientMethodInfo(clientMethod((*Client).createMessage), 0), + methodElicit: newClientMethodInfo(clientMethod((*Client).elicit), missingParamsOK), notificationCancelled: newClientMethodInfo(clientSessionMethod((*ClientSession).cancel), notification|missingParamsOK), notificationToolListChanged: newClientMethodInfo(clientMethod((*Client).callToolChangedHandler), notification|missingParamsOK), notificationPromptListChanged: newClientMethodInfo(clientMethod((*Client).callPromptChangedHandler), notification|missingParamsOK), diff --git a/mcp/elicitation_example_test.go b/mcp/elicitation_example_test.go new file mode 100644 index 00000000..526a4881 --- /dev/null +++ b/mcp/elicitation_example_test.go @@ -0,0 +1,86 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +package mcp_test + +import ( + "context" + "fmt" + "log" + + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +func Example_elicitation() { + ctx := context.Background() + clientTransport, serverTransport := mcp.NewInMemoryTransports() + + // Create server + server := mcp.NewServer(&mcp.Implementation{Name: "config-server", Version: "v1.0.0"}, nil) + + serverSession, err := server.Connect(ctx, serverTransport, nil) + if err != nil { + log.Fatal(err) + } + + // Create client with elicitation handler + // Note: Never use elicitation for sensitive data like API keys or passwords + client := mcp.NewClient(&mcp.Implementation{Name: "config-client", Version: "v1.0.0"}, &mcp.ClientOptions{ + ElicitationHandler: func(ctx context.Context, request *mcp.ElicitRequest) (*mcp.ElicitResult, error) { + fmt.Printf("Server requests: %s\n", request.Params.Message) + + // In a real application, this would prompt the user for input + // Here we simulate user providing configuration data + return &mcp.ElicitResult{ + Action: "accept", + Content: map[string]any{ + "serverEndpoint": "https://api.example.com", + "maxRetries": float64(3), + "enableLogs": true, + }, + }, nil + }, + }) + + _, err = client.Connect(ctx, clientTransport, nil) + if err != nil { + log.Fatal(err) + } + + // Server requests user configuration via elicitation + configSchema := &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "serverEndpoint": {Type: "string", Description: "Server endpoint URL"}, + "maxRetries": {Type: "number", Minimum: ptr(1.0), Maximum: ptr(10.0)}, + "enableLogs": {Type: "boolean", Description: "Enable debug logging"}, + }, + Required: []string{"serverEndpoint"}, + } + + result, err := serverSession.Elicit(ctx, &mcp.ElicitParams{ + Message: "Please provide your configuration settings", + RequestedSchema: configSchema, + }) + if err != nil { + log.Fatal(err) + } + + if result.Action == "accept" { + fmt.Printf("Configuration received: Endpoint: %v, Max Retries: %.0f, Logs: %v\n", + result.Content["serverEndpoint"], + result.Content["maxRetries"], + result.Content["enableLogs"]) + } + + // Output: + // Server requests: Please provide your configuration settings + // Configuration received: Endpoint: https://api.example.com, Max Retries: 3, Logs: true +} + +// ptr is a helper function to create pointers for schema constraints +func ptr[T any](v T) *T { + return &v +} diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 9c578392..fef0c91e 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -7,6 +7,7 @@ package mcp import ( "bytes" "context" + "encoding/json" "errors" "fmt" "io" @@ -132,6 +133,15 @@ func TestEndToEnd(t *testing.T) { CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) { return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil }, + ElicitationHandler: func(ctx context.Context, req *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{ + Action: "accept", + Content: map[string]any{ + "name": "Test User", + "email": "test@example.com", + }, + }, nil + }, ToolListChangedHandler: func(context.Context, *ToolListChangedRequest) { notificationChans["tools"] <- 0 }, @@ -474,6 +484,19 @@ func TestEndToEnd(t *testing.T) { } }) + t.Run("elicitation", func(t *testing.T) { + result, err := ss.Elicit(ctx, &ElicitParams{ + Message: "Please provide information", + }) + if err != nil { + t.Fatal(err) + } + if result.Action != "accept" { + t.Errorf("got action %q, want %q", result.Action, "accept") + } + + }) + // Disconnect. cs.Close() clientWG.Wait() @@ -906,6 +929,518 @@ func TestKeepAlive(t *testing.T) { } } +func TestElicitationUnsupportedMethod(t *testing.T) { + ctx := context.Background() + ct, st := NewInMemoryTransports() + + // Server + s := NewServer(testImpl, nil) + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + // Client without ElicitationHandler + c := NewClient(testImpl, &ClientOptions{ + CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) { + return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil + }, + }) + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Test that elicitation fails when no handler is provided + _, err = ss.Elicit(ctx, &ElicitParams{ + Message: "This should fail", + RequestedSchema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "test": {Type: "string"}, + }, + }, + }) + + if err == nil { + t.Error("expected error when ElicitationHandler is not provided, got nil") + } + if code := errorCode(err); code != CodeUnsupportedMethod { + t.Errorf("got error code %d, want %d (CodeUnsupportedMethod)", code, CodeUnsupportedMethod) + } + if !strings.Contains(err.Error(), "does not support elicitation") { + t.Errorf("error should mention unsupported elicitation, got: %v", err) + } +} + +func TestElicitationSchemaValidation(t *testing.T) { + ctx := context.Background() + ct, st := NewInMemoryTransports() + + s := NewServer(testImpl, nil) + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + c := NewClient(testImpl, &ClientOptions{ + ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept", Content: map[string]any{"test": "value"}}, nil + }, + }) + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Test valid schemas - these should not return errors + validSchemas := []struct { + name string + schema *jsonschema.Schema + }{ + { + name: "nil schema", + schema: nil, + }, + { + name: "empty object schema", + schema: &jsonschema.Schema{ + Type: "object", + }, + }, + { + name: "simple string property", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "name": {Type: "string"}, + }, + }, + }, + { + name: "string with valid formats", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "email": {Type: "string", Format: "email"}, + "website": {Type: "string", Format: "uri"}, + "birthday": {Type: "string", Format: "date"}, + "created": {Type: "string", Format: "date-time"}, + }, + }, + }, + { + name: "string with constraints", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "name": {Type: "string", MinLength: ptr(1), MaxLength: ptr(100)}, + }, + }, + }, + { + name: "number with constraints", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "age": {Type: "integer", Minimum: ptr(0.0), Maximum: ptr(150.0)}, + "score": {Type: "number", Minimum: ptr(0.0), Maximum: ptr(100.0)}, + }, + }, + }, + { + name: "boolean with default", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "enabled": {Type: "boolean", Default: json.RawMessage("true")}, + }, + }, + }, + { + name: "string enum", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "status": { + Type: "string", + Enum: []any{ + "active", + "inactive", + "pending", + }, + }, + }, + }, + }, + { + name: "enum with matching enumNames", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "priority": { + Type: "string", + Enum: []any{ + "high", + "medium", + "low", + }, + Extra: map[string]any{ + "enumNames": []interface{}{"High Priority", "Medium Priority", "Low Priority"}, + }, + }, + }, + }, + }, + } + + for _, tc := range validSchemas { + t.Run("valid_"+tc.name, func(t *testing.T) { + _, err := ss.Elicit(ctx, &ElicitParams{ + Message: "Test valid schema: " + tc.name, + RequestedSchema: tc.schema, + }) + if err != nil { + t.Errorf("expected no error for valid schema %q, got: %v", tc.name, err) + } + }) + } + + // Test invalid schemas - these should return errors + invalidSchemas := []struct { + name string + schema *jsonschema.Schema + expectedError string + }{ + { + name: "root schema non-object type", + schema: &jsonschema.Schema{ + Type: "string", + }, + expectedError: "elicit schema must be of type 'object', got \"string\"", + }, + { + name: "nested object property", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "user": { + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "name": {Type: "string"}, + }, + }, + }, + }, + expectedError: "elicit schema property \"user\" contains nested properties, only primitive properties are allowed", + }, + { + name: "property with explicit object type", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "config": {Type: "object"}, + }, + }, + expectedError: "elicit schema property \"config\" has unsupported type \"object\", only string, number, integer, and boolean are allowed", + }, + { + name: "array property", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "tags": {Type: "array", Items: &jsonschema.Schema{Type: "string"}}, + }, + }, + expectedError: "elicit schema property \"tags\" has unsupported type \"array\", only string, number, integer, and boolean are allowed", + }, + { + name: "array without items", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "items": {Type: "array"}, + }, + }, + expectedError: "elicit schema property \"items\" has unsupported type \"array\", only string, number, integer, and boolean are allowed", + }, + { + name: "unsupported string format", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "phone": {Type: "string", Format: "phone"}, + }, + }, + expectedError: "elicit schema property \"phone\" has unsupported format \"phone\", only email, uri, date, and date-time are allowed", + }, + { + name: "unsupported type", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "data": {Type: "null"}, + }, + }, + expectedError: "elicit schema property \"data\" has unsupported type \"null\"", + }, + { + name: "string with invalid minLength", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "name": {Type: "string", MinLength: ptr(-1)}, + }, + }, + expectedError: "elicit schema property \"name\" has invalid minLength -1, must be non-negative", + }, + { + name: "string with invalid maxLength", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "name": {Type: "string", MaxLength: ptr(-5)}, + }, + }, + expectedError: "elicit schema property \"name\" has invalid maxLength -5, must be non-negative", + }, + { + name: "string with maxLength less than minLength", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "name": {Type: "string", MinLength: ptr(10), MaxLength: ptr(5)}, + }, + }, + expectedError: "elicit schema property \"name\" has maxLength 5 less than minLength 10", + }, + { + name: "number with maximum less than minimum", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "score": {Type: "number", Minimum: ptr(100.0), Maximum: ptr(50.0)}, + }, + }, + expectedError: "elicit schema property \"score\" has maximum 50 less than minimum 100", + }, + { + name: "boolean with invalid default", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "enabled": {Type: "boolean", Default: json.RawMessage(`"not-a-boolean"`)}, + }, + }, + expectedError: "elicit schema property \"enabled\" has invalid default value, must be a boolean", + }, + { + name: "enum with mismatched enumNames length", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "priority": { + Type: "string", + Enum: []any{ + "high", + "medium", + "low", + }, + Extra: map[string]any{ + "enumNames": []interface{}{"High Priority", "Medium Priority"}, // Only 2 names for 3 values + }, + }, + }, + }, + expectedError: "elicit schema property \"priority\" has 3 enum values but 2 enumNames, they must match", + }, + { + name: "enum with invalid enumNames type", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "status": { + Type: "string", + Enum: []any{ + "active", + "inactive", + }, + Extra: map[string]any{ + "enumNames": "not an array", // Should be array + }, + }, + }, + }, + expectedError: "elicit schema property \"status\" has invalid enumNames type, must be an array", + }, + { + name: "enum without explicit type", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "priority": { + Enum: []any{ + "high", + "medium", + "low", + }, + }, + }, + }, + expectedError: "elicit schema property \"priority\" has unsupported type \"\", only string, number, integer, and boolean are allowed", + }, + { + name: "untyped property", + schema: &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "data": {}, + }, + }, + expectedError: "elicit schema property \"data\" has unsupported type \"\", only string, number, integer, and boolean are allowed", + }, + } + + for _, tc := range invalidSchemas { + t.Run("invalid_"+tc.name, func(t *testing.T) { + _, err := ss.Elicit(ctx, &ElicitParams{ + Message: "Test invalid schema: " + tc.name, + RequestedSchema: tc.schema, + }) + if err == nil { + t.Errorf("expected error for invalid schema %q, got nil", tc.name) + return + } + if code := errorCode(err); code != CodeInvalidParams { + t.Errorf("got error code %d, want %d (CodeInvalidParams)", code, CodeInvalidParams) + } + if !strings.Contains(err.Error(), tc.expectedError) { + t.Errorf("error message %q does not contain expected text %q", err.Error(), tc.expectedError) + } + }) + } +} + +func TestElicitationProgressToken(t *testing.T) { + ctx := context.Background() + ct, st := NewInMemoryTransports() + + s := NewServer(testImpl, nil) + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + c := NewClient(testImpl, &ClientOptions{ + ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "accept"}, nil + }, + }) + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + params := &ElicitParams{ + Message: "Test progress token", + Meta: Meta{}, + } + params.SetProgressToken("test-token") + + if token := params.GetProgressToken(); token != "test-token" { + t.Errorf("got progress token %v, want %q", token, "test-token") + } + + _, err = ss.Elicit(ctx, params) + if err != nil { + t.Fatal(err) + } +} + +func TestElicitationCapabilityDeclaration(t *testing.T) { + ctx := context.Background() + + t.Run("with_handler", func(t *testing.T) { + ct, st := NewInMemoryTransports() + + // Client with ElicitationHandler should declare capability + c := NewClient(testImpl, &ClientOptions{ + ElicitationHandler: func(context.Context, *ElicitRequest) (*ElicitResult, error) { + return &ElicitResult{Action: "cancel"}, nil + }, + }) + + s := NewServer(testImpl, nil) + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // The client should have declared elicitation capability during initialization + // We can verify this worked by successfully making an elicitation call + result, err := ss.Elicit(ctx, &ElicitParams{ + Message: "Test capability", + RequestedSchema: &jsonschema.Schema{Type: "object"}, + }) + if err != nil { + t.Errorf("elicitation should work when capability is declared, got error: %v", err) + } + if result.Action != "cancel" { + t.Errorf("got action %q, want %q", result.Action, "cancel") + } + }) + + t.Run("without_handler", func(t *testing.T) { + ct, st := NewInMemoryTransports() + + // Client without ElicitationHandler should not declare capability + c := NewClient(testImpl, &ClientOptions{ + CreateMessageHandler: func(context.Context, *CreateMessageRequest) (*CreateMessageResult, error) { + return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil + }, + }) + + s := NewServer(testImpl, nil) + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + defer ss.Close() + + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + defer cs.Close() + + // Elicitation should fail with UnsupportedMethod + _, err = ss.Elicit(ctx, &ElicitParams{ + Message: "This should fail", + RequestedSchema: &jsonschema.Schema{Type: "object"}, + }) + + if err == nil { + t.Error("expected UnsupportedMethod error when no capability declared") + } + if code := errorCode(err); code != CodeUnsupportedMethod { + t.Errorf("got error code %d, want %d (CodeUnsupportedMethod)", code, CodeUnsupportedMethod) + } + }) +} + func TestKeepAliveFailure(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -1185,3 +1720,8 @@ func TestPointerArgEquivalence(t *testing.T) { }) } } + +// ptr is a helper function to create pointers for schema constraints +func ptr[T any](v T) *T { + return &v +} diff --git a/mcp/protocol.go b/mcp/protocol.go index 75db7613..382f745f 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -968,7 +968,39 @@ func (*ResourceUpdatedNotificationParams) isParams() {} // TODO(jba): add CompleteRequest and related types. -// TODO(jba): add ElicitRequest and related types. +// A request from the server to elicit additional information from the user via the client. +type ElicitParams struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The message to present to the user. + Message string `json:"message"` + // A restricted subset of JSON Schema. + // Only top-level properties are allowed, without nesting. + RequestedSchema *jsonschema.Schema `json:"requestedSchema"` +} + +func (x *ElicitParams) isParams() {} + +func (x *ElicitParams) GetProgressToken() any { return getProgressToken(x) } +func (x *ElicitParams) SetProgressToken(t any) { setProgressToken(x, t) } + +// The client's response to an elicitation/create request from the server. +type ElicitResult struct { + // This property is reserved by the protocol to allow clients and servers to + // attach additional metadata to their responses. + Meta `json:"_meta,omitempty"` + // The user action in response to the elicitation. + // - "accept": User submitted the form/confirmed the action + // - "decline": User explicitly declined the action + // - "cancel": User dismissed without making an explicit choice + Action string `json:"action"` + // The submitted form data, only present when action is "accept". + // Contains values matching the requested schema. + Content map[string]any `json:"content,omitempty"` +} + +func (*ElicitResult) isResult() {} // An Implementation describes the name and version of an MCP implementation, with an optional // title for UI representation. diff --git a/mcp/requests.go b/mcp/requests.go index 46ff4f8d..c50ea99b 100644 --- a/mcp/requests.go +++ b/mcp/requests.go @@ -25,6 +25,7 @@ type ( type ( CreateMessageRequest = ClientRequest[*CreateMessageParams] + ElicitRequest = ClientRequest[*ElicitParams] InitializedClientRequest = ClientRequest[*InitializedParams] InitializeRequest = ClientRequest[*InitializeParams] ListRootsRequest = ClientRequest[*ListRootsParams] diff --git a/mcp/server.go b/mcp/server.go index f9ddcd66..68dbd591 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -834,6 +834,11 @@ func (ss *ServerSession) CreateMessage(ctx context.Context, params *CreateMessag return handleSend[*CreateMessageResult](ctx, methodCreateMessage, newServerRequest(ss, orZero[Params](params))) } +// Elicit sends an elicitation request to the client asking for user input. +func (ss *ServerSession) Elicit(ctx context.Context, params *ElicitParams) (*ElicitResult, error) { + return handleSend[*ElicitResult](ctx, methodElicit, newServerRequest(ss, orZero[Params](params))) +} + // Log sends a log message to the client. // The message is not sent if the client has not called SetLevel, or if its level // is below that of the last SetLevel.