Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions jsonschema/infer.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package jsonschema
import (
"fmt"
"reflect"
"regexp"

"github.com/modelcontextprotocol/go-sdk/internal/util"
)
Expand Down Expand Up @@ -36,8 +37,12 @@ import (
// - complex numbers
// - unsafe pointers
//
// The types must not have cycles.
// It will return an error if there is a cycle in the types.
//
// For recognizes struct field tags named "jsonschema".
// A jsonschema tag on a field is used as the description for the corresponding property.
// For future compatibility, descriptions must not start with "WORD=", where WORD is a
// sequence of non-whitespace characters.
func For[T any]() (*Schema, error) {
// TODO: consider skipping incompatible fields, instead of failing.
seen := make(map[reflect.Type]bool)
Expand Down Expand Up @@ -126,10 +131,20 @@ func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) {
if s.Properties == nil {
s.Properties = make(map[string]*Schema)
}
s.Properties[info.Name], err = forType(field.Type, seen)
fs, err := forType(field.Type, seen)
if err != nil {
return nil, err
}
if tag, ok := field.Tag.Lookup("jsonschema"); ok {
if tag == "" {
return nil, fmt.Errorf("empty jsonschema tag on struct field %s.%s", t, field.Name)
}
if disallowedPrefixRegexp.MatchString(tag) {
return nil, fmt.Errorf("tag must not begin with 'WORD=': %q", tag)
}
fs.Description = tag
}
s.Properties[info.Name] = fs
if !info.Settings["omitempty"] && !info.Settings["omitzero"] {
s.Required = append(s.Required, info.Name)
}
Expand All @@ -144,3 +159,6 @@ func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) {
}
return s, nil
}

// Disallow jsonschema tag values beginning "WORD=", for future expansion.
var disallowedPrefixRegexp = regexp.MustCompile("^[^ \t\n]*=")
87 changes: 79 additions & 8 deletions jsonschema/infer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package jsonschema_test

import (
"strings"
"testing"

"github.com/google/go-cmp/cmp"
Expand All @@ -20,8 +21,13 @@ func forType[T any]() *jsonschema.Schema {
return s
}

func TestForType(t *testing.T) {
func TestFor(t *testing.T) {
type schema = jsonschema.Schema

type S struct {
B int `jsonschema:"bdesc"`
}

tests := []struct {
name string
got *jsonschema.Schema
Expand All @@ -44,9 +50,9 @@ func TestForType(t *testing.T) {
{
"struct",
forType[struct {
F int `json:"f"`
F int `json:"f" jsonschema:"fdesc"`
G []float64
P *bool
P *bool `jsonschema:"pdesc"`
Skip string `json:"-"`
NoSkip string `json:",omitempty"`
unexported float64
Expand All @@ -55,13 +61,13 @@ func TestForType(t *testing.T) {
&schema{
Type: "object",
Properties: map[string]*schema{
"f": {Type: "integer"},
"f": {Type: "integer", Description: "fdesc"},
"G": {Type: "array", Items: &schema{Type: "number"}},
"P": {Types: []string{"null", "boolean"}},
"P": {Types: []string{"null", "boolean"}, Description: "pdesc"},
"NoSkip": {Type: "string"},
},
Required: []string{"f", "G", "P"},
AdditionalProperties: &jsonschema.Schema{Not: &jsonschema.Schema{}},
AdditionalProperties: falseSchema(),
},
},
{
Expand All @@ -74,7 +80,37 @@ func TestForType(t *testing.T) {
"Y": {Type: "integer"},
},
Required: []string{"X", "Y"},
AdditionalProperties: &jsonschema.Schema{Not: &jsonschema.Schema{}},
AdditionalProperties: falseSchema(),
},
},
{
"nested and embedded",
forType[struct {
A S
S
}](),
&schema{
Type: "object",
Properties: map[string]*schema{
"A": {
Type: "object",
Properties: map[string]*schema{
"B": {Type: "integer", Description: "bdesc"},
},
Required: []string{"B"},
AdditionalProperties: falseSchema(),
},
"S": {
Type: "object",
Properties: map[string]*schema{
"B": {Type: "integer", Description: "bdesc"},
},
Required: []string{"B"},
AdditionalProperties: falseSchema(),
},
},
Required: []string{"A", "S"},
AdditionalProperties: falseSchema(),
},
},
}
Expand All @@ -92,6 +128,38 @@ func TestForType(t *testing.T) {
}
}

func forErr[T any]() error {
_, err := jsonschema.For[T]()
return err
}

func TestForErrors(t *testing.T) {
type (
s1 struct {
Empty int `jsonschema:""`
}
s2 struct {
Bad int `jsonschema:"$foo=1,bar"`
}
)

for _, tt := range []struct {
got error
want string
}{
{forErr[map[int]int](), "unsupported map key type"},
{forErr[s1](), "empty jsonschema tag"},
{forErr[s2](), "must not begin with"},
{forErr[func()](), "unsupported"},
} {
if tt.got == nil {
t.Errorf("got nil, want error containing %q", tt.want)
} else if !strings.Contains(tt.got.Error(), tt.want) {
t.Errorf("got %q\nwant it to contain %q", tt.got, tt.want)
}
}
}

func TestForWithMutation(t *testing.T) {
// This test ensures that the cached schema is not mutated when the caller
// mutates the returned schema.
Expand Down Expand Up @@ -172,7 +240,6 @@ func TestForWithCycle(t *testing.T) {
}

for _, test := range tests {
test := test // prevent loop shadowing
t.Run(test.name, func(t *testing.T) {
err := test.fn()
if test.shouldErr && err == nil {
Expand All @@ -184,3 +251,7 @@ func TestForWithCycle(t *testing.T) {
})
}
}

func falseSchema() *jsonschema.Schema {
return &jsonschema.Schema{Not: &jsonschema.Schema{}}
}
Loading