Skip to content

Commit 828d4a6

Browse files
jbaBC-ACherednichenko
authored andcommitted
jsonschema: options for inference (modelcontextprotocol#185)
- Add ForOptions to hold options for schema inference. - Replace ForLax with ForOptions.IgnoreBadTypes. - Add an option to provide schemas for arbitrary types. - Provide a default mapping from types to schemas that includes stdlib types with MarshalJSON methods. - Add Schema.CloneSchemas. This is needed to make copies of the schemas in the above map: a schema cannot appear twice in a parent schema, because schema addresses matter when resolving internal references.
1 parent 5527ef2 commit 828d4a6

File tree

5 files changed

+173
-61
lines changed

5 files changed

+173
-61
lines changed

jsonschema/infer.go

Lines changed: 66 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,40 @@ package jsonschema
88

99
import (
1010
"fmt"
11+
"log/slog"
12+
"math/big"
1113
"reflect"
1214
"regexp"
15+
"time"
1316

1417
"github.com/modelcontextprotocol/go-sdk/internal/util"
1518
)
1619

20+
// ForOptions are options for the [For] function.
21+
type ForOptions struct {
22+
// If IgnoreInvalidTypes is true, fields that can't be represented as a JSON Schema
23+
// are ignored instead of causing an error.
24+
// This allows callers to adjust the resulting schema using custom knowledge.
25+
// For example, an interface type where all the possible implementations are
26+
// known can be described with "oneof".
27+
IgnoreInvalidTypes bool
28+
29+
// TypeSchemas maps types to their schemas.
30+
// If [For] encounters a type equal to a type of a key in this map, the
31+
// corresponding value is used as the resulting schema (after cloning to
32+
// ensure uniqueness).
33+
// Types in this map override the default translations, as described
34+
// in [For]'s documentation.
35+
TypeSchemas map[any]*Schema
36+
}
37+
1738
// For constructs a JSON schema object for the given type argument.
39+
// If non-nil, the provided options configure certain aspects of this contruction,
40+
// described below.
41+
42+
// It translates Go types into compatible JSON schema types, as follows.
43+
// These defaults can be overridden by [ForOptions.TypeSchemas].
1844
//
19-
// It translates Go types into compatible JSON schema types, as follows:
2045
// - Strings have schema type "string".
2146
// - Bools have schema type "boolean".
2247
// - Signed and unsigned integer types have schema type "integer".
@@ -29,48 +54,51 @@ import (
2954
// Their properties are derived from exported struct fields, using the
3055
// struct field JSON name. Fields that are marked "omitempty" are
3156
// considered optional; all other fields become required properties.
57+
// - Some types in the standard library that implement json.Marshaler
58+
// translate to schemas that match the values to which they marshal.
59+
// For example, [time.Time] translates to the schema for strings.
60+
//
61+
// For will return an error if there is a cycle in the types.
3262
//
33-
// For returns an error if t contains (possibly recursively) any of the following Go
34-
// types, as they are incompatible with the JSON schema spec.
63+
// By default, For returns an error if t contains (possibly recursively) any of the
64+
// following Go types, as they are incompatible with the JSON schema spec.
65+
// If [ForOptions.IgnoreInvalidTypes] is true, then these types are ignored instead.
3566
// - maps with key other than 'string'
3667
// - function types
3768
// - channel types
3869
// - complex numbers
3970
// - unsafe pointers
4071
//
41-
// It will return an error if there is a cycle in the types.
42-
//
4372
// This function recognizes struct field tags named "jsonschema".
4473
// A jsonschema tag on a field is used as the description for the corresponding property.
4574
// For future compatibility, descriptions must not start with "WORD=", where WORD is a
4675
// sequence of non-whitespace characters.
47-
func For[T any]() (*Schema, error) {
48-
// TODO: consider skipping incompatible fields, instead of failing.
49-
seen := make(map[reflect.Type]bool)
50-
s, err := forType(reflect.TypeFor[T](), seen, false)
51-
if err != nil {
52-
var z T
53-
return nil, fmt.Errorf("For[%T](): %w", z, err)
76+
func For[T any](opts *ForOptions) (*Schema, error) {
77+
if opts == nil {
78+
opts = &ForOptions{}
5479
}
55-
return s, nil
56-
}
57-
58-
// ForLax behaves like [For], except that it ignores struct fields with invalid types instead of
59-
// returning an error. That allows callers to adjust the resulting schema using custom knowledge.
60-
// For example, an interface type where all the possible implementations are known
61-
// can be described with "oneof".
62-
func ForLax[T any]() (*Schema, error) {
63-
// TODO: consider skipping incompatible fields, instead of failing.
64-
seen := make(map[reflect.Type]bool)
65-
s, err := forType(reflect.TypeFor[T](), seen, true)
80+
schemas := make(map[reflect.Type]*Schema)
81+
// Add types from the standard library that have MarshalJSON methods.
82+
ss := &Schema{Type: "string"}
83+
schemas[reflect.TypeFor[time.Time]()] = ss
84+
schemas[reflect.TypeFor[slog.Level]()] = ss
85+
schemas[reflect.TypeFor[big.Int]()] = &Schema{Types: []string{"null", "string"}}
86+
schemas[reflect.TypeFor[big.Rat]()] = ss
87+
schemas[reflect.TypeFor[big.Float]()] = ss
88+
89+
// Add types from the options. They override the default ones.
90+
for v, s := range opts.TypeSchemas {
91+
schemas[reflect.TypeOf(v)] = s
92+
}
93+
s, err := forType(reflect.TypeFor[T](), map[reflect.Type]bool{}, opts.IgnoreInvalidTypes, schemas)
6694
if err != nil {
6795
var z T
68-
return nil, fmt.Errorf("ForLax[%T](): %w", z, err)
96+
return nil, fmt.Errorf("For[%T](): %w", z, err)
6997
}
7098
return s, nil
7199
}
72100

73-
func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, error) {
101+
func forType(t reflect.Type, seen map[reflect.Type]bool, ignore bool, schemas map[reflect.Type]*Schema) (*Schema, error) {
74102
// Follow pointers: the schema for *T is almost the same as for T, except that
75103
// an explicit JSON "null" is allowed for the pointer.
76104
allowNull := false
@@ -89,6 +117,10 @@ func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, err
89117
defer delete(seen, t)
90118
}
91119

120+
if s := schemas[t]; s != nil {
121+
return s.CloneSchemas(), nil
122+
}
123+
92124
var (
93125
s = new(Schema)
94126
err error
@@ -111,30 +143,30 @@ func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, err
111143

112144
case reflect.Map:
113145
if t.Key().Kind() != reflect.String {
114-
if lax {
146+
if ignore {
115147
return nil, nil // ignore
116148
}
117149
return nil, fmt.Errorf("unsupported map key type %v", t.Key().Kind())
118150
}
119151
if t.Key().Kind() != reflect.String {
120152
}
121153
s.Type = "object"
122-
s.AdditionalProperties, err = forType(t.Elem(), seen, lax)
154+
s.AdditionalProperties, err = forType(t.Elem(), seen, ignore, schemas)
123155
if err != nil {
124156
return nil, fmt.Errorf("computing map value schema: %v", err)
125157
}
126-
if lax && s.AdditionalProperties == nil {
158+
if ignore && s.AdditionalProperties == nil {
127159
// Ignore if the element type is invalid.
128160
return nil, nil
129161
}
130162

131163
case reflect.Slice, reflect.Array:
132164
s.Type = "array"
133-
s.Items, err = forType(t.Elem(), seen, lax)
165+
s.Items, err = forType(t.Elem(), seen, ignore, schemas)
134166
if err != nil {
135167
return nil, fmt.Errorf("computing element schema: %v", err)
136168
}
137-
if lax && s.Items == nil {
169+
if ignore && s.Items == nil {
138170
// Ignore if the element type is invalid.
139171
return nil, nil
140172
}
@@ -160,11 +192,11 @@ func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, err
160192
if s.Properties == nil {
161193
s.Properties = make(map[string]*Schema)
162194
}
163-
fs, err := forType(field.Type, seen, lax)
195+
fs, err := forType(field.Type, seen, ignore, schemas)
164196
if err != nil {
165197
return nil, err
166198
}
167-
if lax && fs == nil {
199+
if ignore && fs == nil {
168200
// Skip fields of invalid type.
169201
continue
170202
}
@@ -184,7 +216,7 @@ func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, err
184216
}
185217

186218
default:
187-
if lax {
219+
if ignore {
188220
// Ignore.
189221
return nil, nil
190222
}
@@ -194,6 +226,7 @@ func forType(t reflect.Type, seen map[reflect.Type]bool, lax bool) (*Schema, err
194226
s.Types = []string{"null", s.Type}
195227
s.Type = ""
196228
}
229+
schemas[t] = s
197230
return s, nil
198231
}
199232

jsonschema/infer_test.go

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,30 @@
55
package jsonschema_test
66

77
import (
8+
"log/slog"
9+
"math/big"
810
"strings"
911
"testing"
12+
"time"
1013

1114
"github.com/google/go-cmp/cmp"
1215
"github.com/google/go-cmp/cmp/cmpopts"
1316
"github.com/modelcontextprotocol/go-sdk/jsonschema"
1417
)
1518

16-
func forType[T any](lax bool) *jsonschema.Schema {
19+
type custom int
20+
21+
func forType[T any](ignore bool) *jsonschema.Schema {
1722
var s *jsonschema.Schema
1823
var err error
19-
if lax {
20-
s, err = jsonschema.ForLax[T]()
21-
} else {
22-
s, err = jsonschema.For[T]()
24+
25+
opts := &jsonschema.ForOptions{
26+
IgnoreInvalidTypes: ignore,
27+
TypeSchemas: map[any]*jsonschema.Schema{
28+
custom(0): {Type: "custom"},
29+
},
2330
}
31+
s, err = jsonschema.For[T](opts)
2432
if err != nil {
2533
panic(err)
2634
}
@@ -40,19 +48,23 @@ func TestFor(t *testing.T) {
4048
want *jsonschema.Schema
4149
}
4250

43-
tests := func(lax bool) []test {
51+
tests := func(ignore bool) []test {
4452
return []test{
45-
{"string", forType[string](lax), &schema{Type: "string"}},
46-
{"int", forType[int](lax), &schema{Type: "integer"}},
47-
{"int16", forType[int16](lax), &schema{Type: "integer"}},
48-
{"uint32", forType[int16](lax), &schema{Type: "integer"}},
49-
{"float64", forType[float64](lax), &schema{Type: "number"}},
50-
{"bool", forType[bool](lax), &schema{Type: "boolean"}},
51-
{"intmap", forType[map[string]int](lax), &schema{
53+
{"string", forType[string](ignore), &schema{Type: "string"}},
54+
{"int", forType[int](ignore), &schema{Type: "integer"}},
55+
{"int16", forType[int16](ignore), &schema{Type: "integer"}},
56+
{"uint32", forType[int16](ignore), &schema{Type: "integer"}},
57+
{"float64", forType[float64](ignore), &schema{Type: "number"}},
58+
{"bool", forType[bool](ignore), &schema{Type: "boolean"}},
59+
{"time", forType[time.Time](ignore), &schema{Type: "string"}},
60+
{"level", forType[slog.Level](ignore), &schema{Type: "string"}},
61+
{"bigint", forType[big.Int](ignore), &schema{Types: []string{"null", "string"}}},
62+
{"custom", forType[custom](ignore), &schema{Type: "custom"}},
63+
{"intmap", forType[map[string]int](ignore), &schema{
5264
Type: "object",
5365
AdditionalProperties: &schema{Type: "integer"},
5466
}},
55-
{"anymap", forType[map[string]any](lax), &schema{
67+
{"anymap", forType[map[string]any](ignore), &schema{
5668
Type: "object",
5769
AdditionalProperties: &schema{},
5870
}},
@@ -66,7 +78,7 @@ func TestFor(t *testing.T) {
6678
NoSkip string `json:",omitempty"`
6779
unexported float64
6880
unexported2 int `json:"No"`
69-
}](lax),
81+
}](ignore),
7082
&schema{
7183
Type: "object",
7284
Properties: map[string]*schema{
@@ -81,7 +93,7 @@ func TestFor(t *testing.T) {
8193
},
8294
{
8395
"no sharing",
84-
forType[struct{ X, Y int }](lax),
96+
forType[struct{ X, Y int }](ignore),
8597
&schema{
8698
Type: "object",
8799
Properties: map[string]*schema{
@@ -97,7 +109,7 @@ func TestFor(t *testing.T) {
97109
forType[struct {
98110
A S
99111
S
100-
}](lax),
112+
}](ignore),
101113
&schema{
102114
Type: "object",
103115
Properties: map[string]*schema{
@@ -165,7 +177,7 @@ func TestFor(t *testing.T) {
165177
}
166178

167179
func forErr[T any]() error {
168-
_, err := jsonschema.For[T]()
180+
_, err := jsonschema.For[T](nil)
169181
return err
170182
}
171183

@@ -209,7 +221,7 @@ func TestForWithMutation(t *testing.T) {
209221
D [3]S
210222
E *bool
211223
}
212-
s, err := jsonschema.For[T]()
224+
s, err := jsonschema.For[T](nil)
213225
if err != nil {
214226
t.Fatalf("For: %v", err)
215227
}
@@ -220,7 +232,7 @@ func TestForWithMutation(t *testing.T) {
220232
s.Properties["D"].MinItems = jsonschema.Ptr(10)
221233
s.Properties["E"].Types[0] = "mutated"
222234

223-
s2, err := jsonschema.For[T]()
235+
s2, err := jsonschema.For[T](nil)
224236
if err != nil {
225237
t.Fatalf("For: %v", err)
226238
}
@@ -266,13 +278,13 @@ func TestForWithCycle(t *testing.T) {
266278
shouldErr bool
267279
fn func() error
268280
}{
269-
{"slice alias (a)", true, func() error { _, err := jsonschema.For[a](); return err }},
270-
{"unexported self cycle (b1)", false, func() error { _, err := jsonschema.For[b1](); return err }},
271-
{"exported self cycle (b2)", true, func() error { _, err := jsonschema.For[b2](); return err }},
272-
{"unexported map self cycle (c1)", false, func() error { _, err := jsonschema.For[c1](); return err }},
273-
{"exported map self cycle (c2)", true, func() error { _, err := jsonschema.For[c2](); return err }},
274-
{"cross-cycle x -> y -> x", true, func() error { _, err := jsonschema.For[x](); return err }},
275-
{"cross-cycle y -> x -> y", true, func() error { _, err := jsonschema.For[y](); return err }},
281+
{"slice alias (a)", true, func() error { _, err := jsonschema.For[a](nil); return err }},
282+
{"unexported self cycle (b1)", false, func() error { _, err := jsonschema.For[b1](nil); return err }},
283+
{"exported self cycle (b2)", true, func() error { _, err := jsonschema.For[b2](nil); return err }},
284+
{"unexported map self cycle (c1)", false, func() error { _, err := jsonschema.For[c1](nil); return err }},
285+
{"exported map self cycle (c2)", true, func() error { _, err := jsonschema.For[c2](nil); return err }},
286+
{"cross-cycle x -> y -> x", true, func() error { _, err := jsonschema.For[x](nil); return err }},
287+
{"cross-cycle y -> x -> y", true, func() error { _, err := jsonschema.For[y](nil); return err }},
276288
}
277289

278290
for _, test := range tests {

jsonschema/schema.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,42 @@ func (s *Schema) String() string {
152152
return "<anonymous schema>"
153153
}
154154

155+
// CloneSchemas returns a copy of s.
156+
// The copy is shallow except for sub-schemas, which are themelves copied with CloneSchemas.
157+
// This allows both s and s.CloneSchemas() to appear as sub-schemas in the same parent.
158+
func (s *Schema) CloneSchemas() *Schema {
159+
if s == nil {
160+
return nil
161+
}
162+
s2 := *s
163+
v := reflect.ValueOf(&s2)
164+
for _, info := range schemaFieldInfos {
165+
fv := v.Elem().FieldByIndex(info.sf.Index)
166+
switch info.sf.Type {
167+
case schemaType:
168+
sscss := fv.Interface().(*Schema)
169+
fv.Set(reflect.ValueOf(sscss.CloneSchemas()))
170+
171+
case schemaSliceType:
172+
slice := fv.Interface().([]*Schema)
173+
slice = slices.Clone(slice)
174+
for i, ss := range slice {
175+
slice[i] = ss.CloneSchemas()
176+
}
177+
fv.Set(reflect.ValueOf(slice))
178+
179+
case schemaMapType:
180+
m := fv.Interface().(map[string]*Schema)
181+
m = maps.Clone(m)
182+
for k, ss := range m {
183+
m[k] = ss.CloneSchemas()
184+
}
185+
fv.Set(reflect.ValueOf(m))
186+
}
187+
}
188+
return &s2
189+
}
190+
155191
func (s *Schema) basicChecks() error {
156192
if s.Type != "" && s.Types != nil {
157193
return errors.New("both Type and Types are set; at most one should be")

0 commit comments

Comments
 (0)