diff --git a/jsonschema/infer.go b/jsonschema/infer.go index 1334bdf1..d1c1a5fb 100644 --- a/jsonschema/infer.go +++ b/jsonschema/infer.go @@ -37,9 +37,11 @@ import ( // - unsafe pointers // // The types must not have cycles. +// It will return an error if there is a cycle in the types. func For[T any]() (*Schema, error) { // TODO: consider skipping incompatible fields, instead of failing. - s, err := forType(reflect.TypeFor[T]()) + seen := make(map[reflect.Type]bool) + s, err := forType(reflect.TypeFor[T](), seen) if err != nil { var z T return nil, fmt.Errorf("For[%T](): %w", z, err) @@ -47,7 +49,7 @@ func For[T any]() (*Schema, error) { return s, nil } -func forType(t reflect.Type) (*Schema, error) { +func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) { // Follow pointers: the schema for *T is almost the same as for T, except that // an explicit JSON "null" is allowed for the pointer. allowNull := false @@ -56,6 +58,16 @@ func forType(t reflect.Type) (*Schema, error) { t = t.Elem() } + // Check for cycles + // User defined types have a name, so we can skip those that are natively defined + if t.Name() != "" { + if seen[t] { + return nil, fmt.Errorf("cycle detected for type %v", t) + } + seen[t] = true + defer delete(seen, t) + } + var ( s = new(Schema) err error @@ -81,14 +93,14 @@ func forType(t reflect.Type) (*Schema, error) { return nil, fmt.Errorf("unsupported map key type %v", t.Key().Kind()) } s.Type = "object" - s.AdditionalProperties, err = forType(t.Elem()) + s.AdditionalProperties, err = forType(t.Elem(), seen) if err != nil { return nil, fmt.Errorf("computing map value schema: %v", err) } case reflect.Slice, reflect.Array: s.Type = "array" - s.Items, err = forType(t.Elem()) + s.Items, err = forType(t.Elem(), seen) if err != nil { return nil, fmt.Errorf("computing element schema: %v", err) } @@ -114,7 +126,7 @@ func forType(t reflect.Type) (*Schema, error) { if s.Properties == nil { s.Properties = make(map[string]*Schema) } - s.Properties[info.Name], err = forType(field.Type) + s.Properties[info.Name], err = forType(field.Type, seen) if err != nil { return nil, err } diff --git a/jsonschema/infer_test.go b/jsonschema/infer_test.go index 9325b832..0b1b769a 100644 --- a/jsonschema/infer_test.go +++ b/jsonschema/infer_test.go @@ -91,3 +91,96 @@ func TestForType(t *testing.T) { }) } } + +func TestForWithMutation(t *testing.T) { + // This test ensures that the cached schema is not mutated when the caller + // mutates the returned schema. + type S struct { + A int + } + type T struct { + A int `json:"A"` + B map[string]int + C []S + D [3]S + E *bool + } + s, err := jsonschema.For[T]() + if err != nil { + t.Fatalf("For: %v", err) + } + s.Required[0] = "mutated" + s.Properties["A"].Type = "mutated" + s.Properties["C"].Items.Type = "mutated" + s.Properties["D"].MaxItems = jsonschema.Ptr(10) + s.Properties["D"].MinItems = jsonschema.Ptr(10) + s.Properties["E"].Types[0] = "mutated" + + s2, err := jsonschema.For[T]() + if err != nil { + t.Fatalf("For: %v", err) + } + if s2.Properties["A"].Type == "mutated" { + t.Fatalf("ForWithMutation: expected A.Type to not be mutated") + } + if s2.Properties["B"].AdditionalProperties.Type == "mutated" { + t.Fatalf("ForWithMutation: expected B.AdditionalProperties.Type to not be mutated") + } + if s2.Properties["C"].Items.Type == "mutated" { + t.Fatalf("ForWithMutation: expected C.Items.Type to not be mutated") + } + if *s2.Properties["D"].MaxItems == 10 { + t.Fatalf("ForWithMutation: expected D.MaxItems to not be mutated") + } + if *s2.Properties["D"].MinItems == 10 { + t.Fatalf("ForWithMutation: expected D.MinItems to not be mutated") + } + if s2.Properties["E"].Types[0] == "mutated" { + t.Fatalf("ForWithMutation: expected E.Types[0] to not be mutated") + } + if s2.Required[0] == "mutated" { + t.Fatalf("ForWithMutation: expected Required[0] to not be mutated") + } +} + +type x struct { + Y y +} +type y struct { + X []x +} + +func TestForWithCycle(t *testing.T) { + type a []*a + type b1 struct{ b *b1 } // unexported field should be skipped + type b2 struct{ B *b2 } + type c1 struct{ c map[string]*c1 } // unexported field should be skipped + type c2 struct{ C map[string]*c2 } + + tests := []struct { + name string + shouldErr bool + fn func() error + }{ + {"slice alias (a)", true, func() error { _, err := jsonschema.For[a](); return err }}, + {"unexported self cycle (b1)", false, func() error { _, err := jsonschema.For[b1](); return err }}, + {"exported self cycle (b2)", true, func() error { _, err := jsonschema.For[b2](); return err }}, + {"unexported map self cycle (c1)", false, func() error { _, err := jsonschema.For[c1](); return err }}, + {"exported map self cycle (c2)", true, func() error { _, err := jsonschema.For[c2](); return err }}, + {"cross-cycle x -> y -> x", true, func() error { _, err := jsonschema.For[x](); return err }}, + {"cross-cycle y -> x -> y", true, func() error { _, err := jsonschema.For[y](); return err }}, + } + + for _, test := range tests { + test := test // prevent loop shadowing + t.Run(test.name, func(t *testing.T) { + err := test.fn() + if test.shouldErr && err == nil { + t.Errorf("expected cycle error, got nil") + } + if !test.shouldErr && err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + } +}