Skip to content

Commit 3026172

Browse files
authored
mcp: validate user-provided output schemas (#408)
toolForErr was ignoring the output schema if the output type was any. That neglected the case where the user provided their own output schema. Fixes #371.
1 parent 07b9cee commit 3026172

File tree

3 files changed

+59
-3
lines changed

3 files changed

+59
-3
lines changed

mcp/server.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan
212212
elemZero any // only non-nil if Out is a pointer type
213213
outputResolved *jsonschema.Resolved
214214
)
215-
if reflect.TypeFor[Out]() != reflect.TypeFor[any]() {
215+
if t.OutputSchema != nil || reflect.TypeFor[Out]() != reflect.TypeFor[any]() {
216216
var err error
217217
elemZero, err = setSchema[Out](&tt.OutputSchema, &outputResolved)
218218
if err != nil {
@@ -302,8 +302,8 @@ func toolForErr[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*Tool, ToolHan
302302
// TODO(rfindley): we really shouldn't ever return 'null' results. Maybe we
303303
// should have a jsonschema.Zero(schema) helper?
304304
func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) (zero any, err error) {
305-
rt := reflect.TypeFor[T]()
306305
if *sfield == nil {
306+
rt := reflect.TypeFor[T]()
307307
if rt.Kind() == reflect.Pointer {
308308
rt = rt.Elem()
309309
zero = reflect.Zero(rt).Interface()

mcp/server_test.go

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package mcp
66

77
import (
88
"context"
9+
"encoding/json"
910
"log"
1011
"slices"
1112
"testing"
@@ -487,3 +488,59 @@ func TestAddTool(t *testing.T) {
487488
t.Error("bad Out: expected panic")
488489
}
489490
}
491+
492+
type schema = jsonschema.Schema
493+
494+
func testToolForSchema[In, Out any](t *testing.T, tool *Tool, in string, out Out, wantIn, wantOut *schema, wantErr bool) {
495+
t.Helper()
496+
th := func(context.Context, *CallToolRequest, In) (*CallToolResult, Out, error) {
497+
return nil, out, nil
498+
}
499+
gott, goth, err := toolForErr(tool, th)
500+
if err != nil {
501+
t.Fatal(err)
502+
}
503+
if diff := cmp.Diff(wantIn, gott.InputSchema); diff != "" {
504+
t.Errorf("input: mismatch (-want, +got):\n%s", diff)
505+
}
506+
if diff := cmp.Diff(wantOut, gott.OutputSchema); diff != "" {
507+
t.Errorf("output: mismatch (-want, +got):\n%s", diff)
508+
}
509+
ctr := &CallToolRequest{
510+
Params: &CallToolParamsRaw{
511+
Arguments: json.RawMessage(in),
512+
},
513+
}
514+
_, err = goth(context.Background(), ctr)
515+
516+
if gotErr := err != nil; gotErr != wantErr {
517+
t.Errorf("got error: %t, want error: %t", gotErr, wantErr)
518+
}
519+
}
520+
521+
func TestToolForSchemas(t *testing.T) {
522+
// Validate that ToolFor handles schemas properly.
523+
524+
// Infer both schemas.
525+
testToolForSchema[int](t, &Tool{}, "3", true,
526+
&schema{Type: "integer"}, &schema{Type: "boolean"}, false)
527+
// Validate the input schema: expect an error if it's wrong.
528+
// We can't test that the output schema is validated, because it's typed.
529+
testToolForSchema[int](t, &Tool{}, `"x"`, true,
530+
&schema{Type: "integer"}, &schema{Type: "boolean"}, true)
531+
532+
// Ignore type any for output.
533+
testToolForSchema[int, any](t, &Tool{}, "3", 0,
534+
&schema{Type: "integer"}, nil, false)
535+
// Input is still validated.
536+
testToolForSchema[int, any](t, &Tool{}, `"x"`, 0,
537+
&schema{Type: "integer"}, nil, true)
538+
539+
// Tool sets input schema: that is what's used.
540+
testToolForSchema[int, any](t, &Tool{InputSchema: &schema{Type: "string"}}, "3", 0,
541+
&schema{Type: "string"}, nil, true) // error: 3 is not a string
542+
543+
// Tool sets output schema: that is what's used, and validation happens.
544+
testToolForSchema[string, any](t, &Tool{OutputSchema: &schema{Type: "integer"}}, "3", "x",
545+
&schema{Type: "string"}, &schema{Type: "integer"}, true) // error: "x" is not an integer
546+
}

mcp/tool.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import (
99
"context"
1010
"encoding/json"
1111
"fmt"
12-
// "log"
1312

1413
"github.com/google/jsonschema-go/jsonschema"
1514
)

0 commit comments

Comments
 (0)