Skip to content

Commit b674c5d

Browse files
remove map and keep cycle detection
1 parent de4dd4d commit b674c5d

File tree

2 files changed

+45
-68
lines changed

2 files changed

+45
-68
lines changed

jsonschema/infer.go

Lines changed: 10 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ package jsonschema
99
import (
1010
"fmt"
1111
"reflect"
12-
"sync"
1312

1413
"github.com/modelcontextprotocol/go-sdk/internal/util"
1514
)
@@ -38,6 +37,7 @@ import (
3837
// - unsafe pointers
3938
//
4039
// The types must not have cycles.
40+
// It will return an error if there is a cycle in the types.
4141
func For[T any]() (*Schema, error) {
4242
// TODO: consider skipping incompatible fields, instead of failing.
4343
seen := make(map[reflect.Type]bool)
@@ -49,8 +49,6 @@ func For[T any]() (*Schema, error) {
4949
return s, nil
5050
}
5151

52-
var typeSchema sync.Map // map[reflect.Type]*Schema
53-
5452
func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) {
5553
// Follow pointers: the schema for *T is almost the same as for T, except that
5654
// an explicit JSON "null" is allowed for the pointer.
@@ -60,23 +58,21 @@ func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) {
6058
t = t.Elem()
6159
}
6260

63-
if cachedS, ok := typeSchema.Load(t); ok {
64-
s := deepCopySchema(cachedS.(*Schema))
65-
adjustTypesForPointer(s, allowNull)
66-
return s, nil
61+
// Check for cycles
62+
// User defined types have a name, so we can skip those that are natively defined
63+
if t.Name() != "" {
64+
if seen[t] {
65+
return nil, fmt.Errorf("cycle detected for type %v", t)
66+
}
67+
seen[t] = true
68+
defer delete(seen, t)
6769
}
6870

6971
var (
7072
s = new(Schema)
7173
err error
7274
)
7375

74-
if seen[t] {
75-
return nil, fmt.Errorf("cycle detected for type %v", t)
76-
}
77-
seen[t] = true
78-
defer delete(seen, t)
79-
8076
switch t.Kind() {
8177
case reflect.Bool:
8278
s.Type = "boolean"
@@ -142,56 +138,9 @@ func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) {
142138
default:
143139
return nil, fmt.Errorf("type %v is unsupported by jsonschema", t)
144140
}
145-
typeSchema.Store(t, deepCopySchema(s))
146-
adjustTypesForPointer(s, allowNull)
147-
return s, nil
148-
}
149-
150-
func adjustTypesForPointer(s *Schema, allowNull bool) {
151141
if allowNull && s.Type != "" {
152142
s.Types = []string{"null", s.Type}
153143
s.Type = ""
154144
}
155-
}
156-
157-
// deepCopySchema makes a deep copy of a Schema.
158-
// Only fields that are modified by forType are cloned.
159-
func deepCopySchema(s *Schema) *Schema {
160-
if s == nil {
161-
return nil
162-
}
163-
164-
clone := new(Schema)
165-
clone.Type = s.Type
166-
167-
if s.Items != nil {
168-
clone.Items = deepCopySchema(s.Items)
169-
}
170-
if s.AdditionalProperties != nil {
171-
clone.AdditionalProperties = deepCopySchema(s.AdditionalProperties)
172-
}
173-
if s.MinItems != nil {
174-
minItems := *s.MinItems
175-
clone.MinItems = &minItems
176-
}
177-
if s.MaxItems != nil {
178-
maxItems := *s.MaxItems
179-
clone.MaxItems = &maxItems
180-
}
181-
if s.Types != nil {
182-
clone.Types = make([]string, len(s.Types))
183-
copy(clone.Types, s.Types)
184-
}
185-
if s.Required != nil {
186-
clone.Required = make([]string, len(s.Required))
187-
copy(clone.Required, s.Required)
188-
}
189-
if s.Properties != nil {
190-
clone.Properties = make(map[string]*Schema)
191-
for k, v := range s.Properties {
192-
clone.Properties[k] = deepCopySchema(v)
193-
}
194-
}
195-
196-
return clone
145+
return s, nil
197146
}

jsonschema/infer_test.go

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,16 +143,44 @@ func TestForWithMutation(t *testing.T) {
143143
}
144144
}
145145

146-
type s struct {
147-
A t
146+
type x struct {
147+
Y y
148148
}
149-
type t struct {
150-
B []s
149+
type y struct {
150+
X []x
151151
}
152152

153153
func TestForWithCycle(t *testing.T) {
154-
_, err := jsonschema.For[s]()
155-
if err == nil {
156-
t.Fatalf("ForWithCycle: expected error, got nil")
154+
type a []*a
155+
type b1 struct{ b *b1 } // unexported field should be skipped
156+
type b2 struct{ B *b2 }
157+
type c1 struct{ c map[string]*c1 } // unexported field should be skipped
158+
type c2 struct{ C map[string]*c2 }
159+
160+
tests := []struct {
161+
name string
162+
shouldErr bool
163+
fn func() error
164+
}{
165+
{"slice alias (a)", true, func() error { _, err := jsonschema.For[a](); return err }},
166+
{"unexported self cycle (b1)", false, func() error { _, err := jsonschema.For[b1](); return err }},
167+
{"exported self cycle (b2)", true, func() error { _, err := jsonschema.For[b2](); return err }},
168+
{"unexported map self cycle (c1)", false, func() error { _, err := jsonschema.For[c1](); return err }},
169+
{"exported map self cycle (c2)", true, func() error { _, err := jsonschema.For[c2](); return err }},
170+
{"cross-cycle x -> y -> x", true, func() error { _, err := jsonschema.For[x](); return err }},
171+
{"cross-cycle y -> x -> y", true, func() error { _, err := jsonschema.For[y](); return err }},
172+
}
173+
174+
for _, test := range tests {
175+
test := test // prevent loop shadowing
176+
t.Run(test.name, func(t *testing.T) {
177+
err := test.fn()
178+
if test.shouldErr && err == nil {
179+
t.Errorf("expected cycle error, got nil")
180+
}
181+
if !test.shouldErr && err != nil {
182+
t.Errorf("unexpected error: %v", err)
183+
}
184+
})
157185
}
158186
}

0 commit comments

Comments
 (0)