Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions mcp/mcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1084,3 +1084,104 @@ func TestNoDistributedDeadlock(t *testing.T) {
}

var testImpl = &Implementation{Name: "test", Version: "v1.0.0"}

// This test checks that when we use pointer types for tools, we get the same
// schema as when using the non-pointer types. It is too much of a footgun for
// there to be a difference (see #199 and #200).
//
// If anyone asks, we can add an option that controls how pointers are treated.
func TestPointerArgEquivalence(t *testing.T) {
type input struct {
In string
}
type output struct {
Out string
}
cs, _ := basicConnection(t, func(s *Server) {
// Add two equivalent tools, one of which operates in the 'pointer' realm,
// the other of which does not.
//
// We handle a few different types of results, to assert they behave the
// same in all cases.
AddTool(s, &Tool{Name: "pointer"}, func(_ context.Context, req *ServerRequest[*CallToolParams], in *input) (*CallToolResult, *output, error) {
switch in.In {
case "":
return nil, nil, fmt.Errorf("must provide input")
case "nil":
return nil, nil, nil
case "empty":
return &CallToolResult{}, nil, nil
case "ok":
return &CallToolResult{}, &output{Out: "foo"}, nil
default:
panic("unreachable")
}
})
AddTool(s, &Tool{Name: "nonpointer"}, func(_ context.Context, req *ServerRequest[*CallToolParams], in input) (*CallToolResult, output, error) {
switch in.In {
case "":
return nil, output{}, fmt.Errorf("must provide input")
case "nil":
return nil, output{}, nil
case "empty":
return &CallToolResult{}, output{}, nil
case "ok":
return &CallToolResult{}, output{Out: "foo"}, nil
default:
panic("unreachable")
}
})
})
defer cs.Close()

ctx := context.Background()
tools, err := cs.ListTools(ctx, nil)
if err != nil {
t.Fatal(err)
}
if got, want := len(tools.Tools), 2; got != want {
t.Fatalf("got %d tools, want %d", got, want)
}
t0 := tools.Tools[0]
t1 := tools.Tools[1]

// First, check that the tool schemas don't differ.
if diff := cmp.Diff(t0.InputSchema, t1.InputSchema); diff != "" {
t.Errorf("input schemas do not match (-%s +%s):\n%s", t0.Name, t1.Name, diff)
}
if diff := cmp.Diff(t0.OutputSchema, t1.OutputSchema); diff != "" {
t.Errorf("output schemas do not match (-%s +%s):\n%s", t0.Name, t1.Name, diff)
}

// Then, check that we handle empty input equivalently.
for _, args := range []any{nil, struct{}{}} {
r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: args})
if err != nil {
t.Fatal(err)
}
r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: args})
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(r0, r1); diff != "" {
t.Errorf("CallTool(%v) with no arguments mismatch (-%s +%s):\n%s", args, t0.Name, t1.Name, diff)
}
}

// Then, check that we handle different types of output equivalently.
for _, in := range []string{"nil", "empty", "ok"} {
t.Run(in, func(t *testing.T) {
r0, err := cs.CallTool(ctx, &CallToolParams{Name: t0.Name, Arguments: input{In: in}})
if err != nil {
t.Fatal(err)
}
r1, err := cs.CallTool(ctx, &CallToolParams{Name: t1.Name, Arguments: input{In: in}})
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(r0, r1); diff != "" {
t.Errorf("CallTool({\"In\": %q}) mismatch (-%s +%s):\n%s", in, t0.Name, t1.Name, diff)
}
})
}
}
83 changes: 60 additions & 23 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,31 +189,26 @@ func toolFor[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandle

// TODO(v0.3.0): test
func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHandler, error) {
var err error
tt := *t
tt.InputSchema = t.InputSchema
if tt.InputSchema == nil {
tt.InputSchema, err = jsonschema.For[In](nil)
var inputResolved *jsonschema.Resolved
if _, err := setSchema[In](&tt.InputSchema, &inputResolved); err != nil {
return nil, nil, fmt.Errorf("input schema: %w", err)
}

// Handling for zero values:
//
// If Out is a pointer type and we've derived the output schema from its
// element type, use the zero value of its element type in place of a typed
// nil.
var (
zero any
outputResolved *jsonschema.Resolved
)
if reflect.TypeFor[Out]() != reflect.TypeFor[any]() {
var err error
zero, err = setSchema[Out](&t.OutputSchema, &outputResolved)
if err != nil {
return nil, nil, fmt.Errorf("input schema: %w", err)
}
}
inputResolved, err := tt.InputSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
if err != nil {
return nil, nil, fmt.Errorf("resolving input schema: %w", err)
}

if tt.OutputSchema == nil && reflect.TypeFor[Out]() != reflect.TypeFor[any]() {
tt.OutputSchema, err = jsonschema.For[Out](nil)
}
if err != nil {
return nil, nil, fmt.Errorf("output schema: %w", err)
}
var outputResolved *jsonschema.Resolved
if tt.OutputSchema != nil {
outputResolved, err = tt.OutputSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
if err != nil {
return nil, nil, fmt.Errorf("resolving output schema: %w", err)
return nil, nil, fmt.Errorf("output schema: %v", err)
}
}

Expand Down Expand Up @@ -255,12 +250,54 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan
res = &CallToolResult{}
}
res.StructuredContent = out
if zero != nil {
// Avoid typed nil, which will serialize as JSON null.
// Instead, use the zero value of the non-zero
var z Out
if any(out) == any(z) { // bypass comparable check: pointers are comparable
res.StructuredContent = zero
}
}
if tt.OutputSchema != nil && zero != nil {
res.StructuredContent = zero
}
return res, nil
}

return &tt, th, nil
}

// setSchema sets the schema and resolved schema corresponding to the type T.
//
// If sfield is nil, the schema is derived from T.
//
// Pointers are treated equivalently to non-pointers when deriving the schema.
// If an indirection occurred to derive the schema, a non-nil zero value is
// returned to be used in place of the typed nil zero value.
//
// Note that if sfield already holds a schema, zero will be nil even if T is a
// pointer: if the user provided the schema, they may have intentionally
// derived it from the pointer type, and handling of zero values is up to them.
//
// TODO(rfindley): we really shouldn't ever return 'null' results. Maybe we
// should have a jsonschema.Zero(schema) helper?
func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) (zero any, err error) {
rt := reflect.TypeFor[T]()
if *sfield == nil {
if rt.Kind() == reflect.Pointer {
rt = rt.Elem()
zero = reflect.Zero(rt).Interface()
}
// TODO: we should be able to pass nil opts here.
*sfield, err = jsonschema.ForType(rt, &jsonschema.ForOptions{})
}
if err != nil {
return zero, err
}
*rfield, err = (*sfield).Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
return zero, err
}

// AddTool adds a tool and handler to the server.
//
// A shallow copy of the tool is made first.
Expand Down