Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 60 additions & 30 deletions jsonschema/infer.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,33 @@ package jsonschema

import (
"fmt"
"log/slog"
"math/big"
"reflect"
"regexp"
"time"

"github.com/modelcontextprotocol/go-sdk/internal/util"
)

// ForOptions are options for the [For] function.
type ForOptions struct {
// If IgnoreBadTypes is true, fields that can't be represented as a JSON Schema
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any better word than 'Bad'? I'd say 'Unsupported' is better, but it's too long.

Probably Bad is best, I just wanted to note this decision.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replaced with "Invalid"

// are ignored instead of causing an error.
// This allows callers to adjust the resulting schema using custom knowledge.
// For example, an interface type where all the possible implementations are
// known can be described with "oneof".
IgnoreBadTypes bool

// TypeSchemas maps types to their schemas.
// If [For] encounters a type equal to a type of a key in this map, the
// corresponding value is used as the resulting schema (after cloning to
// ensure uniqueness).
// Types in this map override the default translations, as described
// in [For]'s documentation.
TypeSchemas map[any]*Schema
}

// For constructs a JSON schema object for the given type argument.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// If non-nil, the provided options configures certain aspects of this contruction, described below.

(we need to document somewhere that opts may be nil)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rewrote doc

//
// It translates Go types into compatible JSON schema types, as follows:
Expand All @@ -29,48 +50,52 @@ import (
// Their properties are derived from exported struct fields, using the
// struct field JSON name. Fields that are marked "omitempty" are
// considered optional; all other fields become required properties.
// - Some types in the standard library that implement json.Marshaler
// translate to schemas that match the values to which they marshal.
// For example, [time.Time] translates to the schema for strings.
//
// For returns an error if t contains (possibly recursively) any of the following Go
// types, as they are incompatible with the JSON schema spec.
// By default, For returns an error if t contains (possibly recursively) any of the
// following Go types, as they are incompatible with the JSON schema spec.
// - maps with key other than 'string'
// - function types
// - channel types
// - complex numbers
// - unsafe pointers
//
// If [ForOptions.IgnoreBadTypes] is true, then these types are ignored instead.
//
// It will return an error if there is a cycle in the types.
//
// This function recognizes struct field tags named "jsonschema".
// A jsonschema tag on a field is used as the description for the corresponding property.
// For future compatibility, descriptions must not start with "WORD=", where WORD is a
// sequence of non-whitespace characters.
func For[T any]() (*Schema, error) {
// TODO: consider skipping incompatible fields, instead of failing.
seen := make(map[reflect.Type]bool)
s, err := forType(reflect.TypeFor[T](), seen, false)
if err != nil {
var z T
return nil, fmt.Errorf("For[%T](): %w", z, err)
func For[T any](opts *ForOptions) (*Schema, error) {
if opts == nil {
opts = &ForOptions{}
}
return s, nil
}

// ForLax behaves like [For], except that it ignores struct fields with invalid types instead of
// returning an error. That allows callers to adjust the resulting schema using custom knowledge.
// For example, an interface type where all the possible implementations are known
// can be described with "oneof".
func ForLax[T any]() (*Schema, error) {
// TODO: consider skipping incompatible fields, instead of failing.
seen := make(map[reflect.Type]bool)
s, err := forType(reflect.TypeFor[T](), seen, true)
schemas := make(map[reflect.Type]*Schema)
// Add types from the standard library that have MarshalJSON methods.
ss := &Schema{Type: "string"}
schemas[reflect.TypeFor[time.Time]()] = ss
schemas[reflect.TypeFor[slog.Level]()] = ss
schemas[reflect.TypeFor[big.Int]()] = &Schema{Types: []string{"null", "string"}}
schemas[reflect.TypeFor[big.Rat]()] = ss
schemas[reflect.TypeFor[big.Float]()] = ss

// Add types from the options. They override the default ones.
for v, s := range opts.TypeSchemas {
schemas[reflect.TypeOf(v)] = s
}
s, err := forType(reflect.TypeFor[T](), map[reflect.Type]bool{}, opts.IgnoreBadTypes, schemas)
if err != nil {
var z T
return nil, fmt.Errorf("ForLax[%T](): %w", z, err)
return nil, fmt.Errorf("For[%T](): %w", z, err)
}
return s, nil
}

func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, error) {
func forType(t reflect.Type, seen map[reflect.Type]bool, ignore bool, schemas map[reflect.Type]*Schema) (*Schema, error) {
// Follow pointers: the schema for *T is almost the same as for T, except that
// an explicit JSON "null" is allowed for the pointer.
allowNull := false
Expand All @@ -89,6 +114,10 @@ func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, err
defer delete(seen, t)
}

if s := schemas[t]; s != nil {
return s.CloneSchemas(), nil
}

var (
s = new(Schema)
err error
Expand All @@ -111,30 +140,30 @@ func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, err

case reflect.Map:
if t.Key().Kind() != reflect.String {
if lax {
if ignore {
return nil, nil // ignore
}
return nil, fmt.Errorf("unsupported map key type %v", t.Key().Kind())
}
if t.Key().Kind() != reflect.String {
}
s.Type = "object"
s.AdditionalProperties, err = forType(t.Elem(), seen, lax)
s.AdditionalProperties, err = forType(t.Elem(), seen, ignore, schemas)
if err != nil {
return nil, fmt.Errorf("computing map value schema: %v", err)
}
if lax && s.AdditionalProperties == nil {
if ignore && s.AdditionalProperties == nil {
// Ignore if the element type is invalid.
return nil, nil
}

case reflect.Slice, reflect.Array:
s.Type = "array"
s.Items, err = forType(t.Elem(), seen, lax)
s.Items, err = forType(t.Elem(), seen, ignore, schemas)
if err != nil {
return nil, fmt.Errorf("computing element schema: %v", err)
}
if lax && s.Items == nil {
if ignore && s.Items == nil {
// Ignore if the element type is invalid.
return nil, nil
}
Expand All @@ -160,11 +189,11 @@ func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, err
if s.Properties == nil {
s.Properties = make(map[string]*Schema)
}
fs, err := forType(field.Type, seen, lax)
fs, err := forType(field.Type, seen, ignore, schemas)
if err != nil {
return nil, err
}
if lax && fs == nil {
if ignore && fs == nil {
// Skip fields of invalid type.
continue
}
Expand All @@ -184,7 +213,7 @@ func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, err
}

default:
if lax {
if ignore {
// Ignore.
return nil, nil
}
Expand All @@ -194,6 +223,7 @@ func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, err
s.Types = []string{"null", s.Type}
s.Type = ""
}
schemas[t] = s
return s, nil
}

Expand Down
66 changes: 39 additions & 27 deletions jsonschema/infer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,30 @@
package jsonschema_test

import (
"log/slog"
"math/big"
"strings"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/modelcontextprotocol/go-sdk/jsonschema"
)

func forType[T any](lax bool) *jsonschema.Schema {
type custom int

func forType[T any](ignore bool) *jsonschema.Schema {
var s *jsonschema.Schema
var err error
if lax {
s, err = jsonschema.ForLax[T]()
} else {
s, err = jsonschema.For[T]()

opts := &jsonschema.ForOptions{
IgnoreBadTypes: ignore,
TypeSchemas: map[any]*jsonschema.Schema{
custom(0): {Type: "custom"},
},
}
s, err = jsonschema.For[T](opts)
if err != nil {
panic(err)
}
Expand All @@ -40,19 +48,23 @@ func TestFor(t *testing.T) {
want *jsonschema.Schema
}

tests := func(lax bool) []test {
tests := func(ignore bool) []test {
return []test{
{"string", forType[string](lax), &schema{Type: "string"}},
{"int", forType[int](lax), &schema{Type: "integer"}},
{"int16", forType[int16](lax), &schema{Type: "integer"}},
{"uint32", forType[int16](lax), &schema{Type: "integer"}},
{"float64", forType[float64](lax), &schema{Type: "number"}},
{"bool", forType[bool](lax), &schema{Type: "boolean"}},
{"intmap", forType[map[string]int](lax), &schema{
{"string", forType[string](ignore), &schema{Type: "string"}},
{"int", forType[int](ignore), &schema{Type: "integer"}},
{"int16", forType[int16](ignore), &schema{Type: "integer"}},
{"uint32", forType[int16](ignore), &schema{Type: "integer"}},
{"float64", forType[float64](ignore), &schema{Type: "number"}},
{"bool", forType[bool](ignore), &schema{Type: "boolean"}},
{"time", forType[time.Time](ignore), &schema{Type: "string"}},
{"level", forType[slog.Level](ignore), &schema{Type: "string"}},
{"bigint", forType[big.Int](ignore), &schema{Types: []string{"null", "string"}}},
{"custom", forType[custom](ignore), &schema{Type: "custom"}},
{"intmap", forType[map[string]int](ignore), &schema{
Type: "object",
AdditionalProperties: &schema{Type: "integer"},
}},
{"anymap", forType[map[string]any](lax), &schema{
{"anymap", forType[map[string]any](ignore), &schema{
Type: "object",
AdditionalProperties: &schema{},
}},
Expand All @@ -66,7 +78,7 @@ func TestFor(t *testing.T) {
NoSkip string `json:",omitempty"`
unexported float64
unexported2 int `json:"No"`
}](lax),
}](ignore),
&schema{
Type: "object",
Properties: map[string]*schema{
Expand All @@ -81,7 +93,7 @@ func TestFor(t *testing.T) {
},
{
"no sharing",
forType[struct{ X, Y int }](lax),
forType[struct{ X, Y int }](ignore),
&schema{
Type: "object",
Properties: map[string]*schema{
Expand All @@ -97,7 +109,7 @@ func TestFor(t *testing.T) {
forType[struct {
A S
S
}](lax),
}](ignore),
&schema{
Type: "object",
Properties: map[string]*schema{
Expand Down Expand Up @@ -165,7 +177,7 @@ func TestFor(t *testing.T) {
}

func forErr[T any]() error {
_, err := jsonschema.For[T]()
_, err := jsonschema.For[T](nil)
return err
}

Expand Down Expand Up @@ -209,7 +221,7 @@ func TestForWithMutation(t *testing.T) {
D [3]S
E *bool
}
s, err := jsonschema.For[T]()
s, err := jsonschema.For[T](nil)
if err != nil {
t.Fatalf("For: %v", err)
}
Expand All @@ -220,7 +232,7 @@ func TestForWithMutation(t *testing.T) {
s.Properties["D"].MinItems = jsonschema.Ptr(10)
s.Properties["E"].Types[0] = "mutated"

s2, err := jsonschema.For[T]()
s2, err := jsonschema.For[T](nil)
if err != nil {
t.Fatalf("For: %v", err)
}
Expand Down Expand Up @@ -266,13 +278,13 @@ func TestForWithCycle(t *testing.T) {
shouldErr bool
fn func() error
}{
{"slice alias (a)", true, func() error { _, err := jsonschema.For[a](); return err }},
{"unexported self cycle (b1)", false, func() error { _, err := jsonschema.For[b1](); return err }},
{"exported self cycle (b2)", true, func() error { _, err := jsonschema.For[b2](); return err }},
{"unexported map self cycle (c1)", false, func() error { _, err := jsonschema.For[c1](); return err }},
{"exported map self cycle (c2)", true, func() error { _, err := jsonschema.For[c2](); return err }},
{"cross-cycle x -> y -> x", true, func() error { _, err := jsonschema.For[x](); return err }},
{"cross-cycle y -> x -> y", true, func() error { _, err := jsonschema.For[y](); return err }},
{"slice alias (a)", true, func() error { _, err := jsonschema.For[a](nil); return err }},
{"unexported self cycle (b1)", false, func() error { _, err := jsonschema.For[b1](nil); return err }},
{"exported self cycle (b2)", true, func() error { _, err := jsonschema.For[b2](nil); return err }},
{"unexported map self cycle (c1)", false, func() error { _, err := jsonschema.For[c1](nil); return err }},
{"exported map self cycle (c2)", true, func() error { _, err := jsonschema.For[c2](nil); return err }},
{"cross-cycle x -> y -> x", true, func() error { _, err := jsonschema.For[x](nil); return err }},
{"cross-cycle y -> x -> y", true, func() error { _, err := jsonschema.For[y](nil); return err }},
}

for _, test := range tests {
Expand Down
36 changes: 36 additions & 0 deletions jsonschema/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,42 @@ func (s *Schema) String() string {
return "<anonymous schema>"
}

// CloneSchemas returns a copy of s.
// The copy is shallow except for sub-schemas, which are themelves copied with CloneSchemas.
// This allows both s and s.CloneSchemas() to appear as sub-schemas in the same parent.
func (s *Schema) CloneSchemas() *Schema {
if s == nil {
return nil
}
s2 := *s
v := reflect.ValueOf(&s2)
for _, info := range schemaFieldInfos {
fv := v.Elem().FieldByIndex(info.sf.Index)
switch info.sf.Type {
case schemaType:
sscss := fv.Interface().(*Schema)
fv.Set(reflect.ValueOf(sscss.CloneSchemas()))

case schemaSliceType:
slice := fv.Interface().([]*Schema)
slice = slices.Clone(slice)
for i, ss := range slice {
slice[i] = ss.CloneSchemas()
}
fv.Set(reflect.ValueOf(slice))

case schemaMapType:
m := fv.Interface().(map[string]*Schema)
m = maps.Clone(m)
for k, ss := range m {
m[k] = ss.CloneSchemas()
}
fv.Set(reflect.ValueOf(m))
}
}
return &s2
}

func (s *Schema) basicChecks() error {
if s.Type != "" && s.Types != nil {
return errors.New("both Type and Types are set; at most one should be")
Expand Down
Loading
Loading