diff --git a/mcp/protocol.go b/mcp/protocol.go index 439d07a0..6b6f4790 100644 --- a/mcp/protocol.go +++ b/mcp/protocol.go @@ -129,6 +129,25 @@ type CallToolResultFor[Out any] struct { StructuredContent Out `json:"structuredContent,omitempty"` } +// UnmarshalJSON handles the unmarshalling of content into the Content +// interface. +func (x *CallToolResultFor[Out]) UnmarshalJSON(data []byte) error { + type res CallToolResultFor[Out] // avoid recursion + var wire struct { + res + Content []*wireContent `json:"content"` + } + if err := json.Unmarshal(data, &wire); err != nil { + return err + } + var err error + if wire.res.Content, err = contentsFromWire(wire.Content, nil); err != nil { + return err + } + *x = CallToolResultFor[Out](wire.res) + return nil +} + type CancelledParams struct { // This property is reserved by the protocol to allow clients and servers to // attach additional metadata to their responses. diff --git a/mcp/protocol_test.go b/mcp/protocol_test.go index d68fb738..823c409c 100644 --- a/mcp/protocol_test.go +++ b/mcp/protocol_test.go @@ -8,6 +8,8 @@ import ( "encoding/json" "maps" "testing" + + "github.com/google/go-cmp/cmp" ) func TestParamsMeta(t *testing.T) { @@ -67,3 +69,47 @@ func TestParamsMeta(t *testing.T) { p.SetProgressToken(int32(1)) p.SetProgressToken(int64(1)) } + +func TestContentUnmarshal(t *testing.T) { + // Verify that types with a Content field round-trip properly. + roundtrip := func(in, out any) { + t.Helper() + data, err := json.Marshal(in) + if err != nil { + t.Fatal(err) + } + if err := json.Unmarshal(data, out); err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(in, out); diff != "" { + t.Errorf("mismatch (-want, +got):\n%s", diff) + } + } + + content := []Content{&TextContent{Text: "t"}} + + ctr := &CallToolResult{ + Meta: Meta{"m": true}, + Content: content, + IsError: true, + StructuredContent: map[string]any{"s": "x"}, + } + var got CallToolResult + roundtrip(ctr, &got) + + ctrf := &CallToolResultFor[int]{ + Meta: Meta{"m": true}, + Content: content, + IsError: true, + StructuredContent: 3, + } + var gotf CallToolResultFor[int] + roundtrip(ctrf, &gotf) + + pm := &PromptMessage{ + Content: content[0], + Role: "", + } + var gotpm PromptMessage + roundtrip(pm, &gotpm) +}