Skip to content

Commit f02f47f

Browse files
committed
mcp: improvements for 'stateless' streamable servers; 'distributed' mode
WIP: needs self review and more tests. Several improvements for the stateless streamable mode, plus support for a 'distributed' (or rather, distributable) version of the stateless server. - Add a 'Stateless' option to StreamableHTTPOptions and StreamableServerTransport, which controls stateless behavior. GetSessionID may still return a non-empty session ID. - Audit validation of stateless mode to allow requests with a session id. Propagate this session ID to the temporary connection. - Peek at requests to allow 'initialize' requests to go through to the session, so that version negotiation can occur (FIXME: add tests). Fixes #284 For #148
1 parent 3c0a062 commit f02f47f

File tree

4 files changed

+143
-30
lines changed

4 files changed

+143
-30
lines changed

internal/jsonrpc2/conn.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,19 @@ func (c *Connection) processResult(from any, req *incomingRequest, result any, e
725725
// write is used by all things that write outgoing messages, including replies.
726726
// it makes sure that writes are atomic
727727
func (c *Connection) write(ctx context.Context, msg Message) error {
728-
err := c.writer.Write(ctx, msg)
728+
var err error
729+
c.updateInFlight(func(s *inFlightState) {
730+
err = s.shuttingDown(ErrServerClosing)
731+
})
732+
if err == nil {
733+
err = c.writer.Write(ctx, msg)
734+
}
735+
736+
// For rejected requests, we don't set the writeErr (which would break the
737+
// connection). They can just be returned to the caller.
738+
if errors.Is(err, ErrRejected) {
739+
return err
740+
}
729741

730742
if err != nil && ctx.Err() == nil {
731743
// The call to Write failed, and since ctx.Err() is nil we can't attribute

internal/jsonrpc2/wire.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,17 @@ var (
3737
ErrServerClosing = NewError(-32004, "server is closing")
3838
// ErrClientClosing is a dummy error returned for calls initiated while the client is closing.
3939
ErrClientClosing = NewError(-32003, "client is closing")
40+
41+
// The following errors have special semantics for MCP transports
42+
43+
// ErrRejected may be wrapped to return errors from calls to Writer.Write
44+
// that signal that the request was rejected by the transport layer as
45+
// invalid.
46+
//
47+
// Such failures do not indicate that the connection is broken, but rather
48+
// should be returned to the caller to indicate that the specific request is
49+
// invalid in the current context.
50+
ErrRejected = NewError(-32004, "rejected by transport")
4051
)
4152

4253
const wireVersion = "2.0"

mcp/streamable.go

Lines changed: 76 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,14 @@ type StreamableHTTPHandler struct {
4646
type StreamableHTTPOptions struct {
4747
// GetSessionID provides the next session ID to use for an incoming request.
4848
//
49-
// If GetSessionID returns an empty string, the session is 'stateless',
50-
// meaning it is not persisted and no session validation is performed.
49+
// FIXME: update doc.
5150
GetSessionID func() string
5251

52+
// Stateless controls whether the session is 'stateless'.
53+
//
54+
// FIXME: update doc.
55+
Stateless bool
56+
5357
// TODO: support session retention (?)
5458

5559
// jsonResponse is forwarded to StreamableServerTransport.jsonResponse.
@@ -119,11 +123,12 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
119123
}
120124

121125
var transport *StreamableServerTransport
122-
if id := req.Header.Get(sessionIDHeader); id != "" {
126+
sessionID := req.Header.Get(sessionIDHeader)
127+
if id := sessionID; id != "" {
123128
h.mu.Lock()
124129
transport, _ = h.transports[id]
125130
h.mu.Unlock()
126-
if transport == nil {
131+
if transport == nil && !h.opts.Stateless {
127132
http.Error(w, "session not found", http.StatusNotFound)
128133
return
129134
}
@@ -132,22 +137,24 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
132137
// TODO(rfindley): simplify the locking so that each request has only one
133138
// critical section.
134139
if req.Method == http.MethodDelete {
135-
if transport == nil {
140+
if sessionID == "" {
136141
// => Mcp-Session-Id was not set; else we'd have returned NotFound above.
137142
http.Error(w, "DELETE requires an Mcp-Session-Id header", http.StatusBadRequest)
138143
return
139144
}
140-
h.mu.Lock()
141-
delete(h.transports, transport.SessionID)
142-
h.mu.Unlock()
143-
transport.connection.Close()
145+
if transport != nil { // transport may be nil in stateless mode
146+
h.mu.Lock()
147+
delete(h.transports, transport.SessionID)
148+
h.mu.Unlock()
149+
transport.connection.Close()
150+
}
144151
w.WriteHeader(http.StatusNoContent)
145152
return
146153
}
147154

148155
switch req.Method {
149156
case http.MethodPost, http.MethodGet:
150-
if req.Method == http.MethodGet && transport == nil {
157+
if req.Method == http.MethodGet && sessionID == "" {
151158
http.Error(w, "GET requires an active session", http.StatusMethodNotAllowed)
152159
return
153160
}
@@ -164,37 +171,76 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
164171
http.Error(w, "no server available", http.StatusBadRequest)
165172
return
166173
}
167-
sessionID := h.opts.GetSessionID()
168-
s := &StreamableServerTransport{SessionID: sessionID, jsonResponse: h.opts.jsonResponse}
174+
if sessionID == "" {
175+
// In stateless mode, sessionID may be nonempty even if there's no
176+
// existing transport.
177+
sessionID = h.opts.GetSessionID()
178+
}
179+
transport = &StreamableServerTransport{
180+
SessionID: sessionID,
181+
Stateless: h.opts.Stateless,
182+
jsonResponse: h.opts.jsonResponse,
183+
}
169184

170185
// To support stateless mode, we initialize the session with a default
171186
// state, so that it doesn't reject subsequent requests.
172187
var connectOpts *ServerSessionOptions
173-
if sessionID == "" {
188+
if h.opts.Stateless {
189+
// Peek at the body to see if it is an initialize request.
190+
// We want that to be handled as usual.
191+
var hasInitialize, hasInitialized bool
192+
193+
// TODO: verify that this allows protocol version negotiation for
194+
// stateless servers.
195+
body, err := io.ReadAll(req.Body)
196+
if err != nil {
197+
http.Error(w, "failed to read body", http.StatusBadRequest)
198+
return
199+
}
200+
// Reset the body to be read later.
201+
req.Body = io.NopCloser(bytes.NewBuffer(body))
202+
203+
msgs, _, err := readBatch(body)
204+
if err == nil {
205+
for _, msg := range msgs {
206+
if req, ok := msg.(*jsonrpc.Request); ok {
207+
switch req.Method {
208+
case methodInitialize:
209+
hasInitialize = true
210+
case notificationInitialized:
211+
hasInitialized = true
212+
}
213+
}
214+
}
215+
}
216+
state := new(ServerSessionState)
217+
if !hasInitialize {
218+
state.InitializeParams = new(InitializeParams)
219+
}
220+
if !hasInitialized {
221+
state.InitializedParams = new(InitializedParams)
222+
}
174223
connectOpts = &ServerSessionOptions{
175-
State: &ServerSessionState{
176-
InitializeParams: new(InitializeParams),
177-
InitializedParams: new(InitializedParams),
178-
},
224+
State: state,
179225
}
180226
}
227+
181228
// Pass req.Context() here, to allow middleware to add context values.
182229
// The context is detached in the jsonrpc2 library when handling the
183230
// long-running stream.
184-
ss, err := server.Connect(req.Context(), s, connectOpts)
231+
ss, err := server.Connect(req.Context(), transport, connectOpts)
185232
if err != nil {
186233
http.Error(w, "failed connection", http.StatusInternalServerError)
187234
return
188235
}
189-
if sessionID == "" {
236+
if h.opts.Stateless {
190237
// Stateless mode: close the session when the request exits.
191238
defer ss.Close() // close the fake session after handling the request
192239
} else {
193240
h.mu.Lock()
194-
h.transports[s.SessionID] = s
241+
h.transports[transport.SessionID] = transport
195242
h.mu.Unlock()
196243
}
197-
transport = s
198244
}
199245

200246
transport.ServeHTTP(w, req)
@@ -225,6 +271,9 @@ type StreamableServerTransport struct {
225271
// generator to produce one, as with [crypto/rand.Text].)
226272
SessionID string
227273

274+
// FIXME: doc
275+
Stateless bool
276+
228277
// Storage for events, to enable stream resumption.
229278
// If nil, a [MemoryEventStore] with the default maximum size will be used.
230279
EventStore EventStore
@@ -265,6 +314,7 @@ func (t *StreamableServerTransport) Connect(context.Context) (Connection, error)
265314
}
266315
t.connection = &streamableServerConn{
267316
sessionID: t.SessionID,
317+
stateless: t.Stateless,
268318
eventStore: t.EventStore,
269319
jsonResponse: t.jsonResponse,
270320
incoming: make(chan jsonrpc.Message, 10),
@@ -285,6 +335,7 @@ func (t *StreamableServerTransport) Connect(context.Context) (Connection, error)
285335

286336
type streamableServerConn struct {
287337
sessionID string
338+
stateless bool
288339
jsonResponse bool
289340
eventStore EventStore
290341

@@ -759,6 +810,10 @@ func (c *streamableServerConn) Read(ctx context.Context) (jsonrpc.Message, error
759810

760811
// Write implements the [Connection] interface.
761812
func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) error {
813+
if req, ok := msg.(*jsonrpc.Request); ok && req.ID.IsValid() && (c.stateless || c.sessionID == "") {
814+
// Requests aren't possible with stateless servers.
815+
return fmt.Errorf("%w: stateless servers cannot make requests", jsonrpc2.ErrRejected)
816+
}
762817
// Find the incoming request that this write relates to, if any.
763818
var forRequest jsonrpc.ID
764819
isResponse := false

mcp/streamable_test.go

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,7 @@ func testStreamableHandler(t *testing.T, handler http.Handler, requests []stream
706706
if !request.ignoreResponse {
707707
transform := cmpopts.AcyclicTransformer("jsonrpcid", func(id jsonrpc.ID) any { return id.Raw() })
708708
if diff := cmp.Diff(request.wantMessages, got, transform); diff != "" {
709-
t.Errorf("received unexpected messages (-want +got):\n%s", diff)
709+
t.Errorf("request #%d: received unexpected messages (-want +got):\n%s", i, diff)
710710
}
711711
}
712712
sessionID.CompareAndSwap("", gotSessionID)
@@ -955,19 +955,18 @@ func TestEventID(t *testing.T) {
955955
}
956956

957957
func TestStreamableStateless(t *testing.T) {
958-
// This version of sayHi doesn't make a ping request (we can't respond to
958+
// This version of sayHi expects
959959
// that request from our client).
960960
sayHi := func(ctx context.Context, req *ServerRequest[*CallToolParamsFor[hiParams]]) (*CallToolResult, error) {
961+
if err := req.Session.Ping(ctx, nil); err == nil {
962+
// ping should fail, but not break the connection
963+
t.Errorf("ping succeeded unexpectedly")
964+
}
961965
return &CallToolResult{Content: []Content{&TextContent{Text: "hi " + req.Params.Arguments.Name}}}, nil
962966
}
963967
server := NewServer(testImpl, nil)
964968
AddTool(server, &Tool{Name: "greet", Description: "say hi"}, sayHi)
965969

966-
// Test stateless mode.
967-
handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{
968-
GetSessionID: func() string { return "" },
969-
})
970-
971970
requests := []streamableRequest{
972971
{
973972
method: "POST",
@@ -987,7 +986,43 @@ func TestStreamableStateless(t *testing.T) {
987986
},
988987
wantSessionID: false,
989988
},
989+
{
990+
method: "POST",
991+
wantStatusCode: http.StatusOK,
992+
messages: []jsonrpc.Message{
993+
req(2, "tools/call", &CallToolParams{Name: "greet", Arguments: hiParams{Name: "foo"}}),
994+
},
995+
wantMessages: []jsonrpc.Message{
996+
resp(2, &CallToolResult{Content: []Content{&TextContent{Text: "hi foo"}}}, nil),
997+
},
998+
wantSessionID: false,
999+
},
9901000
}
9911001

992-
testStreamableHandler(t, handler, requests)
1002+
handler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{
1003+
GetSessionID: func() string { return "" },
1004+
Stateless: true,
1005+
})
1006+
1007+
// Test the default stateless mode.
1008+
t.Run("stateless", func(t *testing.T) {
1009+
testStreamableHandler(t, handler, requests)
1010+
})
1011+
1012+
// Test a "distributed" variant of stateless mode, where it has non-empty
1013+
// session IDs, but is otherwise stateless.
1014+
//
1015+
// This can be used by tools to look up application state preserved across
1016+
// subsequent requests.
1017+
for i, req := range requests {
1018+
// Now, we want a session for all requests.
1019+
req.wantSessionID = true
1020+
requests[i] = req
1021+
}
1022+
distributableHandler := NewStreamableHTTPHandler(func(*http.Request) *Server { return server }, &StreamableHTTPOptions{
1023+
Stateless: true,
1024+
})
1025+
t.Run("distributed", func(t *testing.T) {
1026+
testStreamableHandler(t, distributableHandler, requests)
1027+
})
9931028
}

0 commit comments

Comments
 (0)