From 3dd00744c4ad7825899b5e6eed37bfa5379f8666 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Wed, 20 Aug 2025 23:12:45 +0000 Subject: [PATCH 1/2] mcp: treat pointers equivalently to non-pointers when deriving schema As reported in #199 and #200, the fact that we return a possibly "null" schema for pointer types breaks various clients, which expect schemas to be of type "object". This is an unfortunate footgun. For now, assume that the user wants us to treat pointers equivalently to non-pointers. If we want to change this behavior in the future, we can do so behind an option. + a test Also fix the handling of nil results in the case where the output schema is non-nil: we must provide structured content in this case. (This was causing the test to fail). Fixes #199 Fixes #200 --- mcp/mcp_test.go | 101 ++++++++++++++++++++++++++++++++++++++++++++++++ mcp/server.go | 83 ++++++++++++++++++++++++++++----------- 2 files changed, 161 insertions(+), 23 deletions(-) diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 44dd76d2..446c7ba6 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -1084,3 +1084,104 @@ func TestNoDistributedDeadlock(t *testing.T) { } var testImpl = &Implementation{Name: "test", Version: "v1.0.0"} + +// This test checks that when we use pointer types for tools, we get the same +// schema as when using the non-pointer types. It is too much of a footgun for +// there to be a difference (see #199 and #200). +// +// If anyone asks, we can add an option that controls how pointers are treated. +func TestPointerArgEquivalence(t *testing.T) { + type input struct { + In string + } + type output struct { + Out string + } + cs, _ := basicConnection(t, func(s *Server) { + // Add two equivalent tools, one of which operates in the 'pointer' realm, + // the other of which does not. + // + // We handle a few different types of results, to assert they behave the + // same in all cases. + AddTool(s, &Tool{Name: "pointer"}, func(_ context.Context, req *ServerRequest[*CallToolParams], in *input) (*CallToolResult, *output, error) { + switch in.In { + case "": + return nil, nil, fmt.Errorf("must provide input") + case "nil": + return nil, nil, nil + case "empty": + return &CallToolResult{}, nil, nil + case "ok": + return &CallToolResult{}, &output{Out: "foo"}, nil + default: + panic("unreachable") + } + }) + AddTool(s, &Tool{Name: "nonpointer"}, func(_ context.Context, req *ServerRequest[*CallToolParams], in input) (*CallToolResult, output, error) { + switch in.In { + case "": + return nil, output{}, fmt.Errorf("must provide input") + case "nil": + return nil, output{}, nil + case "empty": + return &CallToolResult{}, output{}, nil + case "ok": + return &CallToolResult{}, output{Out: "foo"}, nil + default: + panic("unreachable") + } + }) + }) + defer cs.Close() + + ctx := context.Background() + tools, err := cs.ListTools(ctx, nil) + if err != nil { + t.Fatal(err) + } + if got, want := len(tools.Tools), 2; got != want { + t.Fatalf("got %d tools, want %d", got, want) + } + t0 := tools.Tools[0] + t1 := tools.Tools[1] + + // First, check that the tool schemas don't differ. + if diff := cmp.Diff(t0.InputSchema, t1.InputSchema); diff != "" { + t.Errorf("input schemas do not match (-%s +%s):\n%s", t0.Name, t1.Name, diff) + } + if diff := cmp.Diff(t0.OutputSchema, t1.OutputSchema); diff != "" { + t.Errorf("output schemas do not match (-%s +%s):\n%s", t0.Name, t1.Name, diff) + } + + // Then, check that we handle empty input equivalently. + for _, args := range []any{nil, struct{}{}} { + r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: args}) + if err != nil { + t.Fatal(err) + } + r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: args}) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(r0, r1); diff != "" { + t.Errorf("CallTool(%v) with no arguments mismatch (-%s +%s):\n%s", args, t0.Name, t1.Name, diff) + } + } + + // Then, check that we handle different types of output equivalently. + for _, in := range []string{"nil", "empty", "ok"} { + t.Run(in, func(t *testing.T) { + r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: input{In: in}}) + if err != nil { + t.Fatal(err) + } + r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: input{In: in}}) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(r0, r1); diff != "" { + t.Errorf("CallTool({\"In\": %q}) mismatch (-%s +%s):\n%s", in, t0.Name, t1.Name, diff) + } + }) + } +} diff --git a/mcp/server.go b/mcp/server.go index b8e72907..4899d3f0 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -189,31 +189,26 @@ func toolFor[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandle // TODO(v0.3.0): test func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler, error) { - var err error tt := *t - tt.InputSchema = t.InputSchema - if tt.InputSchema == nil { - tt.InputSchema, err = jsonschema.For[In](nil) + var inputResolved *jsonschema.Resolved + if _, err := setSchema[In](&tt.InputSchema, &inputResolved); err != nil { + return nil, nil, fmt.Errorf("input schema: %w", err) + } + + // Handling for zero values: + // + // If Out is a pointer type and we've derived the output schema from its + // element type, use the zero value of its element type in place of a typed + // nil. + var ( + zero any + outputResolved *jsonschema.Resolved + ) + if reflect.TypeFor[Out]() != reflect.TypeFor[any]() { + var err error + zero, err = setSchema[Out](&t.OutputSchema, &outputResolved) if err != nil { - return nil, nil, fmt.Errorf("input schema: %w", err) - } - } - inputResolved, err := tt.InputSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) - if err != nil { - return nil, nil, fmt.Errorf("resolving input schema: %w", err) - } - - if tt.OutputSchema == nil && reflect.TypeFor[Out]() != reflect.TypeFor[any]() { - tt.OutputSchema, err = jsonschema.For[Out](nil) - } - if err != nil { - return nil, nil, fmt.Errorf("output schema: %w", err) - } - var outputResolved *jsonschema.Resolved - if tt.OutputSchema != nil { - outputResolved, err = tt.OutputSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) - if err != nil { - return nil, nil, fmt.Errorf("resolving output schema: %w", err) + return nil, nil, fmt.Errorf("output schema: %v", err) } } @@ -255,12 +250,54 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan res = &CallToolResult{} } res.StructuredContent = out + if zero != nil { + // Avoid typed nil, which will serialize as JSON null. + // Instead, use the zero value of the non-zero + var z Out + if any(out) == any(z) { // bypass comparable check: pointers are comparable + res.StructuredContent = zero + } + } + if tt.OutputSchema != nil && zero != nil { + res.StructuredContent = zero + } return res, nil } return &tt, th, nil } +// setSchema sets the schema and resolved schema corresponding to the type T. +// +// If sfield is nil, the schema is derived from T. +// +// Pointers are treated equivalently to non-pointers when deriving the schema. +// If an indirection occurred to derive the schema, a non-nil zero value is +// returned to be used in place of the typed nil zero value. +// +// Note that if sfield already holds a schema, zero will be nil even if T is a +// pointer: if the user provided the schema, they may have intentionally +// derived it from the pointer type, and handling of zero values is up to them. +// +// TODO(rfindley): we really shouldn't ever return 'null' results. Maybe we +// should have a jsonschema.Zero(schema) helper? +func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) (zero any, err error) { + rt := reflect.TypeFor[T]() + if *sfield == nil { + if rt.Kind() == reflect.Pointer { + rt = rt.Elem() + zero = reflect.Zero(rt).Interface() + } + // TODO: we should be able to pass nil opts here. + *sfield, err = jsonschema.ForType(rt, &jsonschema.ForOptions{}) + } + if err != nil { + return zero, err + } + *rfield, err = (*sfield).Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) + return zero, err +} + // AddTool adds a tool and handler to the server. // // A shallow copy of the tool is made first. From b8c81c633cbfcb2720d33a22d41ae5dce8cd7093 Mon Sep 17 00:00:00 2001 From: Rob Findley Date: Thu, 21 Aug 2025 14:25:55 +0000 Subject: [PATCH 2/2] address comments --- mcp/server.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mcp/server.go b/mcp/server.go index 4899d3f0..13ecb079 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -201,12 +201,12 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan // element type, use the zero value of its element type in place of a typed // nil. var ( - zero any + elemZero any // only non-nil if Out is a pointer type outputResolved *jsonschema.Resolved ) if reflect.TypeFor[Out]() != reflect.TypeFor[any]() { var err error - zero, err = setSchema[Out](&t.OutputSchema, &outputResolved) + elemZero, err = setSchema[Out](&t.OutputSchema, &outputResolved) if err != nil { return nil, nil, fmt.Errorf("output schema: %v", err) } @@ -250,16 +250,16 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan res = &CallToolResult{} } res.StructuredContent = out - if zero != nil { + if elemZero != nil { // Avoid typed nil, which will serialize as JSON null. // Instead, use the zero value of the non-zero var z Out - if any(out) == any(z) { // bypass comparable check: pointers are comparable - res.StructuredContent = zero + if any(out) == any(z) { // zero is only non-nil if Out is a pointer type + res.StructuredContent = elemZero } } - if tt.OutputSchema != nil && zero != nil { - res.StructuredContent = zero + if tt.OutputSchema != nil && elemZero != nil { + res.StructuredContent = elemZero } return res, nil }