Skip to content

Commit 0a57fc4

Browse files
committed
mcp: treat pointers equivalently to non-pointers when deriving schema
As reported in #199 and #200, the fact that we return a possibly "null" schema for pointer types breaks various clients, which expect schemas to be of type "object". This is an unfortunate footgun. For now, assume that the user wants us to treat pointers equivalently to non-pointers. If we want to change this behavior in the future, we can do so behind an option. + a test Also fix the handling of nil results in the case where the output schema is non-nil: we must provide structured content in this case. (This was causing the test to fail). Fixes #199 Fixes #200
1 parent 79f063b commit 0a57fc4

File tree

2 files changed

+147
-8
lines changed

2 files changed

+147
-8
lines changed

mcp/mcp_test.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,3 +1084,104 @@ func TestNoDistributedDeadlock(t *testing.T) {
10841084
}
10851085

10861086
var testImpl = &Implementation{Name: "test", Version: "v1.0.0"}
1087+
1088+
// This test checks that when we use pointer types for tools, we get the same
1089+
// schema as when using the non-pointer types. It is too much of a footgun for
1090+
// there to be a difference (see #199 and #200).
1091+
//
1092+
// If anyone asks, we can add an option that controls how pointers are treated.
1093+
func TestPointerArgEquivalence(t *testing.T) {
1094+
type input struct {
1095+
In string
1096+
}
1097+
type output struct {
1098+
Out string
1099+
}
1100+
cs, _ := basicConnection(t, func(s *Server) {
1101+
// Add two equivalent tools, one of which operates in the 'pointer' realm,
1102+
// the other of which does not.
1103+
//
1104+
// We handle a few different types of results, to assert they behave the
1105+
// same in all cases.
1106+
AddTool(s, &Tool{Name: "pointer"}, func(_ context.Context, req *ServerRequest[*CallToolParamsFor[*input]]) (*CallToolResultFor[*output], error) {
1107+
switch req.Params.Arguments.In {
1108+
case "":
1109+
return nil, fmt.Errorf("must provide input")
1110+
case "nil":
1111+
return nil, nil
1112+
case "empty":
1113+
return &CallToolResultFor[*output]{}, nil
1114+
case "ok":
1115+
return &CallToolResultFor[*output]{StructuredContent: &output{Out: "foo"}}, nil
1116+
default:
1117+
panic("unreachable")
1118+
}
1119+
})
1120+
AddTool(s, &Tool{Name: "nonpointer"}, func(_ context.Context, req *ServerRequest[*CallToolParamsFor[input]]) (*CallToolResultFor[output], error) {
1121+
switch req.Params.Arguments.In {
1122+
case "":
1123+
return nil, fmt.Errorf("must provide input")
1124+
case "nil":
1125+
return nil, nil
1126+
case "empty":
1127+
return &CallToolResultFor[output]{}, nil
1128+
case "ok":
1129+
return &CallToolResultFor[output]{StructuredContent: output{Out: "foo"}}, nil
1130+
default:
1131+
panic("unreachable")
1132+
}
1133+
})
1134+
})
1135+
defer cs.Close()
1136+
1137+
ctx := context.Background()
1138+
tools, err := cs.ListTools(ctx, nil)
1139+
if err != nil {
1140+
t.Fatal(err)
1141+
}
1142+
if got, want := len(tools.Tools), 2; got != want {
1143+
t.Fatalf("got %d tools, want %d", got, want)
1144+
}
1145+
t0 := tools.Tools[0]
1146+
t1 := tools.Tools[1]
1147+
1148+
// First, check that the tool schemas don't differ.
1149+
if diff := cmp.Diff(t0.InputSchema, t1.InputSchema); diff != "" {
1150+
t.Errorf("input schemas do not match (-%s +%s):\n%s", t0.Name, t1.Name, diff)
1151+
}
1152+
if diff := cmp.Diff(t0.OutputSchema, t1.OutputSchema); diff != "" {
1153+
t.Errorf("output schemas do not match (-%s +%s):\n%s", t0.Name, t1.Name, diff)
1154+
}
1155+
1156+
// Then, check that we handle empty input equivalently.
1157+
for _, args := range []any{nil, struct{}{}} {
1158+
r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: args})
1159+
if err != nil {
1160+
t.Fatal(err)
1161+
}
1162+
r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: args})
1163+
if err != nil {
1164+
t.Fatal(err)
1165+
}
1166+
if diff := cmp.Diff(r0, r1); diff != "" {
1167+
t.Errorf("CallTool(%v) with no arguments mismatch (-%s +%s):\n%s", args, t0.Name, t1.Name, diff)
1168+
}
1169+
}
1170+
1171+
// Then, check that we handle different types of output equivalently.
1172+
for _, in := range []string{"nil", "empty", "ok"} {
1173+
t.Run(in, func(t *testing.T) {
1174+
r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: input{In: in}})
1175+
if err != nil {
1176+
t.Fatal(err)
1177+
}
1178+
r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: input{In: in}})
1179+
if err != nil {
1180+
t.Fatal(err)
1181+
}
1182+
if diff := cmp.Diff(r0, r1); diff != "" {
1183+
t.Errorf("CallTool({\"In\": %q}) mismatch (-%s +%s):\n%s", in, t0.Name, t1.Name, diff)
1184+
}
1185+
})
1186+
}
1187+
}

