Skip to content

Commit 3fe0e15

Browse files
committed
mcp: improvements for 'stateless' streamable servers; 'distributed' mode
Several improvements for the stateless streamable mode, plus support for a 'distributed' (or rather, distributable) version of the stateless server. - Add a 'Stateless' option to StreamableHTTPOptions and StreamableServerTransport, which controls stateless behavior. GetSessionID may still return a non-empty session ID. - Audit validation of stateless mode to allow requests with a session id. Propagate this session ID to the temporary connection. - Peek at requests to allow 'initialize' requests to go through to the session, so that version negotiation can occur (FIXME: add tests). Fixes #284 For #148
1 parent 9abb850 commit 3fe0e15

File tree

5 files changed

+203
-37
lines changed

5 files changed

+203
-37
lines changed

internal/jsonrpc2/conn.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,23 @@ func (c *Connection) processResult(from any, req *incomingRequest, result any, e
725725
// write is used by all things that write outgoing messages, including replies.
726726
// it makes sure that writes are atomic
727727
func (c *Connection) write(ctx context.Context, msg Message) error {
728-
err := c.writer.Write(ctx, msg)
728+
var err error
729+
// Fail writes immediately if the connection is shutting down.
730+
//
731+
// TODO(rfindley): should we allow cancellation notifations through? It could
732+
// be the case that writes can still succeed.
733+
c.updateInFlight(func(s *inFlightState) {
734+
err = s.shuttingDown(ErrServerClosing)
735+
})
736+
if err == nil {
737+
err = c.writer.Write(ctx, msg)
738+
}
739+
740+
// For rejected requests, we don't set the writeErr (which would break the
741+
// connection). They can just be returned to the caller.
742+
if errors.Is(err, ErrRejected) {
743+
return err
744+
}
729745

730746
if err != nil && ctx.Err() == nil {
731747
// The call to Write failed, and since ctx.Err() is nil we can't attribute

internal/jsonrpc2/wire.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,17 @@ var (
3737
ErrServerClosing = NewError(-32004, "server is closing")
3838
// ErrClientClosing is a dummy error returned for calls initiated while the client is closing.
3939
ErrClientClosing = NewError(-32003, "client is closing")
40+
41+
// The following errors have special semantics for MCP transports
42+
43+
// ErrRejected may be wrapped to return errors from calls to Writer.Write
44+
// that signal that the request was rejected by the transport layer as
45+
// invalid.
46+
//
47+
// Such failures do not indicate that the connection is broken, but rather
48+
// should be returned to the caller to indicate that the specific request is
49+
// invalid in the current context.
50+
ErrRejected = NewError(-32004, "rejected by transport")
4051
)
4152

4253
const wireVersion = "2.0"

mcp/streamable.go

Lines changed: 99 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,29 @@ type StreamableHTTPHandler struct {
3838
getServer func(*http.Request) *Server
3939
opts StreamableHTTPOptions
4040

41-
mu sync.Mutex
41+
mu sync.Mutex
42+
// TODO: we should store the ServerSession along with the transport, because
43+
// we need to cancel keepalive requests when closing the transport.
4244
transports map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header)
4345
}
4446

4547
// StreamableHTTPOptions configures the StreamableHTTPHandler.
4648
type StreamableHTTPOptions struct {
4749
// GetSessionID provides the next session ID to use for an incoming request.
50+
// If nil, a default randomly generated ID will be used.
4851
//
49-
// If GetSessionID returns an empty string, the session is 'stateless',
50-
// meaning it is not persisted and no session validation is performed.
52+
// As a special case, if GetSessionID returns the empty string, the
53+
// Mcp-Session-Id header will not be set.
5154
GetSessionID func() string
5255

56+
// Stateless controls whether the session is 'stateless'.
57+
//
58+
// A stateless server does not validate the Mcp-Session-Id header, and uses a
59+
// temporary session with default initialization parameters. Any
60+
// server->client request is rejected immediately as there's no way for the
61+
// client to respond.
62+
Stateless bool
63+
5364
// TODO: support session retention (?)
5465

5566
// jsonResponse is forwarded to StreamableServerTransport.jsonResponse.
@@ -118,36 +129,39 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
118129
return
119130
}
120131

132+
sessionID := req.Header.Get(sessionIDHeader)
121133
var transport *StreamableServerTransport
122-
if id := req.Header.Get(sessionIDHeader); id != "" {
134+
if sessionID != "" {
123135
h.mu.Lock()
124-
transport, _ = h.transports[id]
136+
transport, _ = h.transports[sessionID]
125137
h.mu.Unlock()
126-
if transport == nil {
138+
if transport == nil && !h.opts.Stateless {
139+
// In stateless mode we allow a missing transport.
140+
//
141+
// A synthetic transport will be created below for the transient session.
127142
http.Error(w, "session not found", http.StatusNotFound)
128143
return
129144
}
130145
}
131146

132-
// TODO(rfindley): simplify the locking so that each request has only one
133-
// critical section.
134147
if req.Method == http.MethodDelete {
135-
if transport == nil {
136-
// => Mcp-Session-Id was not set; else we'd have returned NotFound above.
148+
if sessionID == "" {
137149
http.Error(w, "DELETE requires an Mcp-Session-Id header", http.StatusBadRequest)
138150
return
139151
}
140-
h.mu.Lock()
141-
delete(h.transports, transport.SessionID)
142-
h.mu.Unlock()
143-
transport.connection.Close()
152+
if transport != nil { // transport may be nil in stateless mode
153+
h.mu.Lock()
154+
delete(h.transports, transport.SessionID)
155+
h.mu.Unlock()
156+
transport.connection.Close()
157+
}
144158
w.WriteHeader(http.StatusNoContent)
145159
return
146160
}
147161

148162
switch req.Method {
149163
case http.MethodPost, http.MethodGet:
150-
if req.Method == http.MethodGet && transport == nil {
164+
if req.Method == http.MethodGet && sessionID == "" {
151165
http.Error(w, "GET requires an active session", http.StatusMethodNotAllowed)
152166
return
153167
}
@@ -164,37 +178,83 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
164178
http.Error(w, "no server available", http.StatusBadRequest)
165179
return
166180
}
167-
sessionID := h.opts.GetSessionID()
168-
s := &StreamableServerTransport{SessionID: sessionID, jsonResponse: h.opts.jsonResponse}
181+
if sessionID == "" {
182+
// In stateless mode, sessionID may be nonempty even if there's no
183+
// existing transport.
184+
sessionID = h.opts.GetSessionID()
185+
}
186+
transport = &StreamableServerTransport{
187+
SessionID: sessionID,
188+
Stateless: h.opts.Stateless,
189+
jsonResponse: h.opts.jsonResponse,
190+
}
169191

170192
// To support stateless mode, we initialize the session with a default
171193
// state, so that it doesn't reject subsequent requests.
172194
var connectOpts *ServerSessionOptions
173-
if sessionID == "" {
195+
if h.opts.Stateless {
196+
// Peek at the body to see if it is initialize or initialized.
197+
// We want those to be handled as usual.
198+
var hasInitialize, hasInitialized bool
199+
{
200+
// TODO: verify that this allows protocol version negotiation for
201+
// stateless servers.
202+
body, err := io.ReadAll(req.Body)
203+
if err != nil {
204+
http.Error(w, "failed to read body", http.StatusBadRequest)
205+
return
206+
}
207+
req.Body.Close()
208+
209+
// Reset the body so that it can be read later.
210+
req.Body = io.NopCloser(bytes.NewBuffer(body))
211+
212+
msgs, _, err := readBatch(body)
213+
if err == nil {
214+
for _, msg := range msgs {
215+
if req, ok := msg.(*jsonrpc.Request); ok {
216+
switch req.Method {
217+
case methodInitialize:
218+
hasInitialize = true
219+
case notificationInitialized:
220+
hasInitialized = true
221+
}
222+
}
223+
}
224+
}
225+
}
226+
227+
// If we don't have InitializeParams or InitializedParams in the request,
228+
// set the initial state to a default value.
229+
state := new(ServerSessionState)
230+
if !hasInitialize {
231+
state.InitializeParams = new(InitializeParams)
232+
}
233+
if !hasInitialized {
234+
state.InitializedParams = new(InitializedParams)
235+
}
174236
connectOpts = &ServerSessionOptions{
175-
State: &ServerSessionState{
176-
InitializeParams: new(InitializeParams),
177-
InitializedParams: new(InitializedParams),
178-
},
237+
State: state,
179238
}
180239
}
240+
181241
// Pass req.Context() here, to allow middleware to add context values.
182242
// The context is detached in the jsonrpc2 library when handling the
183243
// long-running stream.
184-
ss, err := server.Connect(req.Context(), s, connectOpts)
244+
ss, err := server.Connect(req.Context(), transport, connectOpts)
185245
if err != nil {
186246
http.Error(w, "failed connection", http.StatusInternalServerError)
187247
return
188248
}
189-
if sessionID == "" {
249+
if h.opts.Stateless {
190250
// Stateless mode: close the session when the request exits.
191251
defer ss.Close() // close the fake session after handling the request
192252
} else {
253+
// Otherwise, save the transport so that it can be reused
193254
h.mu.Lock()
194-
h.transports[s.SessionID] = s
255+
h.transports[transport.SessionID] = transport
195256
h.mu.Unlock()
196257
}
197-
transport = s
198258
}
199259

200260
transport.ServeHTTP(w, req)
@@ -225,6 +285,13 @@ type StreamableServerTransport struct {
225285
// generator to produce one, as with [crypto/rand.Text].)
226286
SessionID string
227287

288+
// Stateless controls whether the eventstore is 'Stateless'. Servers sessions
289+
// connected to a stateless transport are disallowed from making outgoing
290+
// requests.
291+
//
292+
// See also [StreamableHTTPOptions.Stateless].
293+
Stateless bool
294+
228295
// Storage for events, to enable stream resumption.
229296
// If nil, a [MemoryEventStore] with the default maximum size will be used.
230297
EventStore EventStore
@@ -265,6 +332,7 @@ func (t *StreamableServerTransport) Connect(context.Context) (Connection, error)
265332
}
266333
t.connection = &streamableServerConn{
267334
sessionID: t.SessionID,
335+
stateless: t.Stateless,
268336
eventStore: t.EventStore,
269337
jsonResponse: t.jsonResponse,
270338
incoming: make(chan jsonrpc.Message, 10),
@@ -285,6 +353,7 @@ func (t *StreamableServerTransport) Connect(context.Context) (Connection, error)
285353

286354
type streamableServerConn struct {
287355
sessionID string
356+
stateless bool
288357
jsonResponse bool
289358
eventStore EventStore
290359

@@ -756,6 +825,10 @@ func (c *streamableServerConn) Read(ctx context.Context) (jsonrpc.Message, error
756825

757826
// Write implements the [Connection] interface.
758827
func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) error {
828+
if req, ok := msg.(*jsonrpc.Request); ok && req.ID.IsValid() && (c.stateless || c.sessionID == "") {
829+
// Requests aren't possible with stateless servers, or when there's no session ID.
830+
return fmt.Errorf("%w: stateless servers cannot make requests", jsonrpc2.ErrRejected)
831+
}
759832
// Find the incoming request that this write relates to, if any.
760833
var forRequest jsonrpc.ID
761834
isResponse := false

mcp/streamable_test.go

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,7 @@ func testStreamableHandler(t *testing.T, handler http.Handler, requests []stream
706706
if !request.ignoreResponse {
707707
transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id jsonrpc.ID) any { return id.Raw() })
708708
if diff := cmp.Diff(request.wantMessages, got, transform); diff != "" {
709-
t.Errorf("received unexpected messages (-want +got):\n%s", diff)
709+
t.Errorf("request #%d: received unexpected messages (-want +got):\n%s", i, diff)
710710
}
711711
}
712712
sessionID.CompareAndSwap("", gotSessionID)
@@ -953,19 +953,18 @@ func TestEventID(t *testing.T) {
953953
}
954954

955955
func TestStreamableStateless(t *testing.T) {
956-
// This version of sayHi doesn't make a ping request (we can't respond to
956+
// This version of sayHi expects
957957
// that request from our client).
958958
sayHi := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[hiParams]]) (*CallToolResult, error) {
959+
if err := req.Session.Ping(ctx, nil); err == nil {
960+
// ping should fail, but not break the connection
961+
t.Errorf("ping succeeded unexpectedly")
962+
}
959963
return &CallToolResult{Content: []Content{&TextContent{Text: "hi " + req.Params.Arguments.Name}}}, nil
960964
}
961965
server := NewServer(testImpl, nil)
962966
AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi)
963967

964-
// Test stateless mode.
965-
handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{
966-
GetSessionID: func() string { return "" },
967-
})
968-
969968
requests := []streamableRequest{
970969
{
971970
method: "POST",
@@ -985,7 +984,74 @@ func TestStreamableStateless(t *testing.T) {
985984
},
986985
wantSessionID: false,
987986
},
987+
{
988+
method: "POST",
989+
wantStatusCode: http.StatusOK,
990+
messages: []jsonrpc.Message{
991+
req(2, "tools/call", &CallToolParams{Name: "greet", Arguments: hiParams{Name: "foo"}}),
992+
},
993+
wantMessages: []jsonrpc.Message{
994+
resp(2, &CallToolResult{Content: []Content{&TextContent{Text: "hi foo"}}}, nil),
995+
},
996+
wantSessionID: false,
997+
},
998+
}
999+
1000+
testClientCompatibility := func(t *testing.T, handler http.Handler) {
1001+
ctx := context.Background()
1002+
httpServer := httptest.NewServer(handler)
1003+
defer httpServer.Close()
1004+
cs, err := NewClient(testImpl, nil).Connect(ctx, &StreamableClientTransport{Endpoint: httpServer.URL}, nil)
1005+
if err != nil {
1006+
t.Fatal(err)
1007+
}
1008+
res, err := cs.CallTool(ctx, &CallToolParams{Name: "greet", Arguments: hiParams{Name: "bar"}})
1009+
if err != nil {
1010+
t.Fatal(err)
1011+
}
1012+
if got, want := textContent(t, res), "hi bar"; got != want {
1013+
t.Errorf("Result = %q, want %q", got, want)
1014+
}
9881015
}
9891016

990-
testStreamableHandler(t, handler, requests)
1017+
handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{
1018+
GetSessionID: func() string { return "" },
1019+
Stateless: true,
1020+
})
1021+
1022+
// Test the default stateless mode.
1023+
t.Run("stateless", func(t *testing.T) {
1024+
testStreamableHandler(t, handler, requests)
1025+
testClientCompatibility(t, handler)
1026+
})
1027+
1028+
// Test a "distributed" variant of stateless mode, where it has non-empty
1029+
// session IDs, but is otherwise stateless.
1030+
//
1031+
// This can be used by tools to look up application state preserved across
1032+
// subsequent requests.
1033+
for i, req := range requests {
1034+
// Now, we want a session for all requests.
1035+
req.wantSessionID = true
1036+
requests[i] = req
1037+
}
1038+
distributableHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{
1039+
Stateless: true,
1040+
})
1041+
t.Run("distributed", func(t *testing.T) {
1042+
testStreamableHandler(t, distributableHandler, requests)
1043+
testClientCompatibility(t, handler)
1044+
})
1045+
}
1046+
1047+
func textContent(t *testing.T, res *CallToolResult) string {
1048+
t.Helper()
1049+
if len(res.Content) != 1 {
1050+
t.Fatalf("len(Content) = %d, want 1", len(res.Content))
1051+
}
1052+
text, ok := res.Content[0].(*TextContent)
1053+
if !ok {
1054+
t.Fatalf("Content[0] is %T, want *TextContent", res.Content[0])
1055+
}
1056+
return text.Text
9911057
}

mcp/transport.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ type Transport interface {
4040
type Connection interface {
4141
// Read reads the next message to process off the connection.
4242
//
43-
// Read need not be safe for concurrent use: Read is called in a
44-
// concurrency-safe manner by the JSON-RPC library.
43+
// Connections must allow Read to be called concurrently with Close. In
44+
// particular, calling Close should unblock a Read waiting for input.
4545
Read(context.Context) (jsonrpc.Message, error)
4646

4747
// Write writes a new message to the connection.

0 commit comments

Comments
 (0)