Skip to content

Commit a205a5a

Browse files
committed
fix toolschemas
1 parent bd8c842 commit a205a5a

File tree

2 files changed

+78
-26
lines changed

2 files changed

+78
-26
lines changed

examples/server/toolschemas/main.go

Lines changed: 47 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,11 @@ func (t *manualGreeter) greet(_ context.Context, req *mcp.CallToolRequest) (*mcp
5858
if err := json.Unmarshal(req.Params.Arguments, &input); err != nil {
5959
return errf("failed to unmarshal arguments: %v", err), nil
6060
}
61-
if err := t.inputSchema.Validate(input); err != nil {
61+
if err := validateStruct(input, t.inputSchema); err != nil {
6262
return errf("invalid input: %v", err), nil
6363
}
6464
output := Output{Greeting: "Hi " + input.Name}
65-
if err := t.outputSchema.Validate(output); err != nil {
65+
if err := validateStruct(output, t.outputSchema); err != nil {
6666
return errf("tool produced invalid output: %v", err), nil
6767
}
6868
outputJSON, err := json.Marshal(output)
@@ -75,6 +75,50 @@ func (t *manualGreeter) greet(_ context.Context, req *mcp.CallToolRequest) (*mcp
7575
}, nil
7676
}
7777

78+
// validateStruct validates x against schema by first changing the struct to
79+
// a map[string]any, then validating that.
80+
func validateStruct(x any, res *jsonschema.Resolved) error {
81+
data, err := json.Marshal(x)
82+
if err != nil {
83+
return err
84+
}
85+
var m map[string]any
86+
if err := json.Unmarshal(data, &m); err != nil {
87+
return err
88+
}
89+
return res.Validate(m)
90+
}
91+
92+
var (
93+
inputSchema = &jsonschema.Schema{
94+
Type: "object",
95+
Properties: map[string]*jsonschema.Schema{
96+
"name": {Type: "string", MaxLength: jsonschema.Ptr(10)},
97+
},
98+
}
99+
outputSchema = &jsonschema.Schema{
100+
Type: "object",
101+
Properties: map[string]*jsonschema.Schema{
102+
"greeting": {Type: "string"},
103+
},
104+
}
105+
)
106+
107+
func newManualGreeter() (*manualGreeter, error) {
108+
resIn, err := inputSchema.Resolve(nil)
109+
if err != nil {
110+
return nil, err
111+
}
112+
resOut, err := outputSchema.Resolve(nil)
113+
if err != nil {
114+
return nil, err
115+
}
116+
return &manualGreeter{
117+
inputSchema: resIn,
118+
outputSchema: resOut,
119+
}, nil
120+
}
121+
78122
func main() {
79123
server := mcp.NewServer(&mcp.Implementation{Name: "greeter"}, nil)
80124

@@ -90,30 +134,7 @@ func main() {
90134
//
91135
// We don't need to do all this work: below, we use jsonschema.For to start
92136
// from the default schema.
93-
var (
94-
manual manualGreeter
95-
err error
96-
)
97-
inputSchema := &jsonschema.Schema{
98-
Type: "object",
99-
Properties: map[string]*jsonschema.Schema{
100-
"name": {Type: "string", MaxLength: jsonschema.Ptr(10)},
101-
},
102-
}
103-
manual.inputSchema, err = inputSchema.Resolve(nil)
104-
if err != nil {
105-
log.Fatal(err)
106-
}
107-
outputSchema := &jsonschema.Schema{
108-
Type: "object",
109-
Properties: map[string]*jsonschema.Schema{
110-
"greeting": {Type: "string"},
111-
},
112-
}
113-
manual.outputSchema, err = outputSchema.Resolve(nil)
114-
if err != nil {
115-
log.Fatal(err)
116-
}
137+
manual, err := newManualGreeter()
117138
server.AddTool(&mcp.Tool{
118139
Name: "manual greeting",
119140
InputSchema: inputSchema,
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// Copyright 2025 The Go MCP SDK Authors. All rights reserved.
2+
// Use of this source code is governed by an MIT-style
3+
// license that can be found in the LICENSE file.
4+
5+
package main
6+
7+
import (
8+
"context"
9+
"encoding/json"
10+
"testing"
11+
12+
"github.com/modelcontextprotocol/go-sdk/mcp"
13+
)
14+
15+
func TestGreet(t *testing.T) {
16+
manual, err := newManualGreeter()
17+
if err != nil {
18+
t.Fatal(err)
19+
}
20+
res, err := manual.greet(context.Background(), &mcp.CallToolRequest{
21+
Params: &mcp.CallToolParamsRaw{
22+
Arguments: json.RawMessage(`{"name": "Bob"}`),
23+
},
24+
})
25+
if err != nil {
26+
t.Fatal(err)
27+
}
28+
if res.IsError {
29+
t.Fatalf("tool error: %q", res.Content[0].(*mcp.TextContent).Text)
30+
}
31+
}

0 commit comments

Comments
 (0)