Skip to content

Commit 81f975c

Browse files
jsonchema: add memoization and cycle detection (#77)
1 parent 328a25d commit 81f975c

File tree

2 files changed

+120
-6
lines changed

2 files changed

+120
-6
lines changed

jsonschema/infer.go

Lines changed: 69 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ package jsonschema
99
import (
1010
"fmt"
1111
"reflect"
12+
"sync"
1213

1314
"github.com/modelcontextprotocol/go-sdk/internal/util"
1415
)
@@ -39,15 +40,18 @@ import (
3940
// The types must not have cycles.
4041
func For[T any]() (*Schema, error) {
4142
// TODO: consider skipping incompatible fields, instead of failing.
42-
s, err := forType(reflect.TypeFor[T]())
43+
seen := make(map[reflect.Type]bool)
44+
s, err := forType(reflect.TypeFor[T](), seen)
4345
if err != nil {
4446
var z T
4547
return nil, fmt.Errorf("For[%T](): %w", z, err)
4648
}
4749
return s, nil
4850
}
4951

50-
func forType(t reflect.Type) (*Schema, error) {
52+
var typeSchema sync.Map // map[reflect.Type]*Schema
53+
54+
func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) {
5155
// Follow pointers: the schema for *T is almost the same as for T, except that
5256
// an explicit JSON "null" is allowed for the pointer.
5357
allowNull := false
@@ -56,11 +60,23 @@ func forType(t reflect.Type) (*Schema, error) {
5660
t = t.Elem()
5761
}
5862

63+
if cachedS, ok := typeSchema.Load(t); ok {
64+
s := deepCopySchema(cachedS.(*Schema))
65+
adjustTypesForPointer(s, allowNull)
66+
return s, nil
67+
}
68+
5969
var (
6070
s = new(Schema)
6171
err error
6272
)
6373

74+
if seen[t] {
75+
return nil, fmt.Errorf("cycle detected for type %v", t)
76+
}
77+
seen[t] = true
78+
defer delete(seen, t)
79+
6480
switch t.Kind() {
6581
case reflect.Bool:
6682
s.Type = "boolean"
@@ -81,14 +97,14 @@ func forType(t reflect.Type) (*Schema, error) {
8197
return nil, fmt.Errorf("unsupported map key type %v", t.Key().Kind())
8298
}
8399
s.Type = "object"
84-
s.AdditionalProperties, err = forType(t.Elem())
100+
s.AdditionalProperties, err = forType(t.Elem(), seen)
85101
if err != nil {
86102
return nil, fmt.Errorf("computing map value schema: %v", err)
87103
}
88104

89105
case reflect.Slice, reflect.Array:
90106
s.Type = "array"
91-
s.Items, err = forType(t.Elem())
107+
s.Items, err = forType(t.Elem(), seen)
92108
if err != nil {
93109
return nil, fmt.Errorf("computing element schema: %v", err)
94110
}
@@ -114,7 +130,7 @@ func forType(t reflect.Type) (*Schema, error) {
114130
if s.Properties == nil {
115131
s.Properties = make(map[string]*Schema)
116132
}
117-
s.Properties[info.Name], err = forType(field.Type)
133+
s.Properties[info.Name], err = forType(field.Type, seen)
118134
if err != nil {
119135
return nil, err
120136
}
@@ -126,9 +142,56 @@ func forType(t reflect.Type) (*Schema, error) {
126142
default:
127143
return nil, fmt.Errorf("type %v is unsupported by jsonschema", t)
128144
}
145+
typeSchema.Store(t, deepCopySchema(s))
146+
adjustTypesForPointer(s, allowNull)
147+
return s, nil
148+
}
149+
150+
func adjustTypesForPointer(s *Schema, allowNull bool) {
129151
if allowNull && s.Type != "" {
130152
s.Types = []string{"null", s.Type}
131153
s.Type = ""
132154
}
133-
return s, nil
155+
}
156+
157+
// deepCopySchema makes a deep copy of a Schema.
158+
// Only fields that are modified by forType are cloned.
159+
func deepCopySchema(s *Schema) *Schema {
160+
if s == nil {
161+
return nil
162+
}
163+
164+
clone := new(Schema)
165+
clone.Type = s.Type
166+
167+
if s.Items != nil {
168+
clone.Items = deepCopySchema(s.Items)
169+
}
170+
if s.AdditionalProperties != nil {
171+
clone.AdditionalProperties = deepCopySchema(s.AdditionalProperties)
172+
}
173+
if s.MinItems != nil {
174+
minItems := *s.MinItems
175+
clone.MinItems = &minItems
176+
}
177+
if s.MaxItems != nil {
178+
maxItems := *s.MaxItems
179+
clone.MaxItems = &maxItems
180+
}
181+
if s.Types != nil {
182+
clone.Types = make([]string, len(s.Types))
183+
copy(clone.Types, s.Types)
184+
}
185+
if s.Required != nil {
186+
clone.Required = make([]string, len(s.Required))
187+
copy(clone.Required, s.Required)
188+
}
189+
if s.Properties != nil {
190+
clone.Properties = make(map[string]*Schema)
191+
for k, v := range s.Properties {
192+
clone.Properties[k] = deepCopySchema(v)
193+
}
194+
}
195+
196+
return clone
134197
}

jsonschema/infer_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,54 @@ func TestForType(t *testing.T) {
9191
})
9292
}
9393
}
94+
95+
func TestForWithMutation(t *testing.T) {
96+
// This test ensures that the cached schema is not mutated when the caller
97+
// mutates the returned schema.
98+
type S struct {
99+
A int
100+
}
101+
type T struct {
102+
A int `json:"A"`
103+
B map[string]int
104+
C []S
105+
D [3]S
106+
E *bool
107+
}
108+
s, err := jsonschema.For[T]()
109+
if err != nil {
110+
t.Fatalf("For: %v", err)
111+
}
112+
s.Required[0] = "mutated"
113+
s.Properties["A"].Type = "mutated"
114+
s.Properties["C"].Items.Type = "mutated"
115+
s.Properties["D"].MaxItems = jsonschema.Ptr(10)
116+
s.Properties["D"].MinItems = jsonschema.Ptr(10)
117+
s.Properties["E"].Types[0] = "mutated"
118+
119+
s2, err := jsonschema.For[T]()
120+
if err != nil {
121+
t.Fatalf("For: %v", err)
122+
}
123+
if s2.Properties["A"].Type == "mutated" {
124+
t.Fatalf("ForWithMutation: expected A.Type to not be mutated")
125+
}
126+
if s2.Properties["B"].AdditionalProperties.Type == "mutated" {
127+
t.Fatalf("ForWithMutation: expected B.AdditionalProperties.Type to not be mutated")
128+
}
129+
if s2.Properties["C"].Items.Type == "mutated" {
130+
t.Fatalf("ForWithMutation: expected C.Items.Type to not be mutated")
131+
}
132+
if *s2.Properties["D"].MaxItems == 10 {
133+
t.Fatalf("ForWithMutation: expected D.MaxItems to not be mutated")
134+
}
135+
if *s2.Properties["D"].MinItems == 10 {
136+
t.Fatalf("ForWithMutation: expected D.MinItems to not be mutated")
137+
}
138+
if s2.Properties["E"].Types[0] == "mutated" {
139+
t.Fatalf("ForWithMutation: expected E.Types[0] to not be mutated")
140+
}
141+
if s2.Required[0] == "mutated" {
142+
t.Fatalf("ForWithMutation: expected Required[0] to not be mutated")
143+
}
144+
}

0 commit comments

Comments
 (0)