mcp/tool.go

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,23 @@ type serverTool struct {
4242
func newServerTool[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*serverTool, error) {
4343
st := &serverTool{tool: t}
4444

45-
if err := setSchema[In](&t.InputSchema, &st.inputResolved); err != nil {
45+
if _, _, err := setSchema[In](&t.InputSchema, &st.inputResolved); err != nil {
4646
return nil, err
4747
}
48+
// Handling for zero values:
49+
//
50+
// In some cases (for example, if the response is nil yet we've specified an
51+
// output schema), we must use a synthetic zero value.
52+
//
53+
// If pointer is set, an indirection occured when deriving the schema.
54+
var (
55+
pointer bool
56+
zero any
57+
)
4858
if reflect.TypeFor[Out]() != reflect.TypeFor[any]() {
49-
if err := setSchema[Out](&t.OutputSchema, &st.outputResolved); err != nil {
59+
var err error
60+
pointer, zero, err = setSchema[Out](&t.OutputSchema, &st.outputResolved)
61+
if err != nil {
5062
return nil, err
5163
}
5264
}
@@ -88,30 +100,56 @@ func newServerTool[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*serverTool
88100
}
89101
var ctr CallToolResult
90102
// TODO(jba): What if res == nil? Is that valid?
91-
// TODO(jba): if t.OutputSchema != nil, check that StructuredContent is present and validates.
103+
// TODO(jba): if t.OutputSchema != nil, check that StructuredContent validates.
92104
if res != nil {
93105
// TODO(jba): future-proof this copy.
94106
ctr.Meta = res.Meta
95107
ctr.Content = res.Content
96108
ctr.IsError = res.IsError
97109
ctr.StructuredContent = res.StructuredContent
110+
if pointer {
111+
// Avoid typed nil, which will serialize as JSON null.
112+
var z Out
113+
if any(res.StructuredContent) == any(z) { // bypass comparable check: pointers are comparable
114+
ctr.StructuredContent = zero
115+
}
116+
}
117+
} else if t.OutputSchema != nil {
118+
// StructuredContent must be present.
119+
ctr.StructuredContent = zero
98120
}
99121
return &ctr, nil
100122
}
101123

102124
return st, nil
103125
}
104126

105-
func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) error {
106-
var err error
127+
// setSchema sets the schema and resolved schema corresponding to the type T.
128+
//
129+
// If sfield is nil, the schema is derived from T.
130+
//
131+
// Pointers are treated equivalently to non-pointers when deriving the schema.
132+
// The pointer result whether an indirection occurs to derive the schema.
133+
//
134+
// Note that if sfield is non-nil, pointer is false even if T is a pointer: if
135+
// the user provided the schema, they may have intentionally derived it from
136+
// the pointer type.
137+
func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) (pointer bool, zero any, err error) {
138+
rt := reflect.TypeFor[T]()
107139
if *sfield == nil {
108-
*sfield, err = jsonschema.For[T](nil)
140+
if rt.Kind() == reflect.Pointer {
141+
pointer = true
142+
rt = rt.Elem()
143+
}
144+
// TODO: we should be able to pass nil opts here.
145+
*sfield, err = jsonschema.ForType(rt, &jsonschema.ForOptions{})
109146
}
147+
zero = reflect.Zero(rt).Interface()
110148
if err != nil {
111-
return err
149+
return pointer, zero, err
112150
}
113151
*rfield, err = (*sfield).Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
114-
return err
152+
return pointer, zero, err
115153
}
116154

117155
// unmarshalSchema unmarshals data into v and validates the result according to

0 commit comments

Comments
 (0)