Skip to content

Commit e27a36e

Browse files
jsonchema: add cycle detection (#97)
This PR add cycle detection where it return error instead of waiting for stack overflow. This is a fix for #77. The root cause of the issue was the type being processed is deeply nested and has cycle. Hence, the heap memory growth is faster than the stack growth, it went OOM before stack overflow panic can kick in.
1 parent 2b07560 commit e27a36e

File tree

2 files changed

+110
-5
lines changed

2 files changed

+110
-5
lines changed

jsonschema/infer.go

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,19 @@ import (
3737
// - unsafe pointers
3838
//
3939
// The types must not have cycles.
40+
// It will return an error if there is a cycle in the types.
4041
func For[T any]() (*Schema, error) {
4142
// TODO: consider skipping incompatible fields, instead of failing.
42-
s, err := forType(reflect.TypeFor[T]())
43+
seen := make(map[reflect.Type]bool)
44+
s, err := forType(reflect.TypeFor[T](), seen)
4345
if err != nil {
4446
var z T
4547
return nil, fmt.Errorf("For[%T](): %w", z, err)
4648
}
4749
return s, nil
4850
}
4951

50-
func forType(t reflect.Type) (*Schema, error) {
52+
func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) {
5153
// Follow pointers: the schema for *T is almost the same as for T, except that
5254
// an explicit JSON "null" is allowed for the pointer.
5355
allowNull := false
@@ -56,6 +58,16 @@ func forType(t reflect.Type) (*Schema, error) {
5658
t = t.Elem()
5759
}
5860

61+
// Check for cycles
62+
// User defined types have a name, so we can skip those that are natively defined
63+
if t.Name() != "" {
64+
if seen[t] {
65+
return nil, fmt.Errorf("cycle detected for type %v", t)
66+
}
67+
seen[t] = true
68+
defer delete(seen, t)
69+
}
70+
5971
var (
6072
s = new(Schema)
6173
err error
@@ -81,14 +93,14 @@ func forType(t reflect.Type) (*Schema, error) {
8193
return nil, fmt.Errorf("unsupported map key type %v", t.Key().Kind())
8294
}
8395
s.Type = "object"
84-
s.AdditionalProperties, err = forType(t.Elem())
96+
s.AdditionalProperties, err = forType(t.Elem(), seen)
8597
if err != nil {
8698
return nil, fmt.Errorf("computing map value schema: %v", err)
8799
}
88100

89101
case reflect.Slice, reflect.Array:
90102
s.Type = "array"
91-
s.Items, err = forType(t.Elem())
103+
s.Items, err = forType(t.Elem(), seen)
92104
if err != nil {
93105
return nil, fmt.Errorf("computing element schema: %v", err)
94106
}
@@ -114,7 +126,7 @@ func forType(t reflect.Type) (*Schema, error) {
114126
if s.Properties == nil {
115127
s.Properties = make(map[string]*Schema)
116128
}
117-
s.Properties[info.Name], err = forType(field.Type)
129+
s.Properties[info.Name], err = forType(field.Type, seen)
118130
if err != nil {
119131
return nil, err
120132
}

jsonschema/infer_test.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,96 @@ func TestForType(t *testing.T) {
9191
})
9292
}
9393
}
94+
95+
func TestForWithMutation(t *testing.T) {
96+
// This test ensures that the cached schema is not mutated when the caller
97+
// mutates the returned schema.
98+
type S struct {
99+
A int
100+
}
101+
type T struct {
102+
A int `json:"A"`
103+
B map[string]int
104+
C []S
105+
D [3]S
106+
E *bool
107+
}
108+
s, err := jsonschema.For[T]()
109+
if err != nil {
110+
t.Fatalf("For: %v", err)
111+
}
112+
s.Required[0] = "mutated"
113+
s.Properties["A"].Type = "mutated"
114+
s.Properties["C"].Items.Type = "mutated"
115+
s.Properties["D"].MaxItems = jsonschema.Ptr(10)
116+
s.Properties["D"].MinItems = jsonschema.Ptr(10)
117+
s.Properties["E"].Types[0] = "mutated"
118+
119+
s2, err := jsonschema.For[T]()
120+
if err != nil {
121+
t.Fatalf("For: %v", err)
122+
}
123+
if s2.Properties["A"].Type == "mutated" {
124+
t.Fatalf("ForWithMutation: expected A.Type to not be mutated")
125+
}
126+
if s2.Properties["B"].AdditionalProperties.Type == "mutated" {
127+
t.Fatalf("ForWithMutation: expected B.AdditionalProperties.Type to not be mutated")
128+
}
129+
if s2.Properties["C"].Items.Type == "mutated" {
130+
t.Fatalf("ForWithMutation: expected C.Items.Type to not be mutated")
131+
}
132+
if *s2.Properties["D"].MaxItems == 10 {
133+
t.Fatalf("ForWithMutation: expected D.MaxItems to not be mutated")
134+
}
135+
if *s2.Properties["D"].MinItems == 10 {
136+
t.Fatalf("ForWithMutation: expected D.MinItems to not be mutated")
137+
}
138+
if s2.Properties["E"].Types[0] == "mutated" {
139+
t.Fatalf("ForWithMutation: expected E.Types[0] to not be mutated")
140+
}
141+
if s2.Required[0] == "mutated" {
142+
t.Fatalf("ForWithMutation: expected Required[0] to not be mutated")
143+
}
144+
}
145+
146+
type x struct {
147+
Y y
148+
}
149+
type y struct {
150+
X []x
151+
}
152+
153+
func TestForWithCycle(t *testing.T) {
154+
type a []*a
155+
type b1 struct{ b *b1 } // unexported field should be skipped
156+
type b2 struct{ B *b2 }
157+
type c1 struct{ c map[string]*c1 } // unexported field should be skipped
158+
type c2 struct{ C map[string]*c2 }
159+
160+
tests := []struct {
161+
name string
162+
shouldErr bool
163+
fn func() error
164+
}{
165+
{"slice alias (a)", true, func() error { _, err := jsonschema.For[a](); return err }},
166+
{"unexported self cycle (b1)", false, func() error { _, err := jsonschema.For[b1](); return err }},
167+
{"exported self cycle (b2)", true, func() error { _, err := jsonschema.For[b2](); return err }},
168+
{"unexported map self cycle (c1)", false, func() error { _, err := jsonschema.For[c1](); return err }},
169+
{"exported map self cycle (c2)", true, func() error { _, err := jsonschema.For[c2](); return err }},
170+
{"cross-cycle x -> y -> x", true, func() error { _, err := jsonschema.For[x](); return err }},
171+
{"cross-cycle y -> x -> y", true, func() error { _, err := jsonschema.For[y](); return err }},
172+
}
173+
174+
for _, test := range tests {
175+
test := test // prevent loop shadowing
176+
t.Run(test.name, func(t *testing.T) {
177+
err := test.fn()
178+
if test.shouldErr && err == nil {
179+
t.Errorf("expected cycle error, got nil")
180+
}
181+
if !test.shouldErr && err != nil {
182+
t.Errorf("unexpected error: %v", err)
183+
}
184+
})
185+
}
186+
}

0 commit comments

Comments
 (0)