Skip to content

Commit 5efcbbf

Browse files
committed
internal/mcp: apply JSON schema defaults
Before validating a tool input, apply the defaults from the input schema. Change-Id: I23fae4087f898d51b3d1f0578d3203b954cb5f2d Reviewed-on: https://go-review.googlesource.com/c/tools/+/680395 Reviewed-by: Robert Findley <[email protected]> LUCI-TryBot-Result: Go LUCI <[email protected]>
1 parent b4768b8 commit 5efcbbf

File tree

3 files changed

+59
-16
lines changed

3 files changed

+59
-16
lines changed

internal/mcp/server.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919

2020
jsonrpc2 "golang.org/x/tools/internal/jsonrpc2_v2"
2121
"golang.org/x/tools/internal/mcp/internal/util"
22+
"golang.org/x/tools/internal/mcp/jsonschema"
2223
)
2324

2425
const DefaultPageSize = 1000
@@ -137,7 +138,7 @@ func (s *Server) addToolsErr(tools ...*ServerTool) error {
137138
// Resolve the schemas, with no base URI. We don't expect tool schemas to
138139
// refer outside of themselves.
139140
if st.Tool.InputSchema != nil {
140-
r, err := st.Tool.InputSchema.Resolve(nil)
141+
r, err := st.Tool.InputSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
141142
if err != nil {
142143
return err
143144
}
@@ -146,7 +147,7 @@ func (s *Server) addToolsErr(tools ...*ServerTool) error {
146147

147148
// TODO: uncomment when output schemas drop.
148149
// if st.Tool.OutputSchema != nil {
149-
// st.outputResolved, err := st.Tool.OutputSchema.Resolve(nil)
150+
// st.outputResolved, err := st.Tool.OutputSchema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
150151
// if err != nil {
151152
// return err
152153
// }

internal/mcp/tool.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,9 +151,11 @@ func unmarshalSchema(data json.RawMessage, resolved *jsonschema.Resolved, v any)
151151
if err := dec.Decode(v); err != nil {
152152
return fmt.Errorf("unmarshaling: %w", err)
153153
}
154-
// TODO(jba): apply defaults.
155154
// TODO: test with nil args.
156155
if resolved != nil {
156+
if err := resolved.ApplyDefaults(v); err != nil {
157+
return fmt.Errorf("applying defaults from \n\t%s\nto\n\t%s:\n%w", schemaJSON(resolved.Schema()), data, err)
158+
}
157159
if err := resolved.Validate(v); err != nil {
158160
return fmt.Errorf("validating\n\t%s\nagainst\n\t %s:\n %w", data, schemaJSON(resolved.Schema()), err)
159161
}

internal/mcp/tool_test.go

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,31 @@
22
// Use of this source code is governed by a BSD-style
33
// license that can be found in the LICENSE file.
44

5-
package mcp_test
5+
package mcp
66

77
import (
88
"context"
9+
"encoding/json"
10+
"reflect"
911
"testing"
1012

1113
"github.com/google/go-cmp/cmp"
1214
"github.com/google/go-cmp/cmp/cmpopts"
13-
"golang.org/x/tools/internal/mcp"
1415
"golang.org/x/tools/internal/mcp/jsonschema"
1516
)
1617

1718
// testToolHandler is used for type inference in TestNewTool.
18-
func testToolHandler[T any](context.Context, *mcp.ServerSession, *mcp.CallToolParamsFor[T]) (*mcp.CallToolResultFor[any], error) {
19+
func testToolHandler[T any](context.Context, *ServerSession, *CallToolParamsFor[T]) (*CallToolResultFor[any], error) {
1920
panic("not implemented")
2021
}
2122

2223
func TestNewTool(t *testing.T) {
2324
tests := []struct {
24-
tool *mcp.ServerTool
25+
tool *ServerTool
2526
want *jsonschema.Schema
2627
}{
2728
{
28-
mcp.NewTool("basic", "", testToolHandler[struct {
29+
NewTool("basic", "", testToolHandler[struct {
2930
Name string `json:"name"`
3031
}]),
3132
&jsonschema.Schema{
@@ -38,8 +39,8 @@ func TestNewTool(t *testing.T) {
3839
},
3940
},
4041
{
41-
mcp.NewTool("enum", "", testToolHandler[struct{ Name string }], mcp.Input(
42-
mcp.Property("Name", mcp.Enum("x", "y", "z")),
42+
NewTool("enum", "", testToolHandler[struct{ Name string }], Input(
43+
Property("Name", Enum("x", "y", "z")),
4344
)),
4445
&jsonschema.Schema{
4546
Type: "object",
@@ -51,13 +52,13 @@ func TestNewTool(t *testing.T) {
5152
},
5253
},
5354
{
54-
mcp.NewTool("required", "", testToolHandler[struct {
55+
NewTool("required", "", testToolHandler[struct {
5556
Name string `json:"name"`
5657
Language string `json:"language"`
5758
X int `json:"x,omitempty"`
5859
Y int `json:"y,omitempty"`
59-
}], mcp.Input(
60-
mcp.Property("x", mcp.Required(true)))),
60+
}], Input(
61+
Property("x", Required(true)))),
6162
&jsonschema.Schema{
6263
Type: "object",
6364
Required: []string{"name", "language", "x"},
@@ -71,11 +72,11 @@ func TestNewTool(t *testing.T) {
7172
},
7273
},
7374
{
74-
mcp.NewTool("set_schema", "", testToolHandler[struct {
75+
NewTool("set_schema", "", testToolHandler[struct {
7576
X int `json:"x,omitempty"`
7677
Y int `json:"y,omitempty"`
77-
}], mcp.Input(
78-
mcp.Schema(&jsonschema.Schema{Type: "object"})),
78+
}], Input(
79+
Schema(&jsonschema.Schema{Type: "object"})),
7980
),
8081
&jsonschema.Schema{
8182
Type: "object",
@@ -88,3 +89,42 @@ func TestNewTool(t *testing.T) {
8889
}
8990
}
9091
}
92+
93+
func TestUnmarshalSchema(t *testing.T) {
94+
schema := &jsonschema.Schema{
95+
Type: "object",
96+
Properties: map[string]*jsonschema.Schema{
97+
"x": {Type: "integer", Default: json.RawMessage("3")},
98+
},
99+
}
100+
resolved, err := schema.Resolve(&jsonschema.ResolveOptions{ValidateDefaults: true})
101+
if err != nil {
102+
t.Fatal(err)
103+
}
104+
105+
type S struct {
106+
X int `json:"x"`
107+
}
108+
109+
for _, tt := range []struct {
110+
data string
111+
v any
112+
want any
113+
}{
114+
{`{"x": 1}`, new(S), &S{X: 1}},
115+
{`{}`, new(S), &S{X: 3}}, // default applied
116+
{`{"x": 0}`, new(S), &S{X: 3}}, // FAIL: should be 0. (requires double unmarshal)
117+
{`{"x": 1}`, new(map[string]any), &map[string]any{"x": 1.0}},
118+
{`{}`, new(map[string]any), &map[string]any{"x": 3.0}}, // default applied
119+
{`{"x": 0}`, new(map[string]any), &map[string]any{"x": 0.0}},
120+
} {
121+
raw := json.RawMessage(tt.data)
122+
if err := unmarshalSchema(raw, resolved, tt.v); err != nil {
123+
t.Fatal(err)
124+
}
125+
if !reflect.DeepEqual(tt.v, tt.want) {
126+
t.Errorf("got %#v, want %#v", tt.v, tt.want)
127+
}
128+
129+
}
130+
}

0 commit comments

Comments
 (0)