diff --git a/jsonschema/infer.go b/jsonschema/infer.go index 9ff0ddd5..7b6b7e2b 100644 --- a/jsonschema/infer.go +++ b/jsonschema/infer.go @@ -8,15 +8,40 @@ package jsonschema import ( "fmt" + "log/slog" + "math/big" "reflect" "regexp" + "time" "github.com/modelcontextprotocol/go-sdk/internal/util" ) +// ForOptions are options for the [For] function. +type ForOptions struct { + // If IgnoreInvalidTypes is true, fields that can't be represented as a JSON Schema + // are ignored instead of causing an error. + // This allows callers to adjust the resulting schema using custom knowledge. + // For example, an interface type where all the possible implementations are + // known can be described with "oneof". + IgnoreInvalidTypes bool + + // TypeSchemas maps types to their schemas. + // If [For] encounters a type equal to a type of a key in this map, the + // corresponding value is used as the resulting schema (after cloning to + // ensure uniqueness). + // Types in this map override the default translations, as described + // in [For]'s documentation. + TypeSchemas map[any]*Schema +} + // For constructs a JSON schema object for the given type argument. +// If non-nil, the provided options configure certain aspects of this contruction, +// described below. + +// It translates Go types into compatible JSON schema types, as follows. +// These defaults can be overridden by [ForOptions.TypeSchemas]. // -// It translates Go types into compatible JSON schema types, as follows: // - Strings have schema type "string". // - Bools have schema type "boolean". // - Signed and unsigned integer types have schema type "integer". @@ -29,48 +54,51 @@ import ( // Their properties are derived from exported struct fields, using the // struct field JSON name. Fields that are marked "omitempty" are // considered optional; all other fields become required properties. +// - Some types in the standard library that implement json.Marshaler +// translate to schemas that match the values to which they marshal. +// For example, [time.Time] translates to the schema for strings. +// +// For will return an error if there is a cycle in the types. // -// For returns an error if t contains (possibly recursively) any of the following Go -// types, as they are incompatible with the JSON schema spec. +// By default, For returns an error if t contains (possibly recursively) any of the +// following Go types, as they are incompatible with the JSON schema spec. +// If [ForOptions.IgnoreInvalidTypes] is true, then these types are ignored instead. // - maps with key other than 'string' // - function types // - channel types // - complex numbers // - unsafe pointers // -// It will return an error if there is a cycle in the types. -// // This function recognizes struct field tags named "jsonschema". // A jsonschema tag on a field is used as the description for the corresponding property. // For future compatibility, descriptions must not start with "WORD=", where WORD is a // sequence of non-whitespace characters. -func For[T any]() (*Schema, error) { - // TODO: consider skipping incompatible fields, instead of failing. - seen := make(map[reflect.Type]bool) - s, err := forType(reflect.TypeFor[T](), seen, false) - if err != nil { - var z T - return nil, fmt.Errorf("For[%T](): %w", z, err) +func For[T any](opts *ForOptions) (*Schema, error) { + if opts == nil { + opts = &ForOptions{} } - return s, nil -} - -// ForLax behaves like [For], except that it ignores struct fields with invalid types instead of -// returning an error. That allows callers to adjust the resulting schema using custom knowledge. -// For example, an interface type where all the possible implementations are known -// can be described with "oneof". -func ForLax[T any]() (*Schema, error) { - // TODO: consider skipping incompatible fields, instead of failing. - seen := make(map[reflect.Type]bool) - s, err := forType(reflect.TypeFor[T](), seen, true) + schemas := make(map[reflect.Type]*Schema) + // Add types from the standard library that have MarshalJSON methods. + ss := &Schema{Type: "string"} + schemas[reflect.TypeFor[time.Time]()] = ss + schemas[reflect.TypeFor[slog.Level]()] = ss + schemas[reflect.TypeFor[big.Int]()] = &Schema{Types: []string{"null", "string"}} + schemas[reflect.TypeFor[big.Rat]()] = ss + schemas[reflect.TypeFor[big.Float]()] = ss + + // Add types from the options. They override the default ones. + for v, s := range opts.TypeSchemas { + schemas[reflect.TypeOf(v)] = s + } + s, err := forType(reflect.TypeFor[T](), map[reflect.Type]bool{}, opts.IgnoreInvalidTypes, schemas) if err != nil { var z T - return nil, fmt.Errorf("ForLax[%T](): %w", z, err) + return nil, fmt.Errorf("For[%T](): %w", z, err) } return s, nil } -func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, error) { +func forType(t reflect.Type, seen map[reflect.Type]bool, ignore bool, schemas map[reflect.Type]*Schema) (*Schema, error) { // Follow pointers: the schema for *T is almost the same as for T, except that // an explicit JSON "null" is allowed for the pointer. allowNull := false @@ -89,6 +117,10 @@ func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, err defer delete(seen, t) } + if s := schemas[t]; s != nil { + return s.CloneSchemas(), nil + } + var ( s = new(Schema) err error @@ -111,7 +143,7 @@ func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, err case reflect.Map: if t.Key().Kind() != reflect.String { - if lax { + if ignore { return nil, nil // ignore } return nil, fmt.Errorf("unsupported map key type %v", t.Key().Kind()) @@ -119,22 +151,22 @@ func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, err if t.Key().Kind() != reflect.String { } s.Type = "object" - s.AdditionalProperties, err = forType(t.Elem(), seen, lax) + s.AdditionalProperties, err = forType(t.Elem(), seen, ignore, schemas) if err != nil { return nil, fmt.Errorf("computing map value schema: %v", err) } - if lax && s.AdditionalProperties == nil { + if ignore && s.AdditionalProperties == nil { // Ignore if the element type is invalid. return nil, nil } case reflect.Slice, reflect.Array: s.Type = "array" - s.Items, err = forType(t.Elem(), seen, lax) + s.Items, err = forType(t.Elem(), seen, ignore, schemas) if err != nil { return nil, fmt.Errorf("computing element schema: %v", err) } - if lax && s.Items == nil { + if ignore && s.Items == nil { // Ignore if the element type is invalid. return nil, nil } @@ -160,11 +192,11 @@ func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, err if s.Properties == nil { s.Properties = make(map[string]*Schema) } - fs, err := forType(field.Type, seen, lax) + fs, err := forType(field.Type, seen, ignore, schemas) if err != nil { return nil, err } - if lax && fs == nil { + if ignore && fs == nil { // Skip fields of invalid type. continue } @@ -184,7 +216,7 @@ func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, err } default: - if lax { + if ignore { // Ignore. return nil, nil } @@ -194,6 +226,7 @@ func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, err s.Types = []string{"null", s.Type} s.Type = "" } + schemas[t] = s return s, nil } diff --git a/jsonschema/infer_test.go b/jsonschema/infer_test.go index 8c8feec0..1a0895b4 100644 --- a/jsonschema/infer_test.go +++ b/jsonschema/infer_test.go @@ -5,22 +5,30 @@ package jsonschema_test import ( + "log/slog" + "math/big" "strings" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/modelcontextprotocol/go-sdk/jsonschema" ) -func forType[T any](lax bool) *jsonschema.Schema { +type custom int + +func forType[T any](ignore bool) *jsonschema.Schema { var s *jsonschema.Schema var err error - if lax { - s, err = jsonschema.ForLax[T]() - } else { - s, err = jsonschema.For[T]() + + opts := &jsonschema.ForOptions{ + IgnoreInvalidTypes: ignore, + TypeSchemas: map[any]*jsonschema.Schema{ + custom(0): {Type: "custom"}, + }, } + s, err = jsonschema.For[T](opts) if err != nil { panic(err) } @@ -40,19 +48,23 @@ func TestFor(t *testing.T) { want *jsonschema.Schema } - tests := func(lax bool) []test { + tests := func(ignore bool) []test { return []test{ - {"string", forType[string](lax), &schema{Type: "string"}}, - {"int", forType[int](lax), &schema{Type: "integer"}}, - {"int16", forType[int16](lax), &schema{Type: "integer"}}, - {"uint32", forType[int16](lax), &schema{Type: "integer"}}, - {"float64", forType[float64](lax), &schema{Type: "number"}}, - {"bool", forType[bool](lax), &schema{Type: "boolean"}}, - {"intmap", forType[map[string]int](lax), &schema{ + {"string", forType[string](ignore), &schema{Type: "string"}}, + {"int", forType[int](ignore), &schema{Type: "integer"}}, + {"int16", forType[int16](ignore), &schema{Type: "integer"}}, + {"uint32", forType[int16](ignore), &schema{Type: "integer"}}, + {"float64", forType[float64](ignore), &schema{Type: "number"}}, + {"bool", forType[bool](ignore), &schema{Type: "boolean"}}, + {"time", forType[time.Time](ignore), &schema{Type: "string"}}, + {"level", forType[slog.Level](ignore), &schema{Type: "string"}}, + {"bigint", forType[big.Int](ignore), &schema{Types: []string{"null", "string"}}}, + {"custom", forType[custom](ignore), &schema{Type: "custom"}}, + {"intmap", forType[map[string]int](ignore), &schema{ Type: "object", AdditionalProperties: &schema{Type: "integer"}, }}, - {"anymap", forType[map[string]any](lax), &schema{ + {"anymap", forType[map[string]any](ignore), &schema{ Type: "object", AdditionalProperties: &schema{}, }}, @@ -66,7 +78,7 @@ func TestFor(t *testing.T) { NoSkip string `json:",omitempty"` unexported float64 unexported2 int `json:"No"` - }](lax), + }](ignore), &schema{ Type: "object", Properties: map[string]*schema{ @@ -81,7 +93,7 @@ func TestFor(t *testing.T) { }, { "no sharing", - forType[struct{ X, Y int }](lax), + forType[struct{ X, Y int }](ignore), &schema{ Type: "object", Properties: map[string]*schema{ @@ -97,7 +109,7 @@ func TestFor(t *testing.T) { forType[struct { A S S - }](lax), + }](ignore), &schema{ Type: "object", Properties: map[string]*schema{ @@ -165,7 +177,7 @@ func TestFor(t *testing.T) { } func forErr[T any]() error { - _, err := jsonschema.For[T]() + _, err := jsonschema.For[T](nil) return err } @@ -209,7 +221,7 @@ func TestForWithMutation(t *testing.T) { D [3]S E *bool } - s, err := jsonschema.For[T]() + s, err := jsonschema.For[T](nil) if err != nil { t.Fatalf("For: %v", err) } @@ -220,7 +232,7 @@ func TestForWithMutation(t *testing.T) { s.Properties["D"].MinItems = jsonschema.Ptr(10) s.Properties["E"].Types[0] = "mutated" - s2, err := jsonschema.For[T]() + s2, err := jsonschema.For[T](nil) if err != nil { t.Fatalf("For: %v", err) } @@ -266,13 +278,13 @@ func TestForWithCycle(t *testing.T) { shouldErr bool fn func() error }{ - {"slice alias (a)", true, func() error { _, err := jsonschema.For[a](); return err }}, - {"unexported self cycle (b1)", false, func() error { _, err := jsonschema.For[b1](); return err }}, - {"exported self cycle (b2)", true, func() error { _, err := jsonschema.For[b2](); return err }}, - {"unexported map self cycle (c1)", false, func() error { _, err := jsonschema.For[c1](); return err }}, - {"exported map self cycle (c2)", true, func() error { _, err := jsonschema.For[c2](); return err }}, - {"cross-cycle x -> y -> x", true, func() error { _, err := jsonschema.For[x](); return err }}, - {"cross-cycle y -> x -> y", true, func() error { _, err := jsonschema.For[y](); return err }}, + {"slice alias (a)", true, func() error { _, err := jsonschema.For[a](nil); return err }}, + {"unexported self cycle (b1)", false, func() error { _, err := jsonschema.For[b1](nil); return err }}, + {"exported self cycle (b2)", true, func() error { _, err := jsonschema.For[b2](nil); return err }}, + {"unexported map self cycle (c1)", false, func() error { _, err := jsonschema.For[c1](nil); return err }}, + {"exported map self cycle (c2)", true, func() error { _, err := jsonschema.For[c2](nil); return err }}, + {"cross-cycle x -> y -> x", true, func() error { _, err := jsonschema.For[x](nil); return err }}, + {"cross-cycle y -> x -> y", true, func() error { _, err := jsonschema.For[y](nil); return err }}, } for _, test := range tests { diff --git a/jsonschema/schema.go b/jsonschema/schema.go index 4b1d6eed..1d60de12 100644 --- a/jsonschema/schema.go +++ b/jsonschema/schema.go @@ -152,6 +152,42 @@ func (s *Schema) String() string { return "" } +// CloneSchemas returns a copy of s. +// The copy is shallow except for sub-schemas, which are themelves copied with CloneSchemas. +// This allows both s and s.CloneSchemas() to appear as sub-schemas in the same parent. +func (s *Schema) CloneSchemas() *Schema { + if s == nil { + return nil + } + s2 := *s + v := reflect.ValueOf(&s2) + for _, info := range schemaFieldInfos { + fv := v.Elem().FieldByIndex(info.sf.Index) + switch info.sf.Type { + case schemaType: + sscss := fv.Interface().(*Schema) + fv.Set(reflect.ValueOf(sscss.CloneSchemas())) + + case schemaSliceType: + slice := fv.Interface().([]*Schema) + slice = slices.Clone(slice) + for i, ss := range slice { + slice[i] = ss.CloneSchemas() + } + fv.Set(reflect.ValueOf(slice)) + + case schemaMapType: + m := fv.Interface().(map[string]*Schema) + m = maps.Clone(m) + for k, ss := range m { + m[k] = ss.CloneSchemas() + } + fv.Set(reflect.ValueOf(m)) + } + } + return &s2 +} + func (s *Schema) basicChecks() error { if s.Type != "" && s.Types != nil { return errors.New("both Type and Types are set; at most one should be") diff --git a/jsonschema/schema_test.go b/jsonschema/schema_test.go index 4ceb1ee1..4b6df511 100644 --- a/jsonschema/schema_test.go +++ b/jsonschema/schema_test.go @@ -142,3 +142,34 @@ func (s *Schema) jsonIndent() string { } return string(data) } + +func TestCloneSchemas(t *testing.T) { + ss1 := &Schema{Type: "string"} + ss2 := &Schema{Type: "integer"} + ss3 := &Schema{Type: "boolean"} + ss4 := &Schema{Type: "number"} + ss5 := &Schema{Contains: ss4} + + s1 := Schema{ + Contains: ss1, + PrefixItems: []*Schema{ss2, ss3}, + Properties: map[string]*Schema{"a": ss5}, + } + s2 := s1.CloneSchemas() + + // The clones should appear identical. + if g, w := s1.json(), s2.json(); g != w { + t.Errorf("\ngot %s\nwant %s", g, w) + } + // None of the schemas should overlap. + schemas1 := map[*Schema]bool{ss1: true, ss2: true, ss3: true, ss4: true, ss5: true} + for ss := range s2.all() { + if schemas1[ss] { + t.Errorf("uncloned schema %s", ss.json()) + } + } + // s1's original schemas should be intact. + if s1.Contains != ss1 || s1.PrefixItems[0] != ss2 || s1.PrefixItems[1] != ss3 || ss5.Contains != ss4 || s1.Properties["a"] != ss5 { + t.Errorf("s1 modified") + } +} diff --git a/mcp/tool.go b/mcp/tool.go index fc154991..ed80b660 100644 --- a/mcp/tool.go +++ b/mcp/tool.go @@ -89,7 +89,7 @@ func newServerTool[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*serverTool func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) error { var err error if *sfield == nil { - *sfield, err = jsonschema.For[T]() + *sfield, err = jsonschema.For[T](nil) } if err != nil { return err