diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 44dd76d2..446c7ba6 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -1084,3 +1084,104 @@ func TestNoDistributedDeadlock(t *testing.T) { } var testImpl = &Implementation{Name: "test", Version: "v1.0.0"} + +// This test checks that when we use pointer types for tools, we get the same +// schema as when using the non-pointer types. It is too much of a footgun for +// there to be a difference (see #199 and #200). +// +// If anyone asks, we can add an option that controls how pointers are treated. +func TestPointerArgEquivalence(t *testing.T) { + type input struct { + In string + } + type output struct { + Out string + } + cs, _ := basicConnection(t, func(s *Server) { + // Add two equivalent tools, one of which operates in the 'pointer' realm, + // the other of which does not. + // + // We handle a few different types of results, to assert they behave the + // same in all cases. + AddTool(s, &Tool{Name: "pointer"}, func(_ context.Context, req *ServerRequest[*CallToolParams], in *input) (*CallToolResult, *output, error) { + switch in.In { + case "": + return nil, nil, fmt.Errorf("must provide input") + case "nil": + return nil, nil, nil + case "empty": + return &CallToolResult{}, nil, nil + case "ok": + return &CallToolResult{}, &output{Out: "foo"}, nil + default: + panic("unreachable") + } + }) + AddTool(s, &Tool{Name: "nonpointer"}, func(_ context.Context, req *ServerRequest[*CallToolParams], in input) (*CallToolResult, output, error) { + switch in.In { + case "": + return nil, output{}, fmt.Errorf("must provide input") + case "nil": + return nil, output{}, nil + case "empty": + return &CallToolResult{}, output{}, nil + case "ok": + return &CallToolResult{}, output{Out: "foo"}, nil + default: + panic("unreachable") + } + }) + }) + defer cs.Close() + + ctx := context.Background() + tools, err := cs.ListTools(ctx, nil) + if err != nil { + t.Fatal(err) + } + if got, want := len(tools.Tools), 2; got != want { + t.Fatalf("got %d tools, want %d", got, want) + } + t0 := tools.Tools[0] + t1 := tools.Tools[1] + + // First, check that the tool schemas don't differ. + if diff := cmp.Diff(t0.InputSchema, t1.InputSchema); diff != "" { + t.Errorf("input schemas do not match (-%s +%s):\n%s", t0.Name, t1.Name, diff) + } + if diff := cmp.Diff(t0.OutputSchema, t1.OutputSchema); diff != "" { + t.Errorf("output schemas do not match (-%s +%s):\n%s", t0.Name, t1.Name, diff) + } + + // Then, check that we handle empty input equivalently. + for _, args := range []any{nil, struct{}{}} { + r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: args}) + if err != nil { + t.Fatal(err) + } + r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: args}) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(r0, r1); diff != "" { + t.Errorf("CallTool(%v) with no arguments mismatch (-%s +%s):\n%s", args, t0.Name, t1.Name, diff) + } + } + + // Then, check that we handle different types of output equivalently. + for _, in := range []string{"nil", "empty", "ok"} { + t.Run(in, func(t *testing.T) { + r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: input{In: in}}) + if err != nil { + t.Fatal(err) + } + r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: input{In: in}}) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(r0, r1); diff != "" { + t.Errorf("CallTool({\"In\": %q}) mismatch (-%s +%s):\n%s", in, t0.Name, t1.Name, diff) + } + }) + } +} diff --git a/mcp/server.go b/mcp/server.go index b8e72907..13ecb079 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -189,31 +189,26 @@ func toolFor[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandle // TODO(v0.3.0): test func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler, error) { - var err error tt := *t - tt.InputSchema = t.InputSchema - if tt.InputSchema == nil { - tt.InputSchema, err = jsonschema.For[In](nil) + var inputResolved *jsonschema.Resolved + if _, err := setSchema[In](&tt.InputSchema, &inputResolved); err != nil { + return nil, nil, fmt.Errorf("input schema: %w", err) + } + + // Handling for zero values: + // + // If Out is a pointer type and we've derived the output schema from its + // element type, use the zero value of its element type in place of a typed + // nil. + var ( + elemZero any // only non-nil if Out is a pointer type + outputResolved *jsonschema.Resolved + ) + if reflect.TypeFor[Out]() != reflect.TypeFor[any]() { + var err error + elemZero, err = setSchema[Out](&t.OutputSchema, &outputResolved) if err != nil { - return nil, nil, fmt.Errorf("input schema: %w", err) - } - } - inputResolved, err := tt.InputSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) - if err != nil { - return nil, nil, fmt.Errorf("resolving input schema: %w", err) - } - - if tt.OutputSchema == nil && reflect.TypeFor[Out]() != reflect.TypeFor[any]() { - tt.OutputSchema, err = jsonschema.For[Out](nil) - } - if err != nil { - return nil, nil, fmt.Errorf("output schema: %w", err) - } - var outputResolved *jsonschema.Resolved - if tt.OutputSchema != nil { - outputResolved, err = tt.OutputSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) - if err != nil { - return nil, nil, fmt.Errorf("resolving output schema: %w", err) + return nil, nil, fmt.Errorf("output schema: %v", err) } } @@ -255,12 +250,54 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan res = &CallToolResult{} } res.StructuredContent = out + if elemZero != nil { + // Avoid typed nil, which will serialize as JSON null. + // Instead, use the zero value of the non-zero + var z Out + if any(out) == any(z) { // zero is only non-nil if Out is a pointer type + res.StructuredContent = elemZero + } + } + if tt.OutputSchema != nil && elemZero != nil { + res.StructuredContent = elemZero + } return res, nil } return &tt, th, nil } +// setSchema sets the schema and resolved schema corresponding to the type T. +// +// If sfield is nil, the schema is derived from T. +// +// Pointers are treated equivalently to non-pointers when deriving the schema. +// If an indirection occurred to derive the schema, a non-nil zero value is +// returned to be used in place of the typed nil zero value. +// +// Note that if sfield already holds a schema, zero will be nil even if T is a +// pointer: if the user provided the schema, they may have intentionally +// derived it from the pointer type, and handling of zero values is up to them. +// +// TODO(rfindley): we really shouldn't ever return 'null' results. Maybe we +// should have a jsonschema.Zero(schema) helper? +func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) (zero any, err error) { + rt := reflect.TypeFor[T]() + if *sfield == nil { + if rt.Kind() == reflect.Pointer { + rt = rt.Elem() + zero = reflect.Zero(rt).Interface() + } + // TODO: we should be able to pass nil opts here. + *sfield, err = jsonschema.ForType(rt, &jsonschema.ForOptions{}) + } + if err != nil { + return zero, err + } + *rfield, err = (*sfield).Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true}) + return zero, err +} + // AddTool adds a tool and handler to the server. // // A shallow copy of the tool is made first.