Skip to content

Commit 9428e43

Browse files
committed
mcp: support "mcp" struct tags
Add the SchemaFor[T] function, which infers the schema for a struct and also uses the values of "mcp" struct tags to set descriptions. Fixes #47.
1 parent 6f25ba6 commit 9428e43

File tree

2 files changed

+103
-1
lines changed

2 files changed

+103
-1
lines changed

mcp/tool.go

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"fmt"
1212
"reflect"
1313

14+
"github.com/modelcontextprotocol/go-sdk/internal/util"
1415
"github.com/modelcontextprotocol/go-sdk/jsonschema"
1516
)
1617

@@ -89,7 +90,7 @@ func newServerTool[In, Out any](t *Tool, h ToolHandlerFor[In, Out]) (*serverTool
8990
func setSchema[T any](sfield **jsonschema.Schema, rfield **jsonschema.Resolved) error {
9091
var err error
9192
if *sfield == nil {
92-
*sfield, err = jsonschema.For[T]()
93+
*sfield, err = SchemaFor[T]()
9394
}
9495
if err != nil {
9596
return err
@@ -125,6 +126,61 @@ func unmarshalSchema(data json.RawMessage, resolved *jsonschema.Resolved, v any)
125126
return nil
126127
}
127128

129+
// SchemaFor returns a JSON Schema for type T.
130+
// It is like [jsonschema.For], but also uses "mcp" struct field tags
131+
// for property descriptions.
132+
//
133+
// For example, the call
134+
//
135+
// SchemaFor[struct{ B int `mcp:"desc"` }]()
136+
//
137+
// returns a schema with this value for "properties":
138+
//
139+
// {"B": {"type": "integer", "description": "desc"}}
140+
func SchemaFor[T any]() (*jsonschema.Schema, error) {
141+
// Infer the schema based on "json" tags alone.
142+
s, err := jsonschema.For[T]()
143+
if err != nil {
144+
return nil, err
145+
}
146+
147+
// Add descriptions from "mcp" tags.
148+
if err := addDescriptions(reflect.TypeFor[T](), s); err != nil {
149+
return nil, err
150+
}
151+
return s, nil
152+
}
153+
154+
func addDescriptions(t reflect.Type, s *jsonschema.Schema) error {
155+
for t.Kind() == reflect.Pointer {
156+
t = t.Elem()
157+
}
158+
if t.Kind() != reflect.Struct {
159+
return nil
160+
}
161+
162+
for i := range t.NumField() {
163+
f := t.Field(i)
164+
info := util.FieldJSONInfo(f)
165+
ps := s.Properties[info.Name]
166+
if tag, ok := f.Tag.Lookup("mcp"); ok {
167+
if ps == nil {
168+
return fmt.Errorf("mcp tag on struct field %s.%s, which is not in schema", t, f.Name)
169+
}
170+
if tag == "" {
171+
return fmt.Errorf("empty mcp tag on struct field %s.%s", t, f.Name)
172+
}
173+
ps.Description = tag
174+
}
175+
// Recurse on sub-schemas.
176+
if ps != nil {
177+
addDescriptions(f.Type, ps)
178+
}
179+
180+
}
181+
return nil
182+
}
183+
128184
// schemaJSON returns the JSON value for s as a string, or a string indicating an error.
129185
func schemaJSON(s *jsonschema.Schema) string {
130186
m, err := json.Marshal(s)

mcp/tool_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,49 @@ func TestUnmarshalSchema(t *testing.T) {
132132

133133
}
134134
}
135+
136+
func TestSchemaFor(t *testing.T) {
137+
type S1 struct {
138+
G int `mcp:"gdesc"`
139+
}
140+
type S2 struct {
141+
A int
142+
B int `json:"b"`
143+
C int `mcp:"cdesc"`
144+
D int `json:"d" mcp:"ddesc"`
145+
E int `json:"-"`
146+
F S1 `json:"f"`
147+
S1
148+
}
149+
150+
got, err := SchemaFor[S2]()
151+
if err != nil {
152+
t.Fatal(err)
153+
}
154+
i := "integer"
155+
s1 := &jsonschema.Schema{
156+
Type: "object",
157+
Required: []string{"G"},
158+
AdditionalProperties: falseSchema(),
159+
Properties: map[string]*jsonschema.Schema{
160+
"G": {Type: i, Description: "gdesc"},
161+
},
162+
}
163+
want := &jsonschema.Schema{
164+
Type: "object",
165+
Properties: map[string]*jsonschema.Schema{
166+
"A": {Type: i},
167+
"b": {Type: i},
168+
"C": {Type: i, Description: "cdesc"},
169+
"d": {Type: i, Description: "ddesc"},
170+
"f": s1,
171+
"S1": s1,
172+
},
173+
Required: []string{"A", "b", "C", "d", "f", "S1"},
174+
AdditionalProperties: falseSchema(),
175+
}
176+
if diff := cmp.Diff(want, got, cmp.AllowUnexported(jsonschema.Schema{})); diff != "" {
177+
t.Errorf("mismatch (-want, +got):\n%s", diff)
178+
t.Log(schemaJSON(got))
179+
}
180+
}

0 commit comments

Comments
 (0)