Skip to content

Commit bfb9215

Browse files
jsonchema: support additional keywords
Support additional (extra) keywords beyond those specified in the `Schema` type. The current behavior is additional keywords are dropped. To follow up: - There are similar util functions in `mcp/util.go` which don't seem to be used. Maybe we should move it to `internal/util` instead of `jsonchema/util.go` package? - The method `[Un]MarshalJSON` is only on `pointer receiver`, is this intended? (that means it will only work if we [un]marshal pointer to the `Schema` instance) - The fix is using the [un]marshalStructWithMap from `mcp/util.go` and there are requirements for the function to work properly but these are not strictly checked e.g. need `json:"-"` tag for the Extra field. Fixes modelcontextprotocol#69.
1 parent 2facfc6 commit bfb9215

File tree

4 files changed

+211
-4
lines changed

4 files changed

+211
-4
lines changed

jsonschema/schema.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,9 @@ type Schema struct {
127127
// https://json-schema.org/draft/2020-12/draft-bhutton-json-schema-validation-00#rfc.section.7
128128
Format string `json:"format,omitempty"`
129129

130+
// Extra allows for additional keywords beyond those specified.
131+
Extra map[string]any `json:"-"`
132+
130133
// computed fields
131134

132135
// This schema's base schema.
@@ -236,7 +239,7 @@ func (s *Schema) MarshalJSON() ([]byte, error) {
236239
Type: typ,
237240
schemaWithoutMethods: (*schemaWithoutMethods)(s),
238241
}
239-
return json.Marshal(ms)
242+
return marshalStructWithMap(&ms, "Extra")
240243
}
241244

