Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion internal/jsonrpc2/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -739,7 +739,23 @@ func (c *Connection) processResult(from any, req *incomingRequest, result any, e
// write is used by all things that write outgoing messages, including replies.
// it makes sure that writes are atomic
func (c *Connection) write(ctx context.Context, msg Message) error {
err := c.writer.Write(ctx, msg)
var err error
// Fail writes immediately if the connection is shutting down.
//
// TODO(rfindley): should we allow cancellation notifications through? It
// could be the case that writes can still succeed.
c.updateInFlight(func(s *inFlightState) {
err = s.shuttingDown(ErrServerClosing)
})
if err == nil {
err = c.writer.Write(ctx, msg)
}

// For rejected requests, we don't set the writeErr (which would break the
// connection). They can just be returned to the caller.
if errors.Is(err, ErrRejected) {
return err
}

if err != nil && ctx.Err() == nil {
// The call to Write failed, and since ctx.Err() is nil we can't attribute
Expand Down
11 changes: 11 additions & 0 deletions internal/jsonrpc2/wire.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ var (
ErrServerClosing = NewError(-32004, "server is closing")
// ErrClientClosing is a dummy error returned for calls initiated while the client is closing.
ErrClientClosing = NewError(-32003, "client is closing")

// The following errors have special semantics for MCP transports

// ErrRejected may be wrapped to return errors from calls to Writer.Write
// that signal that the request was rejected by the transport layer as
// invalid.
//
// Such failures do not indicate that the connection is broken, but rather
// should be returned to the caller to indicate that the specific request is
// invalid in the current context.
ErrRejected = NewError(-32004, "rejected by transport")
)

const wireVersion = "2.0"
Expand Down
2 changes: 2 additions & 0 deletions mcp/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,8 @@ func (s *Server) callTool(ctx context.Context, req *ServerRequest[*CallToolParam
if !ok {
return nil, fmt.Errorf("%s: unknown tool %q", jsonrpc2.ErrInvalidParams, req.Params.Name)
}
// TODO: if handler returns nil content, it will serialize as null.
// Add a test and fix.
return st.handler(ctx, req)
}

Expand Down
181 changes: 134 additions & 47 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,34 @@ type StreamableHTTPHandler struct {
getServer func(*http.Request) *Server
opts StreamableHTTPOptions

mu sync.Mutex
mu sync.Mutex
// TODO: we should store the ServerSession along with the transport, because
// we need to cancel keepalive requests when closing the transport.
transports map[string]*StreamableServerTransport // keyed by IDs (from Mcp-Session-Id header)
}

// StreamableHTTPOptions configures the StreamableHTTPHandler.
type StreamableHTTPOptions struct {
// GetSessionID provides the next session ID to use for an incoming request.
// If nil, a default randomly generated ID will be used.
//
// Session IDs should be globally unique across the scope of the server,
// which may span multiple processes in the case of distributed servers.
//
// If GetSessionID returns an empty string, the session is 'stateless',
// meaning it is not persisted and no session validation is performed.
// As a special case, if GetSessionID returns the empty string, the
// Mcp-Session-Id header will not be set.
GetSessionID func() string

// Stateless controls whether the session is 'stateless'.
//
// A stateless server does not validate the Mcp-Session-Id header, and uses a
// temporary session with default initialization parameters. Any
// server->client request is rejected immediately as there's no way for the
// client to respond. Server->Client notifications may reach the client if
// they are made in the context of an incoming request, as described in the
// documentation for [StreamableServerTransport].
Stateless bool

// TODO: support session retention (?)

// jsonResponse is forwarded to StreamableServerTransport.jsonResponse.
Expand Down Expand Up @@ -118,36 +134,40 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
return
}

sessionID := req.Header.Get(sessionIDHeader)
var transport *StreamableServerTransport
if id := req.Header.Get(sessionIDHeader); id != "" {
if sessionID != "" {
h.mu.Lock()
transport = h.transports[id]
transport = h.transports[sessionID]
h.mu.Unlock()
if transport == nil {
if transport == nil && !h.opts.Stateless {
// Unless we're in 'stateless' mode, which doesn't perform any Session-ID
// validation, we require that the session ID matches a known session.
//
// In stateless mode, a temporary transport is be created below.
http.Error(w, "session not found", http.StatusNotFound)
return
}
}

// TODO(rfindley): simplify the locking so that each request has only one
// critical section.
if req.Method == http.MethodDelete {
if transport == nil {
// => Mcp-Session-Id was not set; else we'd have returned NotFound above.
if sessionID == "" {
http.Error(w, "DELETE requires an Mcp-Session-Id header", http.StatusBadRequest)
return
}
h.mu.Lock()
delete(h.transports, transport.SessionID)
h.mu.Unlock()
transport.connection.Close()
if transport != nil { // transport may be nil in stateless mode
h.mu.Lock()
delete(h.transports, transport.SessionID)
h.mu.Unlock()
transport.connection.Close()
}
w.WriteHeader(http.StatusNoContent)
return
}

switch req.Method {
case http.MethodPost, http.MethodGet:
if req.Method == http.MethodGet && transport == nil {
if req.Method == http.MethodGet && sessionID == "" {
http.Error(w, "GET requires an active session", http.StatusMethodNotAllowed)
return
}
Expand All @@ -164,37 +184,83 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque
http.Error(w, "no server available", http.StatusBadRequest)
return
}
sessionID := h.opts.GetSessionID()
s := &StreamableServerTransport{SessionID: sessionID, jsonResponse: h.opts.jsonResponse}
if sessionID == "" {
// In stateless mode, sessionID may be nonempty even if there's no
// existing transport.
sessionID = h.opts.GetSessionID()
}
transport = &StreamableServerTransport{
SessionID: sessionID,
Stateless: h.opts.Stateless,
jsonResponse: h.opts.jsonResponse,
}

// To support stateless mode, we initialize the session with a default
// state, so that it doesn't reject subsequent requests.
var connectOpts *ServerSessionOptions
if sessionID == "" {
if h.opts.Stateless {
// Peek at the body to see if it is initialize or initialized.
// We want those to be handled as usual.
var hasInitialize, hasInitialized bool
{
// TODO: verify that this allows protocol version negotiation for
// stateless servers.
body, err := io.ReadAll(req.Body)
if err != nil {
http.Error(w, "failed to read body", http.StatusInternalServerError)
return
}
req.Body.Close()

// Reset the body so that it can be read later.
req.Body = io.NopCloser(bytes.NewBuffer(body))

msgs, _, err := readBatch(body)
if err == nil {
for _, msg := range msgs {
if req, ok := msg.(*jsonrpc.Request); ok {
switch req.Method {
case methodInitialize:
hasInitialize = true
case notificationInitialized:
hasInitialized = true
}
}
}
}
}

// If we don't have InitializeParams or InitializedParams in the request,
// set the initial state to a default value.
state := new(ServerSessionState)
if !hasInitialize {
state.InitializeParams = new(InitializeParams)
}
if !hasInitialized {
state.InitializedParams = new(InitializedParams)
}
connectOpts = &ServerSessionOptions{
State: &ServerSessionState{
InitializeParams: new(InitializeParams),
InitializedParams: new(InitializedParams),
},
State: state,
}
}

// Pass req.Context() here, to allow middleware to add context values.
// The context is detached in the jsonrpc2 library when handling the
// long-running stream.
ss, err := server.Connect(req.Context(), s, connectOpts)
ss, err := server.Connect(req.Context(), transport, connectOpts)
if err != nil {
http.Error(w, "failed connection", http.StatusInternalServerError)
return
}
if sessionID == "" {
if h.opts.Stateless {
// Stateless mode: close the session when the request exits.
defer ss.Close() // close the fake session after handling the request
} else {
// Otherwise, save the transport so that it can be reused
h.mu.Lock()
h.transports[s.SessionID] = s
h.transports[transport.SessionID] = transport
h.mu.Unlock()
}
transport = s
}

transport.ServeHTTP(w, req)
Expand All @@ -212,9 +278,22 @@ type StreamableServerTransportOptions struct {
// A StreamableServerTransport implements the server side of the MCP streamable
// transport.
//
// Each StreamableServerTransport may be connected (via [Server.Connect]) at
// Each StreamableServerTransport must be connected (via [Server.Connect]) at
// most once, since [StreamableServerTransport.ServeHTTP] serves messages to
// the connected session.
//
// Reads from the streamable server connection receive messages from http POST
// requests from the client. Writes to the streamable server connection are
// sent either to the hanging POST response, or to the hanging GET, according
// to the following rules:
// - JSON-RPC responses to incoming requests are always routed to the
// appropriate HTTP response.
// - Requests or notifications made with a context.Context value derived from
// an incoming request handler, are routed to the HTTP response
// corresponding to that request, unless it has already terminated, in
// which case they are routed to the hanging GET.
// - Requests or notifications made with a detached context.Context value are
// routed to the hanging GET.
type StreamableServerTransport struct {
// SessionID is the ID of this session.
//
Expand All @@ -225,6 +304,13 @@ type StreamableServerTransport struct {
// generator to produce one, as with [crypto/rand.Text].)
SessionID string

// Stateless controls whether the eventstore is 'Stateless'. Server sessions
// connected to a stateless transport are disallowed from making outgoing
// requests.
//
// See also [StreamableHTTPOptions.Stateless].
Stateless bool

// Storage for events, to enable stream resumption.
// If nil, a [MemoryEventStore] with the default maximum size will be used.
EventStore EventStore
Expand Down Expand Up @@ -265,6 +351,7 @@ func (t *StreamableServerTransport) Connect(context.Context) (Connection, error)
}
t.connection = &streamableServerConn{
sessionID: t.SessionID,
stateless: t.Stateless,
eventStore: t.EventStore,
jsonResponse: t.jsonResponse,
incoming: make(chan jsonrpc.Message, 10),
Expand All @@ -285,6 +372,7 @@ func (t *StreamableServerTransport) Connect(context.Context) (Connection, error)

type streamableServerConn struct {
sessionID string
stateless bool
jsonResponse bool
eventStore EventStore

Expand Down Expand Up @@ -755,6 +843,10 @@ func (c *streamableServerConn) Read(ctx context.Context) (jsonrpc.Message, error

// Write implements the [Connection] interface.
func (c *streamableServerConn) Write(ctx context.Context, msg jsonrpc.Message) error {
if req, ok := msg.(*jsonrpc.Request); ok && req.ID.IsValid() && (c.stateless || c.sessionID == "") {
// Requests aren't possible with stateless servers, or when there's no session ID.
return fmt.Errorf("%w: stateless servers cannot make requests", jsonrpc2.ErrRejected)
}
// Find the incoming request that this write relates to, if any.
var forRequest jsonrpc.ID
isResponse := false
Expand Down Expand Up @@ -1152,9 +1244,18 @@ func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent
c.fail(err)
return
}

// Reconnection was successful. Continue the loop with the new response.
resp = newResp
if resp.StatusCode == http.StatusMethodNotAllowed && persistent {
// The server doesn't support the hanging GET.
resp.Body.Close()
return
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
resp.Body.Close()
c.fail(fmt.Errorf("failed to reconnect: %v", http.StatusText(resp.StatusCode)))
return
}
// Reconnection was successful. Continue the loop with the new response.
}
}

Expand Down Expand Up @@ -1222,13 +1323,6 @@ func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, er
finalErr = err // Store the error and try again.
continue
}

if !isResumable(resp) {
// The server indicated we should not continue.
resp.Body.Close()
return nil, fmt.Errorf("reconnection failed with unresumable status: %s", resp.Status)
}

return resp, nil
}
}
Expand All @@ -1239,16 +1333,6 @@ func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, er
return nil, fmt.Errorf("connection failed after %d attempts", c.maxRetries)
}

// isResumable checks if an HTTP response indicates a valid SSE stream that can be processed.
func isResumable(resp *http.Response) bool {
// Per the spec, a 405 response means the server doesn't support SSE streams at this endpoint.
if resp.StatusCode == http.StatusMethodNotAllowed {
return false
}

return strings.Contains(resp.Header.Get("Content-Type"), "text/event-stream")
}

// Close implements the [Connection] interface.
func (c *streamableClientConn) Close() error {
c.closeOnce.Do(func() {
Expand Down Expand Up @@ -1288,8 +1372,11 @@ func (c *streamableClientConn) establishSSE(lastEventID string) (*http.Response,

// calculateReconnectDelay calculates a delay using exponential backoff with full jitter.
func calculateReconnectDelay(attempt int) time.Duration {
if attempt == 0 {
return 0
}
// Calculate the exponential backoff using the grow factor.
backoffDuration := time.Duration(float64(reconnectInitialDelay) * math.Pow(reconnectGrowFactor, float64(attempt)))
backoffDuration := time.Duration(float64(reconnectInitialDelay) * math.Pow(reconnectGrowFactor, float64(attempt-1)))
// Cap the backoffDuration at maxDelay.
backoffDuration = min(backoffDuration, reconnectMaxDelay)

Expand Down
Loading