diff --git a/jsonschema/infer.go b/jsonschema/infer.go index 654e6197..9ff0ddd5 100644 --- a/jsonschema/infer.go +++ b/jsonschema/infer.go @@ -47,7 +47,7 @@ import ( func For[T any]() (*Schema, error) { // TODO: consider skipping incompatible fields, instead of failing. seen := make(map[reflect.Type]bool) - s, err := forType(reflect.TypeFor[T](), seen) + s, err := forType(reflect.TypeFor[T](), seen, false) if err != nil { var z T return nil, fmt.Errorf("For[%T](): %w", z, err) @@ -55,7 +55,22 @@ func For[T any]() (*Schema, error) { return s, nil } -func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) { +// ForLax behaves like [For], except that it ignores struct fields with invalid types instead of +// returning an error. That allows callers to adjust the resulting schema using custom knowledge. +// For example, an interface type where all the possible implementations are known +// can be described with "oneof". +func ForLax[T any]() (*Schema, error) { + // TODO: consider skipping incompatible fields, instead of failing. + seen := make(map[reflect.Type]bool) + s, err := forType(reflect.TypeFor[T](), seen, true) + if err != nil { + var z T + return nil, fmt.Errorf("ForLax[%T](): %w", z, err) + } + return s, nil +} + +func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, error) { // Follow pointers: the schema for *T is almost the same as for T, except that // an explicit JSON "null" is allowed for the pointer. allowNull := false @@ -96,20 +111,33 @@ func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) { case reflect.Map: if t.Key().Kind() != reflect.String { + if lax { + return nil, nil // ignore + } return nil, fmt.Errorf("unsupported map key type %v", t.Key().Kind()) } + if t.Key().Kind() != reflect.String { + } s.Type = "object" - s.AdditionalProperties, err = forType(t.Elem(), seen) + s.AdditionalProperties, err = forType(t.Elem(), seen, lax) if err != nil { return nil, fmt.Errorf("computing map value schema: %v", err) } + if lax && s.AdditionalProperties == nil { + // Ignore if the element type is invalid. + return nil, nil + } case reflect.Slice, reflect.Array: s.Type = "array" - s.Items, err = forType(t.Elem(), seen) + s.Items, err = forType(t.Elem(), seen, lax) if err != nil { return nil, fmt.Errorf("computing element schema: %v", err) } + if lax && s.Items == nil { + // Ignore if the element type is invalid. + return nil, nil + } if t.Kind() == reflect.Array { s.MinItems = Ptr(t.Len()) s.MaxItems = Ptr(t.Len()) @@ -132,10 +160,14 @@ func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) { if s.Properties == nil { s.Properties = make(map[string]*Schema) } - fs, err := forType(field.Type, seen) + fs, err := forType(field.Type, seen, lax) if err != nil { return nil, err } + if lax && fs == nil { + // Skip fields of invalid type. + continue + } if tag, ok := field.Tag.Lookup("jsonschema"); ok { if tag == "" { return nil, fmt.Errorf("empty jsonschema tag on struct field %s.%s", t, field.Name) @@ -152,6 +184,10 @@ func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) { } default: + if lax { + // Ignore. + return nil, nil + } return nil, fmt.Errorf("type %v is unsupported by jsonschema", t) } if allowNull && s.Type != "" { diff --git a/jsonschema/infer_test.go b/jsonschema/infer_test.go index 106e5375..8c8feec0 100644 --- a/jsonschema/infer_test.go +++ b/jsonschema/infer_test.go @@ -13,8 +13,14 @@ import ( "github.com/modelcontextprotocol/go-sdk/jsonschema" ) -func forType[T any]() *jsonschema.Schema { - s, err := jsonschema.For[T]() +func forType[T any](lax bool) *jsonschema.Schema { + var s *jsonschema.Schema + var err error + if lax { + s, err = jsonschema.ForLax[T]() + } else { + s, err = jsonschema.For[T]() + } if err != nil { panic(err) } @@ -28,104 +34,134 @@ func TestFor(t *testing.T) { B int `jsonschema:"bdesc"` } - tests := []struct { + type test struct { name string got *jsonschema.Schema want *jsonschema.Schema - }{ - {"string", forType[string](), &schema{Type: "string"}}, - {"int", forType[int](), &schema{Type: "integer"}}, - {"int16", forType[int16](), &schema{Type: "integer"}}, - {"uint32", forType[int16](), &schema{Type: "integer"}}, - {"float64", forType[float64](), &schema{Type: "number"}}, - {"bool", forType[bool](), &schema{Type: "boolean"}}, - {"intmap", forType[map[string]int](), &schema{ - Type: "object", - AdditionalProperties: &schema{Type: "integer"}, - }}, - {"anymap", forType[map[string]any](), &schema{ - Type: "object", - AdditionalProperties: &schema{}, - }}, - { - "struct", - forType[struct { - F int `json:"f" jsonschema:"fdesc"` - G []float64 - P *bool `jsonschema:"pdesc"` - Skip string `json:"-"` - NoSkip string `json:",omitempty"` - unexported float64 - unexported2 int `json:"No"` - }](), - &schema{ - Type: "object", - Properties: map[string]*schema{ - "f": {Type: "integer", Description: "fdesc"}, - "G": {Type: "array", Items: &schema{Type: "number"}}, - "P": {Types: []string{"null", "boolean"}, Description: "pdesc"}, - "NoSkip": {Type: "string"}, + } + + tests := func(lax bool) []test { + return []test{ + {"string", forType[string](lax), &schema{Type: "string"}}, + {"int", forType[int](lax), &schema{Type: "integer"}}, + {"int16", forType[int16](lax), &schema{Type: "integer"}}, + {"uint32", forType[int16](lax), &schema{Type: "integer"}}, + {"float64", forType[float64](lax), &schema{Type: "number"}}, + {"bool", forType[bool](lax), &schema{Type: "boolean"}}, + {"intmap", forType[map[string]int](lax), &schema{ + Type: "object", + AdditionalProperties: &schema{Type: "integer"}, + }}, + {"anymap", forType[map[string]any](lax), &schema{ + Type: "object", + AdditionalProperties: &schema{}, + }}, + { + "struct", + forType[struct { + F int `json:"f" jsonschema:"fdesc"` + G []float64 + P *bool `jsonschema:"pdesc"` + Skip string `json:"-"` + NoSkip string `json:",omitempty"` + unexported float64 + unexported2 int `json:"No"` + }](lax), + &schema{ + Type: "object", + Properties: map[string]*schema{ + "f": {Type: "integer", Description: "fdesc"}, + "G": {Type: "array", Items: &schema{Type: "number"}}, + "P": {Types: []string{"null", "boolean"}, Description: "pdesc"}, + "NoSkip": {Type: "string"}, + }, + Required: []string{"f", "G", "P"}, + AdditionalProperties: falseSchema(), }, - Required: []string{"f", "G", "P"}, - AdditionalProperties: falseSchema(), }, - }, - { - "no sharing", - forType[struct{ X, Y int }](), - &schema{ - Type: "object", - Properties: map[string]*schema{ - "X": {Type: "integer"}, - "Y": {Type: "integer"}, + { + "no sharing", + forType[struct{ X, Y int }](lax), + &schema{ + Type: "object", + Properties: map[string]*schema{ + "X": {Type: "integer"}, + "Y": {Type: "integer"}, + }, + Required: []string{"X", "Y"}, + AdditionalProperties: falseSchema(), }, - Required: []string{"X", "Y"}, - AdditionalProperties: falseSchema(), }, - }, - { - "nested and embedded", - forType[struct { - A S - S - }](), - &schema{ - Type: "object", - Properties: map[string]*schema{ - "A": { - Type: "object", - Properties: map[string]*schema{ - "B": {Type: "integer", Description: "bdesc"}, + { + "nested and embedded", + forType[struct { + A S + S + }](lax), + &schema{ + Type: "object", + Properties: map[string]*schema{ + "A": { + Type: "object", + Properties: map[string]*schema{ + "B": {Type: "integer", Description: "bdesc"}, + }, + Required: []string{"B"}, + AdditionalProperties: falseSchema(), }, - Required: []string{"B"}, - AdditionalProperties: falseSchema(), - }, - "S": { - Type: "object", - Properties: map[string]*schema{ - "B": {Type: "integer", Description: "bdesc"}, + "S": { + Type: "object", + Properties: map[string]*schema{ + "B": {Type: "integer", Description: "bdesc"}, + }, + Required: []string{"B"}, + AdditionalProperties: falseSchema(), }, - Required: []string{"B"}, - AdditionalProperties: falseSchema(), }, + Required: []string{"A", "S"}, + AdditionalProperties: falseSchema(), }, - Required: []string{"A", "S"}, - AdditionalProperties: falseSchema(), }, - }, + } } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - if diff := cmp.Diff(test.want, test.got, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { - t.Fatalf("ForType mismatch (-want +got):\n%s", diff) - } - // These schemas should all resolve. - if _, err := test.got.Resolve(nil); err != nil { - t.Fatalf("Resolving: %v", err) - } - }) + run := func(t *testing.T, tt test) { + if diff := cmp.Diff(tt.want, tt.got, cmpopts.IgnoreUnexported(jsonschema.Schema{})); diff != "" { + t.Fatalf("ForType mismatch (-want +got):\n%s", diff) + } + // These schemas should all resolve. + if _, err := tt.got.Resolve(nil); err != nil { + t.Fatalf("Resolving: %v", err) + } } + + t.Run("strict", func(t *testing.T) { + for _, test := range tests(false) { + t.Run(test.name, func(t *testing.T) { run(t, test) }) + } + }) + + laxTests := append(tests(true), test{ + "ignore", + forType[struct { + A int + B map[int]int + C func() + }](true), + &schema{ + Type: "object", + Properties: map[string]*schema{ + "A": {Type: "integer"}, + }, + Required: []string{"A"}, + AdditionalProperties: falseSchema(), + }, + }) + t.Run("lax", func(t *testing.T) { + for _, test := range laxTests { + t.Run(test.name, func(t *testing.T) { run(t, test) }) + } + }) } func forErr[T any]() error {