diff --git a/examples/rate-limiting/go.mod b/examples/rate-limiting/go.mod index 5ec49ddc..f3cf7aa1 100644 --- a/examples/rate-limiting/go.mod +++ b/examples/rate-limiting/go.mod @@ -1,6 +1,8 @@ module github.com/modelcontextprotocol/go-sdk/examples/rate-limiting -go 1.25 +go 1.23.0 + +toolchain go1.24.4 require ( github.com/modelcontextprotocol/go-sdk v0.0.0-20250625185707-09181c2c2e89 diff --git a/jsonschema/infer.go b/jsonschema/infer.go index 1334bdf1..e69a775e 100644 --- a/jsonschema/infer.go +++ b/jsonschema/infer.go @@ -9,6 +9,7 @@ package jsonschema import ( "fmt" "reflect" + "sync" "github.com/modelcontextprotocol/go-sdk/internal/util" ) @@ -39,7 +40,8 @@ import ( // The types must not have cycles. func For[T any]() (*Schema, error) { // TODO: consider skipping incompatible fields, instead of failing. - s, err := forType(reflect.TypeFor[T]()) + seen := make(map[reflect.Type]bool) + s, err := forType(reflect.TypeFor[T](), seen) if err != nil { var z T return nil, fmt.Errorf("For[%T](): %w", z, err) @@ -47,7 +49,9 @@ func For[T any]() (*Schema, error) { return s, nil } -func forType(t reflect.Type) (*Schema, error) { +var typeSchema sync.Map // map[reflect.Type]*Schema + +func forType(t reflect.Type, seen map[reflect.Type]bool) (*Schema, error) { // Follow pointers: the schema for *T is almost the same as for T, except that // an explicit JSON "null" is allowed for the pointer. allowNull := false @@ -56,11 +60,23 @@ func forType(t reflect.Type) (*Schema, error) { t = t.Elem() } + if cachedS, ok := typeSchema.Load(t); ok { + s := deepCopySchema(cachedS.(*Schema)) + adjustTypesForPointer(s, allowNull) + return s, nil + } + var ( s = new(Schema) err error ) + if seen[t] { + return nil, fmt.Errorf("cycle detected for type %v", t) + } + seen[t] = true + defer delete(seen, t) + switch t.Kind() { case reflect.Bool: s.Type = "boolean" @@ -81,14 +97,14 @@ func forType(t reflect.Type) (*Schema, error) { return nil, fmt.Errorf("unsupported map key type %v", t.Key().Kind()) } s.Type = "object" - s.AdditionalProperties, err = forType(t.Elem()) + s.AdditionalProperties, err = forType(t.Elem(), seen) if err != nil { return nil, fmt.Errorf("computing map value schema: %v", err) } case reflect.Slice, reflect.Array: s.Type = "array" - s.Items, err = forType(t.Elem()) + s.Items, err = forType(t.Elem(), seen) if err != nil { return nil, fmt.Errorf("computing element schema: %v", err) } @@ -114,7 +130,7 @@ func forType(t reflect.Type) (*Schema, error) { if s.Properties == nil { s.Properties = make(map[string]*Schema) } - s.Properties[info.Name], err = forType(field.Type) + s.Properties[info.Name], err = forType(field.Type, seen) if err != nil { return nil, err } @@ -126,9 +142,56 @@ func forType(t reflect.Type) (*Schema, error) { default: return nil, fmt.Errorf("type %v is unsupported by jsonschema", t) } + typeSchema.Store(t, deepCopySchema(s)) + adjustTypesForPointer(s, allowNull) + return s, nil +} + +func adjustTypesForPointer(s *Schema, allowNull bool) { if allowNull && s.Type != "" { s.Types = []string{"null", s.Type} s.Type = "" } - return s, nil +} + +// deepCopySchema makes a deep copy of a Schema. +// Only fields that are pointers and modified by forType are copied. +func deepCopySchema(s *Schema) *Schema { + if s == nil { + return nil + } + + clone := new(Schema) + clone.Type = s.Type + + if s.Items != nil { + clone.Items = deepCopySchema(s.Items) + } + if s.AdditionalProperties != nil { + clone.AdditionalProperties = deepCopySchema(s.AdditionalProperties) + } + if s.MinItems != nil { + minItems := *s.MinItems + clone.MinItems = &minItems + } + if s.MaxItems != nil { + maxItems := *s.MaxItems + clone.MaxItems = &maxItems + } + if s.Types != nil { + clone.Types = make([]string, len(s.Types)) + copy(clone.Types, s.Types) + } + if s.Required != nil { + clone.Required = make([]string, len(s.Required)) + copy(clone.Required, s.Required) + } + if s.Properties != nil { + clone.Properties = make(map[string]*Schema) + for k, v := range s.Properties { + clone.Properties[k] = deepCopySchema(v) + } + } + + return clone } diff --git a/jsonschema/infer_test.go b/jsonschema/infer_test.go index 9325b832..52cc5648 100644 --- a/jsonschema/infer_test.go +++ b/jsonschema/infer_test.go @@ -91,3 +91,54 @@ func TestForType(t *testing.T) { }) } } + +func TestForWithMutation(t *testing.T) { + // This test ensures that the cached schema is not mutated when the caller + // mutates the returned schema. + type S struct { + A int + } + type T struct { + A int `json:"A"` + B map[string]int + C []S + D [3]S + E *bool + } + s, err := jsonschema.For[T]() + if err != nil { + t.Fatalf("For: %v", err) + } + s.Required[0] = "mutated" + s.Properties["A"].Type = "mutated" + s.Properties["C"].Items.Type = "mutated" + s.Properties["D"].MaxItems = jsonschema.Ptr(10) + s.Properties["D"].MinItems = jsonschema.Ptr(10) + s.Properties["E"].Types[0] = "mutated" + + s2, err := jsonschema.For[T]() + if err != nil { + t.Fatalf("For: %v", err) + } + if s2.Properties["A"].Type == "mutated" { + t.Fatalf("ForWithMutation: expected A.Type to not be mutated") + } + if s2.Properties["B"].AdditionalProperties.Type == "mutated" { + t.Fatalf("ForWithMutation: expected B.AdditionalProperties.Type to not be mutated") + } + if s2.Properties["C"].Items.Type == "mutated" { + t.Fatalf("ForWithMutation: expected C.Items.Type to not be mutated") + } + if *s2.Properties["D"].MaxItems == 10 { + t.Fatalf("ForWithMutation: expected D.MaxItems to not be mutated") + } + if *s2.Properties["D"].MinItems == 10 { + t.Fatalf("ForWithMutation: expected D.MinItems to not be mutated") + } + if s2.Properties["E"].Types[0] == "mutated" { + t.Fatalf("ForWithMutation: expected E.Types[0] to not be mutated") + } + if s2.Required[0] == "mutated" { + t.Fatalf("ForWithMutation: expected Required[0] to not be mutated") + } +} diff --git a/mcp/content.go b/mcp/content.go index ed7f6f99..fd027cf8 100644 --- a/mcp/content.go +++ b/mcp/content.go @@ -2,6 +2,9 @@ // Use of this source code is governed by an MIT-style // license that can be found in the LICENSE file. +// TODO(findleyr): update JSON marshalling of all content types to preserve required fields. +// (See [TextContent.MarshalJSON], which handles this for text content). + package mcp import ( @@ -25,12 +28,19 @@ type TextContent struct { } func (c *TextContent) MarshalJSON() ([]byte, error) { - return json.Marshal(&wireContent{ + // Custom wire format to ensure the required "text" field is always included, even when empty. + wire := struct { + Type string `json:"type"` + Text string `json:"text"` + Meta Meta `json:"_meta,omitempty"` + Annotations *Annotations `json:"annotations,omitempty"` + }{ Type: "text", Text: c.Text, Meta: c.Meta, Annotations: c.Annotations, - }) + } + return json.Marshal(wire) } func (c *TextContent) fromWire(wire *wireContent) { @@ -177,10 +187,12 @@ func (r ResourceContents) MarshalJSON() ([]byte, error) { URI string `json:"uri,omitempty"` MIMEType string `json:"mimeType,omitempty"` Blob []byte `json:"blob"` + Meta Meta `json:"_meta,omitempty"` }{ URI: r.URI, MIMEType: r.MIMEType, Blob: r.Blob, + Meta: r.Meta, } return json.Marshal(br) } diff --git a/mcp/content_test.go b/mcp/content_test.go index 5ee6f66c..7a549bea 100644 --- a/mcp/content_test.go +++ b/mcp/content_test.go @@ -22,6 +22,14 @@ func TestContent(t *testing.T) { &mcp.TextContent{Text: "hello"}, `{"type":"text","text":"hello"}`, }, + { + &mcp.TextContent{Text: ""}, + `{"type":"text","text":""}`, + }, + { + &mcp.TextContent{}, + `{"type":"text","text":""}`, + }, { &mcp.TextContent{ Text: "hello", @@ -146,6 +154,10 @@ func TestEmbeddedResource(t *testing.T) { &mcp.ResourceContents{URI: "u", Blob: []byte{1}}, `{"uri":"u","blob":"AQ=="}`, }, + { + &mcp.ResourceContents{URI: "u", MIMEType: "m", Blob: []byte{1}, Meta: mcp.Meta{"key": "value"}}, + `{"uri":"u","mimeType":"m","blob":"AQ==","_meta":{"key":"value"}}`, + }, } { data, err := json.Marshal(tt.rc) if err != nil { diff --git a/mcp/streamable.go b/mcp/streamable.go index da950fb2..db3add85 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -223,13 +223,13 @@ type StreamableServerTransport struct { // TODO(rfindley): clean up once requests are handled. requestStreams map[JSONRPCID]streamID - // outstandingRequests tracks the set of unanswered incoming RPCs for each logical + // streamRequests tracks the set of unanswered incoming RPCs for each logical // stream. // // When the server has responded to each request, the stream should be // closed. // - // Lifecycle: outstandingRequests values persist as until the requests have been + // Lifecycle: streamRequests values persist as until the requests have been // replied to by the server. Notably, NOT until they are sent to an HTTP // response, as delivery is not guaranteed. streamRequests map[streamID]map[JSONRPCID]struct{}