Skip to content

Commit 4434100

Browse files
committed
fix: prevent null 'required' in MCP tool schemas breaking OpenAI API
OpenAI strictly validates JSON Schema and rejects 'required': null (expects an array). When MCP tools had no required fields, the nil []string serialized as null. Initialize required as []string{} and sanitize nested schemas recursively to remove null/invalid required fields from MCP server responses.
1 parent 71bdc76 commit 4434100

File tree

2 files changed

+116
-222
lines changed

2 files changed

+116
-222
lines changed

internal/tools/mcp.go

Lines changed: 18 additions & 222 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,9 @@ import (
66
"fmt"
77
"slices"
88
"strings"
9-
"time"
109

1110
"charm.land/fantasy"
12-
"github.com/mark3labs/mcp-go/client"
13-
"github.com/mark3labs/mcp-go/client/transport"
1411
"github.com/mark3labs/mcp-go/mcp"
15-
"github.com/mark3labs/mcphost/internal/builtin"
1612
"github.com/mark3labs/mcphost/internal/config"
1713
)
1814

@@ -68,75 +64,6 @@ func (m *MCPToolManager) SetDebugLogger(logger DebugLogger) {
6864
}
6965
}
7066

71-
// samplingHandler implements the MCP sampling handler interface using a fantasy LanguageModel
72-
type samplingHandler struct {
73-
model fantasy.LanguageModel
74-
}
75-
76-
// CreateMessage handles sampling requests from MCP servers by forwarding them to the configured LLM model.
77-
// It converts MCP message formats to fantasy message formats, invokes the model for generation,
78-
// and converts the response back to MCP format. Returns an error if no model is available
79-
// or if generation fails.
80-
func (h *samplingHandler) CreateMessage(ctx context.Context, request mcp.CreateMessageRequest) (*mcp.CreateMessageResult, error) {
81-
if h.model == nil {
82-
return nil, fmt.Errorf("no model available for sampling")
83-
}
84-
85-
// Build fantasy messages from MCP sampling request
86-
var messages []fantasy.Message
87-
88-
// Add system message if provided
89-
if request.SystemPrompt != "" {
90-
messages = append(messages, fantasy.NewSystemMessage(request.SystemPrompt))
91-
}
92-
93-
// Convert sampling messages
94-
for _, msg := range request.Messages {
95-
var content string
96-
if textContent, ok := msg.Content.(mcp.TextContent); ok {
97-
content = textContent.Text
98-
} else {
99-
content = fmt.Sprintf("%v", msg.Content)
100-
}
101-
102-
switch msg.Role {
103-
case mcp.RoleUser:
104-
messages = append(messages, fantasy.NewUserMessage(content))
105-
case mcp.RoleAssistant:
106-
messages = append(messages, fantasy.Message{
107-
Role: fantasy.MessageRoleAssistant,
108-
Content: []fantasy.MessagePart{fantasy.TextPart{Text: content}},
109-
})
110-
default:
111-
messages = append(messages, fantasy.NewUserMessage(content))
112-
}
113-
}
114-
115-
// Generate response using the fantasy model
116-
call := fantasy.Call{
117-
Prompt: fantasy.Prompt(messages),
118-
}
119-
response, err := h.model.Generate(ctx, call)
120-
if err != nil {
121-
return nil, fmt.Errorf("model generation failed: %w", err)
122-
}
123-
124-
// Convert response back to MCP format
125-
result := &mcp.CreateMessageResult{
126-
Model: h.model.Model(),
127-
StopReason: "endTurn",
128-
}
129-
result.SamplingMessage = mcp.SamplingMessage{
130-
Role: mcp.RoleAssistant,
131-
Content: mcp.TextContent{
132-
Type: "text",
133-
Text: response.Content.Text(),
134-
},
135-
}
136-
137-
return result, nil
138-
}
139-
14067
// LoadTools loads tools from all configured MCP servers based on the provided configuration.
14168
// It initializes the connection pool, connects to each configured server, and loads their tools.
14269
// Tools from different servers are prefixed with the server name to avoid naming conflicts.
@@ -229,16 +156,15 @@ func (m *MCPToolManager) loadServerTools(ctx context.Context, serverName string,
229156

230157
// Extract properties and required from the schema
231158
parameters := make(map[string]any)
232-
var required []string
159+
required := []string{}
233160

234161
if props, ok := schemaMap["properties"].(map[string]any); ok {
235162
parameters = props
236163
}
237164

238-
// Fix for issue #89: Ensure object schemas have a properties field
239-
if schemaType, ok := schemaMap["type"].(string); ok && schemaType == "object" && len(parameters) == 0 {
240-
// Keep empty parameters map - fantasy handles this fine
241-
}
165+
// Fix for issue #89: Ensure object schemas have a properties field.
166+
// When schema type is "object" with no properties, we keep the
167+
// empty parameters map — fantasy handles this fine.
242168

243169
if req, ok := schemaMap["required"].([]any); ok {
244170
for _, r := range req {
@@ -312,148 +238,6 @@ func (m *MCPToolManager) shouldExcludeTool(toolName string, serverConfig config.
312238
return false
313239
}
314240

315-
func (m *MCPToolManager) createMCPClient(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (client.MCPClient, error) {
316-
transportType := serverConfig.GetTransportType()
317-
318-
switch transportType {
319-
case "stdio":
320-
var env []string
321-
var command string
322-
var args []string
323-
324-
if len(serverConfig.Command) > 0 {
325-
command = serverConfig.Command[0]
326-
if len(serverConfig.Command) > 1 {
327-
args = serverConfig.Command[1:]
328-
} else if len(serverConfig.Args) > 0 {
329-
args = serverConfig.Args
330-
}
331-
}
332-
333-
if serverConfig.Environment != nil {
334-
for k, v := range serverConfig.Environment {
335-
env = append(env, fmt.Sprintf("%s=%s", k, v))
336-
}
337-
}
338-
339-
if serverConfig.Env != nil {
340-
for k, v := range serverConfig.Env {
341-
env = append(env, fmt.Sprintf("%s=%v", k, v))
342-
}
343-
}
344-
345-
stdioTransport := transport.NewStdio(command, env, args...)
346-
stdioClient := client.NewClient(stdioTransport)
347-
348-
if err := stdioTransport.Start(ctx); err != nil {
349-
return nil, fmt.Errorf("failed to start stdio transport: %v", err)
350-
}
351-
352-
time.Sleep(100 * time.Millisecond)
353-
return stdioClient, nil
354-
355-
case "sse":
356-
var options []transport.ClientOption
357-
358-
if len(serverConfig.Headers) > 0 {
359-
headers := make(map[string]string)
360-
for _, header := range serverConfig.Headers {
361-
parts := strings.SplitN(header, ":", 2)
362-
if len(parts) == 2 {
363-
key := strings.TrimSpace(parts[0])
364-
value := strings.TrimSpace(parts[1])
365-
headers[key] = value
366-
}
367-
}
368-
if len(headers) > 0 {
369-
options = append(options, transport.WithHeaders(headers))
370-
}
371-
}
372-
373-
sseClient, err := client.NewSSEMCPClient(serverConfig.URL, options...)
374-
if err != nil {
375-
return nil, err
376-
}
377-
378-
if err := sseClient.Start(ctx); err != nil {
379-
return nil, fmt.Errorf("failed to start SSE client: %v", err)
380-
}
381-
382-
return sseClient, nil
383-
384-
case "streamable":
385-
var options []transport.StreamableHTTPCOption
386-
387-
if len(serverConfig.Headers) > 0 {
388-
headers := make(map[string]string)
389-
for _, header := range serverConfig.Headers {
390-
parts := strings.SplitN(header, ":", 2)
391-
if len(parts) == 2 {
392-
key := strings.TrimSpace(parts[0])
393-
value := strings.TrimSpace(parts[1])
394-
headers[key] = value
395-
}
396-
}
397-
if len(headers) > 0 {
398-
options = append(options, transport.WithHTTPHeaders(headers))
399-
}
400-
}
401-
402-
streamableClient, err := client.NewStreamableHttpClient(serverConfig.URL, options...)
403-
if err != nil {
404-
return nil, err
405-
}
406-
407-
if err := streamableClient.Start(ctx); err != nil {
408-
return nil, fmt.Errorf("failed to start streamable HTTP client: %v", err)
409-
}
410-
411-
return streamableClient, nil
412-
413-
case "inprocess":
414-
return m.createBuiltinClient(ctx, serverName, serverConfig)
415-
416-
default:
417-
return nil, fmt.Errorf("unsupported transport type '%s' for server %s", transportType, serverName)
418-
}
419-
}
420-
421-
func (m *MCPToolManager) initializeClient(ctx context.Context, client client.MCPClient) error {
422-
initCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
423-
defer cancel()
424-
425-
initRequest := mcp.InitializeRequest{}
426-
initRequest.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
427-
initRequest.Params.ClientInfo = mcp.Implementation{
428-
Name: "mcphost",
429-
Version: "1.0.0",
430-
}
431-
initRequest.Params.Capabilities = mcp.ClientCapabilities{}
432-
433-
_, err := client.Initialize(initCtx, initRequest)
434-
if err != nil {
435-
return fmt.Errorf("initialization timeout or failed: %v", err)
436-
}
437-
return nil
438-
}
439-
440-
// createBuiltinClient creates an in-process MCP client for builtin servers
441-
func (m *MCPToolManager) createBuiltinClient(ctx context.Context, serverName string, serverConfig config.MCPServerConfig) (client.MCPClient, error) {
442-
registry := builtin.NewRegistry()
443-
444-
builtinServer, err := registry.CreateServer(serverConfig.Name, serverConfig.Options, m.model)
445-
if err != nil {
446-
return nil, fmt.Errorf("failed to create builtin server: %v", err)
447-
}
448-
449-
inProcessClient, err := client.NewInProcessClient(builtinServer.GetServer())
450-
if err != nil {
451-
return nil, fmt.Errorf("failed to create in-process client: %v", err)
452-
}
453-
454-
return inProcessClient, nil
455-
}
456-
457241
// debugLogConnectionInfo logs detailed connection information for debugging
458242
func (m *MCPToolManager) debugLogConnectionInfo(serverName string, serverConfig config.MCPServerConfig) {
459243
if m.debugLogger == nil || !m.debugLogger.IsDebugEnabled() {
@@ -497,8 +281,9 @@ func convertExclusiveBoundsToBoolean(schemaJSON []byte) []byte {
497281
return result
498282
}
499283

500-
// convertSchemaRecursive recursively processes a schema map and converts
501-
// numeric exclusiveMinimum/exclusiveMaximum to boolean format.
284+
// convertSchemaRecursive recursively processes a schema map to:
285+
// - Convert numeric exclusiveMinimum/exclusiveMaximum to boolean format (draft-07 → draft-04)
286+
// - Remove null "required" fields that cause OpenAI API validation errors
502287
func convertSchemaRecursive(schema map[string]any) {
503288
if exMin, ok := schema["exclusiveMinimum"]; ok {
504289
if num, isNum := exMin.(float64); isNum {
@@ -514,6 +299,17 @@ func convertSchemaRecursive(schema map[string]any) {
514299
}
515300
}
516301

302+
// Fix null "required" fields — OpenAI rejects "required": null,
303+
// it must be an array or absent entirely.
304+
if req, exists := schema["required"]; exists {
305+
if req == nil {
306+
delete(schema, "required")
307+
} else if _, isArr := req.([]any); !isArr {
308+
// Not an array — remove invalid value
309+
delete(schema, "required")
310+
}
311+
}
312+
517313
if props, ok := schema["properties"].(map[string]any); ok {
518314
for _, prop := range props {
519315
if propSchema, ok := prop.(map[string]any); ok {

internal/tools/mcp_test.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,104 @@ func TestConvertExclusiveBoundsToBoolean(t *testing.T) {
302302
}
303303
}
304304

305+
// TestNullRequiredFieldSanitization tests that null "required" fields are removed
306+
// from schemas to prevent OpenAI API validation errors like:
307+
// "None is not of type 'array'"
308+
func TestNullRequiredFieldSanitization(t *testing.T) {
309+
tests := []struct {
310+
name string
311+
input string
312+
wantKey bool // should "required" key exist in output?
313+
wantJSON string // expected JSON output (if checking more than key presence)
314+
}{
315+
{
316+
name: "null required is removed",
317+
input: `{"type": "object", "properties": {"name": {"type": "string"}}, "required": null}`,
318+
wantKey: false,
319+
},
320+
{
321+
name: "valid required array is preserved",
322+
input: `{"type": "object", "properties": {"name": {"type": "string"}}, "required": ["name"]}`,
323+
wantKey: true,
324+
},
325+
{
326+
name: "empty required array is preserved",
327+
input: `{"type": "object", "properties": {}, "required": []}`,
328+
wantKey: true,
329+
},
330+
{
331+
name: "nested null required is removed",
332+
input: `{"type": "object", "properties": {"config": {"type": "object", "properties": {"key": {"type": "string"}}, "required": null}}}`,
333+
wantKey: false,
334+
},
335+
{
336+
name: "required as wrong type (string) is removed",
337+
input: `{"type": "object", "properties": {}, "required": "name"}`,
338+
wantKey: false,
339+
},
340+
}
341+
342+
for _, tt := range tests {
343+
t.Run(tt.name, func(t *testing.T) {
344+
result := convertExclusiveBoundsToBoolean([]byte(tt.input))
345+
346+
var got map[string]any
347+
if err := json.Unmarshal(result, &got); err != nil {
348+
t.Fatalf("Failed to unmarshal result: %v", err)
349+
}
350+
351+
// Check top-level or nested required field
352+
checkSchema := got
353+
if nested, ok := got["properties"].(map[string]any); ok {
354+
if cfg, ok := nested["config"].(map[string]any); ok {
355+
checkSchema = cfg
356+
}
357+
}
358+
359+
_, hasRequired := checkSchema["required"]
360+
if hasRequired != tt.wantKey {
361+
t.Errorf("required key present = %v, want %v. Schema: %s", hasRequired, tt.wantKey, string(result))
362+
}
363+
})
364+
}
365+
}
366+
367+
// TestToolInfoRequiredNeverNull verifies that MCP tool conversion always produces
368+
// a non-nil Required slice, preventing "required": null in serialized JSON.
369+
func TestToolInfoRequiredNeverNull(t *testing.T) {
370+
// Simulate the schema extraction logic from loadServerTools
371+
schemaJSON := `{"type": "object", "properties": {"name": {"type": "string"}}}`
372+
373+
var schemaMap map[string]any
374+
if err := json.Unmarshal([]byte(schemaJSON), &schemaMap); err != nil {
375+
t.Fatalf("Failed to unmarshal schema: %v", err)
376+
}
377+
378+
// This mirrors the code in loadServerTools
379+
required := []string{}
380+
if req, ok := schemaMap["required"].([]any); ok {
381+
for _, r := range req {
382+
if s, ok := r.(string); ok {
383+
required = append(required, s)
384+
}
385+
}
386+
}
387+
388+
// required must never be nil
389+
if required == nil {
390+
t.Fatal("required is nil — would serialize as \"required\": null")
391+
}
392+
393+
// Verify JSON serialization
394+
data, err := json.Marshal(map[string]any{"required": required})
395+
if err != nil {
396+
t.Fatalf("Failed to marshal: %v", err)
397+
}
398+
if string(data) != `{"required":[]}` {
399+
t.Errorf("Expected {\"required\":[]}, got %s", string(data))
400+
}
401+
}
402+
305403
// TestConvertExclusiveBoundsToBoolean_InvalidJSON tests that invalid JSON is returned unchanged
306404
func TestConvertExclusiveBoundsToBoolean_InvalidJSON(t *testing.T) {
307405
invalidJSON := []byte(`{invalid json}`)

0 commit comments

Comments
 (0)