diff --git a/jsonschema/infer.go b/jsonschema/infer.go index d1c1a5fb..ae441291 100644 --- a/jsonschema/infer.go +++ b/jsonschema/infer.go @@ -9,6 +9,7 @@ package jsonschema import ( "fmt" "reflect" + "regexp" "github.com/modelcontextprotocol/go-sdk/internal/util" ) @@ -36,8 +37,12 @@ import ( // - complex numbers // - unsafe pointers // -// The types must not have cycles. // It will return an error if there is a cycle in the types. +// +// For recognizes struct field tags named "jsonschema". +// A jsonschema tag on a field is used as the description for the corresponding property. +// For future compatibility, descriptions must not start with "WORD=", where WORD is a +// sequence of non-whitespace characters. func For[T any]() (*Schema, error) { // TODO: consider skipping incompatible fields, instead of failing. seen := make(map[reflect.Type]bool) @@ -126,10 +131,20 @@ func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) { if s.Properties == nil { s.Properties = make(map[string]*Schema) } - s.Properties[info.Name], err = forType(field.Type, seen) + fs, err := forType(field.Type, seen) if err != nil { return nil, err } + if tag, ok := field.Tag.Lookup("jsonschema"); ok { + if tag == "" { + return nil, fmt.Errorf("empty jsonschema tag on struct field %s.%s", t, field.Name) + } + if disallowedPrefixRegexp.MatchString(tag) { + return nil, fmt.Errorf("tag must not begin with 'WORD=': %q", tag) + } + fs.Description = tag + } + s.Properties[info.Name] = fs if !info.Settings["omitempty"] && !info.Settings["omitzero"] { s.Required = append(s.Required, info.Name) } @@ -144,3 +159,6 @@ func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) { } return s, nil } + +// Disallow jsonschema tag values beginning "WORD=", for future expansion. +var disallowedPrefixRegexp = regexp.MustCompile("^[^ \t\n]*=") diff --git a/jsonschema/infer_test.go b/jsonschema/infer_test.go index 0b1b769a..106e5375 100644 --- a/jsonschema/infer_test.go +++ b/jsonschema/infer_test.go @@ -5,6 +5,7 @@ package jsonschema_test import ( + "strings" "testing" "github.com/google/go-cmp/cmp" @@ -20,8 +21,13 @@ func forType[T any]() *jsonschema.Schema { return s } -func TestForType(t *testing.T) { +func TestFor(t *testing.T) { type schema = jsonschema.Schema + + type S struct { + B int `jsonschema:"bdesc"` + } + tests := []struct { name string got *jsonschema.Schema @@ -44,9 +50,9 @@ func TestForType(t *testing.T) { { "struct", forType[struct { - F int `json:"f"` + F int `json:"f" jsonschema:"fdesc"` G []float64 - P *bool + P *bool `jsonschema:"pdesc"` Skip string `json:"-"` NoSkip string `json:",omitempty"` unexported float64 @@ -55,13 +61,13 @@ func TestForType(t *testing.T) { &schema{ Type: "object", Properties: map[string]*schema{ - "f": {Type: "integer"}, + "f": {Type: "integer", Description: "fdesc"}, "G": {Type: "array", Items: &schema{Type: "number"}}, - "P": {Types: []string{"null", "boolean"}}, + "P": {Types: []string{"null", "boolean"}, Description: "pdesc"}, "NoSkip": {Type: "string"}, }, Required: []string{"f", "G", "P"}, - AdditionalProperties: &jsonschema.Schema{Not: &jsonschema.Schema{}}, + AdditionalProperties: falseSchema(), }, }, { @@ -74,7 +80,37 @@ func TestForType(t *testing.T) { "Y": {Type: "integer"}, }, Required: []string{"X", "Y"}, - AdditionalProperties: &jsonschema.Schema{Not: &jsonschema.Schema{}}, + AdditionalProperties: falseSchema(), + }, + }, + { + "nested and embedded", + forType[struct { + A S + S + }](), + &schema{ + Type: "object", + Properties: map[string]*schema{ + "A": { + Type: "object", + Properties: map[string]*schema{ + "B": {Type: "integer", Description: "bdesc"}, + }, + Required: []string{"B"}, + AdditionalProperties: falseSchema(), + }, + "S": { + Type: "object", + Properties: map[string]*schema{ + "B": {Type: "integer", Description: "bdesc"}, + }, + Required: []string{"B"}, + AdditionalProperties: falseSchema(), + }, + }, + Required: []string{"A", "S"}, + AdditionalProperties: falseSchema(), }, }, } @@ -92,6 +128,38 @@ func TestForType(t *testing.T) { } } +func forErr[T any]() error { + _, err := jsonschema.For[T]() + return err +} + +func TestForErrors(t *testing.T) { + type ( + s1 struct { + Empty int `jsonschema:""` + } + s2 struct { + Bad int `jsonschema:"$foo=1,bar"` + } + ) + + for _, tt := range []struct { + got error + want string + }{ + {forErr[map[int]int](), "unsupported map key type"}, + {forErr[s1](), "empty jsonschema tag"}, + {forErr[s2](), "must not begin with"}, + {forErr[func()](), "unsupported"}, + } { + if tt.got == nil { + t.Errorf("got nil, want error containing %q", tt.want) + } else if !strings.Contains(tt.got.Error(), tt.want) { + t.Errorf("got %q\nwant it to contain %q", tt.got, tt.want) + } + } +} + func TestForWithMutation(t *testing.T) { // This test ensures that the cached schema is not mutated when the caller // mutates the returned schema. @@ -172,7 +240,6 @@ func TestForWithCycle(t *testing.T) { } for _, test := range tests { - test := test // prevent loop shadowing t.Run(test.name, func(t *testing.T) { err := test.fn() if test.shouldErr && err == nil { @@ -184,3 +251,7 @@ func TestForWithCycle(t *testing.T) { }) } } + +func falseSchema() *jsonschema.Schema { + return &jsonschema.Schema{Not: &jsonschema.Schema{}} +}