Skip to content

Commit bacca7a

Browse files
authored
Implement json.Unmarshaler to CreateMessageResult (#191)
When calling `ServerSession.CreateMessage`, will get an json unmarshal error because `CreateMessageResult.Content` is an interface, and it doesn't implement `json.Unmarshaller`. Implement the `json.Unmarshaller` to `CreateMessageResult`, let it able to unmarshal the client response without error
1 parent 982a0bc commit bacca7a

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

mcp/mcp_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ func TestEndToEnd(t *testing.T) {
125125
loggingMessages := make(chan *LoggingMessageParams, 100) // big enough for all logging
126126
opts := &ClientOptions{
127127
CreateMessageHandler: func(context.Context, *ClientSession, *CreateMessageParams) (*CreateMessageResult, error) {
128-
return &CreateMessageResult{Model: "aModel"}, nil
128+
return &CreateMessageResult{Model: "aModel", Content: &TextContent{}}, nil
129129
},
130130
ToolListChangedHandler: func(context.Context, *ClientSession, *ToolListChangedParams) { notificationChans["tools"] <- 0 },
131131
PromptListChangedHandler: func(context.Context, *ClientSession, *PromptListChangedParams) { notificationChans["prompts"] <- 0 },

mcp/protocol.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,23 @@ type CreateMessageResult struct {
264264
StopReason string `json:"stopReason,omitempty"`
265265
}
266266

267+
func (r *CreateMessageResult) UnmarshalJSON(data []byte) error {
268+
type result CreateMessageResult // avoid recursion
269+
var wire struct {
270+
result
271+
Content *wireContent `json:"content"`
272+
}
273+
if err := json.Unmarshal(data, &wire); err != nil {
274+
return err
275+
}
276+
var err error
277+
if wire.result.Content, err = contentFromWire(wire.Content, map[string]bool{"text": true, "image": true, "audio": true}); err != nil {
278+
return err
279+
}
280+
*r = CreateMessageResult(wire.result)
281+
return nil
282+
}
283+
267284
type GetPromptParams struct {
268285
// This property is reserved by the protocol to allow clients and servers to
269286
// attach additional metadata to their responses.

0 commit comments

Comments
 (0)