Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 69 additions & 6 deletions jsonschema/infer.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package jsonschema
import (
"fmt"
"reflect"
"sync"

"github.com/modelcontextprotocol/go-sdk/internal/util"
)
Expand Down Expand Up @@ -39,15 +40,18 @@ import (
// The types must not have cycles.
func For[T any]() (*Schema, error) {
// TODO: consider skipping incompatible fields, instead of failing.
s, err := forType(reflect.TypeFor[T]())
seen := make(map[reflect.Type]bool)
s, err := forType(reflect.TypeFor[T](), seen)
if err != nil {
var z T
return nil, fmt.Errorf("For[%T](): %w", z, err)
}
return s, nil
}

func forType(t reflect.Type) (*Schema, error) {
var typeSchema sync.Map // map[reflect.Type]*Schema

func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) {
// Follow pointers: the schema for *T is almost the same as for T, except that
// an explicit JSON "null" is allowed for the pointer.
allowNull := false
Expand All @@ -56,11 +60,23 @@ func forType(t reflect.Type) (*Schema, error) {
t = t.Elem()
}

if cachedS, ok := typeSchema.Load(t); ok {
s := deepCopySchema(cachedS.(*Schema))
adjustTypesForPointer(s, allowNull)
return s, nil
}

var (
s = new(Schema)
err error
)

if seen[t] {
return nil, fmt.Errorf("cycle detected for type %v", t)
}
seen[t] = true
defer delete(seen, t)

switch t.Kind() {
case reflect.Bool:
s.Type = "boolean"
Expand All @@ -81,14 +97,14 @@ func forType(t reflect.Type) (*Schema, error) {
return nil, fmt.Errorf("unsupported map key type %v", t.Key().Kind())
}
s.Type = "object"
s.AdditionalProperties, err = forType(t.Elem())
s.AdditionalProperties, err = forType(t.Elem(), seen)
if err != nil {
return nil, fmt.Errorf("computing map value schema: %v", err)
}

case reflect.Slice, reflect.Array:
s.Type = "array"
s.Items, err = forType(t.Elem())
s.Items, err = forType(t.Elem(), seen)
if err != nil {
return nil, fmt.Errorf("computing element schema: %v", err)
}
Expand All @@ -114,7 +130,7 @@ func forType(t reflect.Type) (*Schema, error) {
if s.Properties == nil {
s.Properties = make(map[string]*Schema)
}
s.Properties[info.Name], err = forType(field.Type)
s.Properties[info.Name], err = forType(field.Type, seen)
if err != nil {
return nil, err
}
Expand All @@ -126,9 +142,56 @@ func forType(t reflect.Type) (*Schema, error) {
default:
return nil, fmt.Errorf("type %v is unsupported by jsonschema", t)
}
typeSchema.Store(t, deepCopySchema(s))
adjustTypesForPointer(s, allowNull)
return s, nil
}

func adjustTypesForPointer(s *Schema, allowNull bool) {
if allowNull && s.Type != "" {
s.Types = []string{"null", s.Type}
s.Type = ""
}
return s, nil
}

// deepCopySchema makes a deep copy of a Schema.
// Only fields that are modified by forType are cloned.
func deepCopySchema(s *Schema) *Schema {
if s == nil {
return nil
}

clone := new(Schema)
clone.Type = s.Type

if s.Items != nil {
clone.Items = deepCopySchema(s.Items)
}
if s.AdditionalProperties != nil {
clone.AdditionalProperties = deepCopySchema(s.AdditionalProperties)
}
if s.MinItems != nil {
minItems := *s.MinItems
clone.MinItems = &minItems
}
if s.MaxItems != nil {
maxItems := *s.MaxItems
clone.MaxItems = &maxItems
}
if s.Types != nil {
clone.Types = make([]string, len(s.Types))
copy(clone.Types, s.Types)
}
if s.Required != nil {
clone.Required = make([]string, len(s.Required))
copy(clone.Required, s.Required)
}
if s.Properties != nil {
clone.Properties = make(map[string]*Schema)
for k, v := range s.Properties {
clone.Properties[k] = deepCopySchema(v)
}
}

return clone
}
65 changes: 65 additions & 0 deletions jsonschema/infer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,68 @@ func TestForType(t *testing.T) {
})
}
}

func TestForWithMutation(t *testing.T) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

context: I'm keeping this unit test, I think they are still good to have

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

// This test ensures that the cached schema is not mutated when the caller
// mutates the returned schema.
type S struct {
A int
}
type T struct {
A int `json:"A"`
B map[string]int
C []S
D [3]S
E *bool
}
s, err := jsonschema.For[T]()
if err != nil {
t.Fatalf("For: %v", err)
}
s.Required[0] = "mutated"
s.Properties["A"].Type = "mutated"
s.Properties["C"].Items.Type = "mutated"
s.Properties["D"].MaxItems = jsonschema.Ptr(10)
s.Properties["D"].MinItems = jsonschema.Ptr(10)
s.Properties["E"].Types[0] = "mutated"

s2, err := jsonschema.For[T]()
if err != nil {
t.Fatalf("For: %v", err)
}
if s2.Properties["A"].Type == "mutated" {
t.Fatalf("ForWithMutation: expected A.Type to not be mutated")
}
if s2.Properties["B"].AdditionalProperties.Type == "mutated" {
t.Fatalf("ForWithMutation: expected B.AdditionalProperties.Type to not be mutated")
}
if s2.Properties["C"].Items.Type == "mutated" {
t.Fatalf("ForWithMutation: expected C.Items.Type to not be mutated")
}
if *s2.Properties["D"].MaxItems == 10 {
t.Fatalf("ForWithMutation: expected D.MaxItems to not be mutated")
}
if *s2.Properties["D"].MinItems == 10 {
t.Fatalf("ForWithMutation: expected D.MinItems to not be mutated")
}
if s2.Properties["E"].Types[0] == "mutated" {
t.Fatalf("ForWithMutation: expected E.Types[0] to not be mutated")
}
if s2.Required[0] == "mutated" {
t.Fatalf("ForWithMutation: expected Required[0] to not be mutated")
}
}

type s struct {
A t
}
type t struct {
B []s
}

func TestForWithCycle(t *testing.T) {
_, err := jsonschema.For[s]()
if err == nil {
t.Fatalf("ForWithCycle: expected error, got nil")
}
}
Loading