Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions jsonschema/infer.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,19 @@ import (
// - unsafe pointers
//
// The types must not have cycles.
// It will return an error if there is a cycle in the types.
func For[T any]() (*Schema, error) {
// TODO: consider skipping incompatible fields, instead of failing.
s, err := forType(reflect.TypeFor[T]())
seen := make(map[reflect.Type]bool)
s, err := forType(reflect.TypeFor[T](), seen)
if err != nil {
var z T
return nil, fmt.Errorf("For[%T](): %w", z, err)
}
return s, nil
}

func forType(t reflect.Type) (*Schema, error) {
func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) {
// Follow pointers: the schema for *T is almost the same as for T, except that
// an explicit JSON "null" is allowed for the pointer.
allowNull := false
Expand All @@ -56,6 +58,16 @@ func forType(t reflect.Type) (*Schema, error) {
t = t.Elem()
}

// Check for cycles
// User defined types have a name, so we can skip those that are natively defined
if t.Name() != "" {
if seen[t] {
return nil, fmt.Errorf("cycle detected for type %v", t)
}
seen[t] = true
defer delete(seen, t)
}

var (
s = new(Schema)
err error
Expand All @@ -81,14 +93,14 @@ func forType(t reflect.Type) (*Schema, error) {
return nil, fmt.Errorf("unsupported map key type %v", t.Key().Kind())
}
s.Type = "object"
s.AdditionalProperties, err = forType(t.Elem())
s.AdditionalProperties, err = forType(t.Elem(), seen)
if err != nil {
return nil, fmt.Errorf("computing map value schema: %v", err)
}

case reflect.Slice, reflect.Array:
s.Type = "array"
s.Items, err = forType(t.Elem())
s.Items, err = forType(t.Elem(), seen)
if err != nil {
return nil, fmt.Errorf("computing element schema: %v", err)
}
Expand All @@ -114,7 +126,7 @@ func forType(t reflect.Type) (*Schema, error) {
if s.Properties == nil {
s.Properties = make(map[string]*Schema)
}
s.Properties[info.Name], err = forType(field.Type)
s.Properties[info.Name], err = forType(field.Type, seen)
if err != nil {
return nil, err
}
Expand Down
93 changes: 93 additions & 0 deletions jsonschema/infer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,96 @@ func TestForType(t *testing.T) {
})
}
}

func TestForWithMutation(t *testing.T) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

context: I'm keeping this unit test, I think they are still good to have

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

// This test ensures that the cached schema is not mutated when the caller
// mutates the returned schema.
type S struct {
A int
}
type T struct {
A int `json:"A"`
B map[string]int
C []S
D [3]S
E *bool
}
s, err := jsonschema.For[T]()
if err != nil {
t.Fatalf("For: %v", err)
}
s.Required[0] = "mutated"
s.Properties["A"].Type = "mutated"
s.Properties["C"].Items.Type = "mutated"
s.Properties["D"].MaxItems = jsonschema.Ptr(10)
s.Properties["D"].MinItems = jsonschema.Ptr(10)
s.Properties["E"].Types[0] = "mutated"

s2, err := jsonschema.For[T]()
if err != nil {
t.Fatalf("For: %v", err)
}
if s2.Properties["A"].Type == "mutated" {
t.Fatalf("ForWithMutation: expected A.Type to not be mutated")
}
if s2.Properties["B"].AdditionalProperties.Type == "mutated" {
t.Fatalf("ForWithMutation: expected B.AdditionalProperties.Type to not be mutated")
}
if s2.Properties["C"].Items.Type == "mutated" {
t.Fatalf("ForWithMutation: expected C.Items.Type to not be mutated")
}
if *s2.Properties["D"].MaxItems == 10 {
t.Fatalf("ForWithMutation: expected D.MaxItems to not be mutated")
}
if *s2.Properties["D"].MinItems == 10 {
t.Fatalf("ForWithMutation: expected D.MinItems to not be mutated")
}
if s2.Properties["E"].Types[0] == "mutated" {
t.Fatalf("ForWithMutation: expected E.Types[0] to not be mutated")
}
if s2.Required[0] == "mutated" {
t.Fatalf("ForWithMutation: expected Required[0] to not be mutated")
}
}

type x struct {
Y y
}
type y struct {
X []x
}

func TestForWithCycle(t *testing.T) {
type a []*a
type b1 struct{ b *b1 } // unexported field should be skipped
type b2 struct{ B *b2 }
type c1 struct{ c map[string]*c1 } // unexported field should be skipped
type c2 struct{ C map[string]*c2 }

tests := []struct {
name string
shouldErr bool
fn func() error
}{
{"slice alias (a)", true, func() error { _, err := jsonschema.For[a](); return err }},
{"unexported self cycle (b1)", false, func() error { _, err := jsonschema.For[b1](); return err }},
{"exported self cycle (b2)", true, func() error { _, err := jsonschema.For[b2](); return err }},
{"unexported map self cycle (c1)", false, func() error { _, err := jsonschema.For[c1](); return err }},
{"exported map self cycle (c2)", true, func() error { _, err := jsonschema.For[c2](); return err }},
{"cross-cycle x -> y -> x", true, func() error { _, err := jsonschema.For[x](); return err }},
{"cross-cycle y -> x -> y", true, func() error { _, err := jsonschema.For[y](); return err }},
}

for _, test := range tests {
test := test // prevent loop shadowing
t.Run(test.name, func(t *testing.T) {
err := test.fn()
if test.shouldErr && err == nil {
t.Errorf("expected cycle error, got nil")
}
if !test.shouldErr && err != nil {
t.Errorf("unexpected error: %v", err)
}
})
}
}
Loading