242245
func (s *Schema) UnmarshalJSON(data []byte) error {
@@ -269,7 +272,7 @@ func (s *Schema) UnmarshalJSON(data []byte) error {
269272
}{
270273
schemaWithoutMethods: (*schemaWithoutMethods)(s),
271274
}
272-
if err := json.Unmarshal(data, &ms); err != nil {
275+
if err := unmarshalStructWithMap(data, &ms, "Extra"); err != nil {
273276
return err
274277
}
275278
// Unmarshal "type" as either Type or Types.

jsonschema/schema_test.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ func TestGoRoundTrip(t *testing.T) {
2626
{Const: Ptr(any(map[string]any{}))},
2727
{Default: mustMarshal(1)},
2828
{Default: mustMarshal(nil)},
29+
{Extra: map[string]any{"test": "value"}},
2930
} {
3031
data, err := json.Marshal(s)
3132
if err != nil {
@@ -64,11 +65,19 @@ func TestJSONRoundTrip(t *testing.T) {
6465
`{"$vocabulary":{"b":true, "a":false}}`,
6566
`{"$vocabulary":{"a":false,"b":true}}`,
6667
},
67-
{`{"unk":0}`, `{}`}, // unknown fields are dropped, unfortunately
68+
{`{"unk":0}`, `{"unk":0}`}, // unknown fields are not dropped
69+
{
70+
// known and unknown fields are not dropped
71+
// note that the order will be by the declaration order in the anonymous struct inside MarshalJSON
72+
`{"comment":"test","type":"example","unk":0}`,
73+
`{"type":"example","comment":"test","unk":0}`,
74+
},
75+
{`{"extra":0}`, `{"extra":0}`}, // extra is not a special keyword and should not be dropped
76+
{`{"Extra":0}`, `{"Extra":0}`}, // Extra is not a special keyword and should not be dropped
6877
} {
6978
var s Schema
7079
mustUnmarshal(t, []byte(tt.in), &s)
71-
data, err := json.Marshal(s)
80+
data, err := json.Marshal(&s)
7281
if err != nil {
7382
t.Fatal(err)
7483
}

jsonschema/util.go

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ import (
1515
"math/big"
1616
"reflect"
1717
"slices"
18+
"sync"
19+
20+
"github.com/modelcontextprotocol/go-sdk/internal/util"
1821
)
1922

2023
// Equal reports whether two Go values representing JSON values are equal according
@@ -282,3 +285,136 @@ func assert(cond bool, msg string) {
282285
panic("assertion failed: " + msg)
283286
}
284287
}
288+
289+
// marshalStructWithMap marshals its first argument to JSON, treating the field named
290+
// mapField as an embedded map. The first argument must be a pointer to
291+
// a struct. The underlying type of mapField must be a map[string]any, and it must have
292+
// a "-" json tag, meaning it will not be marshaled.
293+
//
294+
// For example, given this struct:
295+
//
296+
// type S struct {
297+
// A int
298+
// Extra map[string] any `json:"-"`
299+
// }
300+
//
301+
// and this value:
302+
//
303+
// s := S{A: 1, Extra: map[string]any{"B": 2}}
304+
//
305+
// the call marshalJSONWithMap(s, "Extra") would return
306+
//
307+
// {"A": 1, "B": 2}
308+
//
309+
// It is an error if the map contains the same key as another struct field's
310+
// JSON name.
311+
//
312+
// marshalStructWithMap calls json.Marshal on a value of type T, so T must not
313+
// have a MarshalJSON method that calls this function, on pain of infinite regress.
314+
//
315+
// Note that there is a similar function in mcp/util.go, but they are not the same.
316+
// Here the function requires `-` json tag, does not clear the mapField map,
317+
// and handles embedded struct due to the implementation of jsonNames in this package.
318+
//
319+
// TODO: avoid this restriction on T by forcing it to marshal in a default way.
320+
// See https://go.dev/play/p/EgXKJHxEx_R.
321+
func marshalStructWithMap[T any](s *T, mapField string) ([]byte, error) {
322+
// Marshal the struct and the map separately, and concatenate the bytes.
323+
// This strategy is dramatically less complicated than
324+
// constructing a synthetic struct or map with the combined keys.
325+
if s == nil {
326+
return []byte("null"), nil
327+
}
328+
s2 := *s
329+
vMapField := reflect.ValueOf(&s2).Elem().FieldByName(mapField)
330+
mapVal := vMapField.Interface().(map[string]any)
331+
332+
// Check for duplicates.
333+
names := jsonNames(reflect.TypeFor[T]())
334+
for key := range mapVal {
335+
if names[key] {
336+
return nil, fmt.Errorf("map key %q duplicates struct field", key)
337+
}
338+
}
339+
340+
structBytes, err := json.Marshal(s2)
341+
if err != nil {
342+
return nil, fmt.Errorf("marshalStructWithMap(%+v): %w", s, err)
343+
}
344+
if len(mapVal) == 0 {
345+
return structBytes, nil
346+
}
347+
mapBytes, err := json.Marshal(mapVal)
348+
if err != nil {
349+
return nil, err
350+
}
351+
if len(structBytes) == 2 { // must be "{}"
352+
return mapBytes, nil
353+
}
354+
// "{X}" + "{Y}" => "{X,Y}"
355+
res := append(structBytes[:len(structBytes)-1], ',')
356+
res = append(res, mapBytes[1:]...)
357+
return res, nil
358+
}
359+
360+
// unmarshalStructWithMap is the inverse of marshalStructWithMap.
361+
// T has the same restrictions as in that function.
362+
//
363+
// Note that there is a similar function in mcp/util.go, but they are not the same.
364+
// Here jsonNames also returns fields from embedded structs, hence this function
365+
// handles embedded structs as well.
366+
func unmarshalStructWithMap[T any](data []byte, v *T, mapField string) error {
367+
// Unmarshal into the struct, ignoring unknown fields.
368+
if err := json.Unmarshal(data, v); err != nil {
369+
return err
370+
}
371+
// Unmarshal into the map.
372+
m := map[string]any{}
373+
if err := json.Unmarshal(data, &m); err != nil {
374+
return err
375+
}
376+
// Delete from the map the fields of the struct.
377+
for n := range jsonNames(reflect.TypeFor[T]()) {
378+
delete(m, n)
379+
}
380+
if len(m) != 0 {
381+
reflect.ValueOf(v).Elem().FieldByName(mapField).Set(reflect.ValueOf(m))
382+
}
383+
return nil
384+
}
385+
386+
var jsonNamesMap sync.Map // from reflect.Type to map[string]bool
387+
388+
// jsonNames returns the set of JSON object keys that t will marshal into,
389+
// including fields from embedded structs in t.
390+
// t must be a struct type.
391+
//
392+
// Note that there is a similar function in mcp/util.go, but they are not the same
393+
// Here the function recurses over embedded structs and includes fields from them.
394+
func jsonNames(t reflect.Type) map[string]bool {
395+
// Lock not necessary: at worst we'll duplicate work.
396+
if val, ok := jsonNamesMap.Load(t); ok {
397+
return val.(map[string]bool)
398+
}
399+
m := map[string]bool{}
400+
for i := range t.NumField() {
401+
field := t.Field(i)
402+
// handle embedded structs
403+
if field.Anonymous {
404+
fieldType := field.Type
405+
if fieldType.Kind() == reflect.Ptr {
406+
fieldType = fieldType.Elem()
407+
}
408+
for n := range jsonNames(fieldType) {
409+
m[n] = true
410+
}
411+
continue
412+
}
413+
info := util.FieldJSONInfo(field)
414+
if !info.Omit {
415+
m[info.Name] = true
416+
}
417+
}
418+
jsonNamesMap.Store(t, m)
419+
return m
420+
}

jsonschema/util_test.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@ import (
88
"encoding/json"
99
"hash/maphash"
1010
"reflect"
11+
"strings"
1112
"testing"
13+
14+
"github.com/google/go-cmp/cmp"
15+
"github.com/google/go-cmp/cmp/cmpopts"
1216
)
1317

1418
func TestEqual(t *testing.T) {
@@ -125,3 +129,58 @@ func TestHash(t *testing.T) {
125129
}
126130
_ = hash(null)
127131
}
132+
133+
func TestMarshalStructWithMap(t *testing.T) {
134+
type S struct {
135+
A int
136+
B string `json:"b,omitempty"`
137+
u bool
138+
M map[string]any `json:"-"`
139+
}
140+
t.Run("basic", func(t *testing.T) {
141+
s := S{A: 1, B: "two", M: map[string]any{"!@#": true}}
142+
got, err := marshalStructWithMap(&s, "M")
143+
if err != nil {
144+
t.Fatal(err)
145+
}
146+
want := `{"A":1,"b":"two","!@#":true}`
147+
if g := string(got); g != want {
148+
t.Errorf("\ngot %s\nwant %s", g, want)
149+
}
150+
151+
var un S
152+
if err := unmarshalStructWithMap(got, &un, "M"); err != nil {
153+
t.Fatal(err)
154+
}
155+
if diff := cmp.Diff(s, un, cmpopts.IgnoreUnexported(S{})); diff != "" {
156+
t.Errorf("mismatch (-want, +got):\n%s", diff)
157+
}
158+
})
159+
t.Run("duplicate", func(t *testing.T) {
160+
s := S{A: 1, B: "two", M: map[string]any{"b": "dup"}}
161+
_, err := marshalStructWithMap(&s, "M")
162+
if err == nil || !strings.Contains(err.Error(), "duplicate") {
163+
t.Errorf("got %v, want error with 'duplicate'", err)
164+
}
165+
})
166+
t.Run("embedded", func(t *testing.T) {
167+
type Embedded struct {
168+
A int
169+
B int
170+
Extra map[string]any `json:"-"`
171+
}
172+
type S struct {
173+
C int
174+
Embedded
175+
}
176+
s := S{C: 1, Embedded: Embedded{A: 2, B: 3, Extra: map[string]any{"d": 4, "e": 5}}}
177+
got, err := marshalStructWithMap(&s, "Extra")
178+
if err != nil {
179+
t.Fatal(err)
180+
}
181+
want := `{"C":1,"A":2,"B":3,"d":4,"e":5}`
182+
if g := string(got); g != want {
183+
t.Errorf("got %v, want %v", g, want)
184+
}
185+
})
186+
}

0 commit comments

Comments
 (0)