diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 3c222eb4..4c728107 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -125,7 +125,7 @@ func TestEndToEnd(t *testing.T) { loggingMessages := make(chan *LoggingMessageParams, 100) // big enough for all logging opts := &ClientOptions{ CreateMessageHandler: func(context.Context, *ClientSession, *CreateMessageParams) (*CreateMessageResult, error) { - return &CreateMessageResult{Model: "aModel"}, nil + return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil }, ToolListChangedHandler: func(context.Context, *ClientSession, *ToolListChangedParams) { notificationChans["tools"] <- 0 }, PromptListChangedHandler: func(context.Context, *ClientSession, *PromptListChangedParams) { notificationChans["prompts"] <- 0 }, diff --git a/mcp/protocol.go b/mcp/protocol.go index 00dcd14d..3ca6cb5e 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -264,6 +264,23 @@ type CreateMessageResult struct { StopReason string `json:"stopReason,omitempty"` } +func (r *CreateMessageResult) UnmarshalJSON(data []byte) error { + type result CreateMessageResult // avoid recursion + var wire struct { + result + Content *wireContent `json:"content"` + } + if err := json.Unmarshal(data, &wire); err != nil { + return err + } + var err error + if wire.result.Content, err = contentFromWire(wire.Content, map[string]bool{"text": true, "image": true, "audio": true}); err != nil { + return err + } + *r = CreateMessageResult(wire.result) + return nil +} + type GetPromptParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses.