Skip to content

Commit 600ba61

Browse files
authored
mcp: fix CallToolResultFor[T] unmarshaling (#46)
Add an UnmarshalJSON method for that type. Add a test.
1 parent 057f525 commit 600ba61

File tree

2 files changed

+65
-0
lines changed

2 files changed

+65
-0
lines changed

mcp/protocol.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,25 @@ type CallToolResultFor[Out any] struct {
129129
StructuredContent Out `json:"structuredContent,omitempty"`
130130
}
131131

132+
// UnmarshalJSON handles the unmarshalling of content into the Content
133+
// interface.
134+
func (x *CallToolResultFor[Out]) UnmarshalJSON(data []byte) error {
135+
type res CallToolResultFor[Out] // avoid recursion
136+
var wire struct {
137+
res
138+
Content []*wireContent `json:"content"`
139+
}
140+
if err := json.Unmarshal(data, &wire); err != nil {
141+
return err
142+
}
143+
var err error
144+
if wire.res.Content, err = contentsFromWire(wire.Content, nil); err != nil {
145+
return err
146+
}
147+
*x = CallToolResultFor[Out](wire.res)
148+
return nil
149+
}
150+
132151
type CancelledParams struct {
133152
// This property is reserved by the protocol to allow clients and servers to
134153
// attach additional metadata to their responses.

mcp/protocol_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import (
88
"encoding/json"
99
"maps"
1010
"testing"
11+
12+
"github.com/google/go-cmp/cmp"
1113
)
1214

1315
func TestParamsMeta(t *testing.T) {
@@ -67,3 +69,47 @@ func TestParamsMeta(t *testing.T) {
6769
p.SetProgressToken(int32(1))
6870
p.SetProgressToken(int64(1))
6971
}
72+
73+
func TestContentUnmarshal(t *testing.T) {
74+
// Verify that types with a Content field round-trip properly.
75+
roundtrip := func(in, out any) {
76+
t.Helper()
77+
data, err := json.Marshal(in)
78+
if err != nil {
79+
t.Fatal(err)
80+
}
81+
if err := json.Unmarshal(data, out); err != nil {
82+
t.Fatal(err)
83+
}
84+
if diff := cmp.Diff(in, out); diff != "" {
85+
t.Errorf("mismatch (-want, +got):\n%s", diff)
86+
}
87+
}
88+
89+
content := []Content{&TextContent{Text: "t"}}
90+
91+
ctr := &CallToolResult{
92+
Meta: Meta{"m": true},
93+
Content: content,
94+
IsError: true,
95+
StructuredContent: map[string]any{"s": "x"},
96+
}
97+
var got CallToolResult
98+
roundtrip(ctr, &got)
99+
100+
ctrf := &CallToolResultFor[int]{
101+
Meta: Meta{"m": true},
102+
Content: content,
103+
IsError: true,
104+
StructuredContent: 3,
105+
}
106+
var gotf CallToolResultFor[int]
107+
roundtrip(ctrf, &gotf)
108+
109+
pm := &PromptMessage{
110+
Content: content[0],
111+
Role: "",
112+
}
113+
var gotpm PromptMessage
114+
roundtrip(pm, &gotpm)
115+
}

0 commit comments

Comments
 (0)