Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 76 additions & 46 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ func (c *streamableServerConn) SessionID() string {
// A stream is a single logical stream of SSE events within a server session.
// A stream begins with a client request, or with a client GET that has
// no Last-Event-ID header.
//
// A stream ends only when its session ends; we cannot determine its end otherwise,
// since a client may send a GET with a Last-Event-ID that references the stream
// at any time.
Expand Down Expand Up @@ -529,6 +530,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques
}
c.mu.Unlock()
stream.signal.Store(signalChanPtr())
defer stream.signal.Store(nil)
}

// Publish incoming messages.
Expand Down Expand Up @@ -857,27 +859,27 @@ type StreamableReconnectOptions struct {
// MaxRetries is the maximum number of times to attempt a reconnect before giving up.
// A value of 0 or less means never retry.
MaxRetries int

// growFactor is the multiplicative factor by which the delay increases after each attempt.
// A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time.
// It must be 1.0 or greater if MaxRetries is greater than 0.
growFactor float64

// initialDelay is the base delay for the first reconnect attempt.
initialDelay time.Duration

// maxDelay caps the backoff delay, preventing it from growing indefinitely.
maxDelay time.Duration
}

// DefaultReconnectOptions provides sensible defaults for reconnect logic.
var DefaultReconnectOptions = &StreamableReconnectOptions{
MaxRetries: 5,
growFactor: 1.5,
initialDelay: 1 * time.Second,
maxDelay: 30 * time.Second,
MaxRetries: 5,
}

// These settings are not (yet) exposed to the user in
// StreamableReconnectOptions. Since they're invisible, keep them const rather
// than requiring the user to start from DefaultReconnectOptions and mutate.
const (
// reconnectGrowFactor is the multiplicative factor by which the delay increases after each attempt.
// A value of 1.0 results in a constant delay, while a value of 2.0 would double it each time.
// It must be 1.0 or greater if MaxRetries is greater than 0.
reconnectGrowFactor = 1.5
// reconnectInitialDelay is the base delay for the first reconnect attempt.
reconnectInitialDelay = 1 * time.Second
// reconnectMaxDelay caps the backoff delay, preventing it from growing indefinitely.
reconnectMaxDelay = 30 * time.Second
)

// StreamableClientTransportOptions provides options for the
// [NewStreamableClientTransport] constructor.
//
Expand Down Expand Up @@ -928,7 +930,7 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
conn := &streamableClientConn{
url: t.Endpoint,
client: client,
incoming: make(chan []byte, 100),
incoming: make(chan jsonrpc.Message, 10),
done: make(chan struct{}),
ReconnectOptions: reconnOpts,
ctx: connCtx,
Expand All @@ -944,7 +946,7 @@ type streamableClientConn struct {
client *http.Client
ctx context.Context
cancel context.CancelFunc
incoming chan []byte
incoming chan jsonrpc.Message

// Guard calls to Close, as it may be called multiple times.
closeOnce sync.Once
Expand Down Expand Up @@ -988,7 +990,7 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) {
// § 2.5: A server using the Streamable HTTP transport MAY assign a session
// ID at initialization time, by including it in an Mcp-Session-Id header
// on the HTTP response containing the InitializeResult.
go c.handleSSE(nil, true)
go c.handleSSE(nil, true, nil)
}

// fail handles an asynchronous error while reading.
Expand Down Expand Up @@ -1031,8 +1033,8 @@ func (c *streamableClientConn) Read(ctx context.Context) (jsonrpc.Message, error
return nil, c.failure()
case <-c.done:
return nil, io.EOF
case data := <-c.incoming:
return jsonrpc2.DecodeMessage(data)
case msg := <-c.incoming:
return msg, nil
}
}

Expand All @@ -1042,7 +1044,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
return err
}

data, err := jsonrpc2.EncodeMessage(msg)
data, err := jsonrpc.EncodeMessage(msg)
if err != nil {
return err
}
Expand Down Expand Up @@ -1088,7 +1090,8 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
go c.handleJSON(resp)

case "text/event-stream":
go c.handleSSE(resp, false)
jsonReq, _ := msg.(*jsonrpc.Request)
go c.handleSSE(resp, false, jsonReq)

default:
resp.Body.Close()
Expand Down Expand Up @@ -1116,30 +1119,40 @@ func (c *streamableClientConn) handleJSON(resp *http.Response) {
c.fail(err)
return
}
msg, err := jsonrpc.DecodeMessage(body)
if err != nil {
c.fail(fmt.Errorf("failed to decode response: %v", err))
return
}
select {
case c.incoming <- body:
case c.incoming <- msg:
case <-c.done:
// The connection was closed by the client; exit gracefully.
}
}

// handleSSE manages the lifecycle of an SSE connection. It can be either
// persistent (for the main GET listener) or temporary (for a POST response).
func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent bool) {
//
// If forReq is set, it is the request that initiated the stream, and the
// stream is complete when we receive its response.
func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent bool, forReq *jsonrpc2.Request) {
resp := initialResp
var lastEventID string
for {
eventID, clientClosed := c.processStream(resp)
lastEventID = eventID
if resp != nil {
eventID, clientClosed := c.processStream(resp, forReq)
lastEventID = eventID

// If the connection was closed by the client, we're done.
if clientClosed {
return
}
// If the stream has ended, then do not reconnect if the stream is
// temporary (POST initiated SSE).
if lastEventID == "" && !persistent {
return
// If the connection was closed by the client, we're done.
if clientClosed {
return
}
// If the stream has ended, then do not reconnect if the stream is
// temporary (POST initiated SSE).
if lastEventID == "" && !persistent {
return
}
}

// The stream was interrupted or ended by the server. Attempt to reconnect.
Expand All @@ -1159,12 +1172,7 @@ func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent
// incoming channel. It returns the ID of the last processed event and a flag
// indicating if the connection was closed by the client. If resp is nil, it
// returns "", false.
func (c *streamableClientConn) processStream(resp *http.Response) (lastEventID string, clientClosed bool) {
if resp == nil {
// TODO(rfindley): avoid this special handling.
return "", false
}

func (c *streamableClientConn) processStream(resp *http.Response, forReq *jsonrpc.Request) (lastEventID string, clientClosed bool) {
defer resp.Body.Close()
for evt, err := range scanEvents(resp.Body) {
if err != nil {
Expand All @@ -1175,8 +1183,21 @@ func (c *streamableClientConn) processStream(resp *http.Response) (lastEventID s
lastEventID = evt.ID
}

msg, err := jsonrpc.DecodeMessage(evt.Data)
if err != nil {
c.fail(fmt.Errorf("failed to decode event: %v", err))
return "", true
}

select {
case c.incoming <- evt.Data:
case c.incoming <- msg:
if jsonResp, ok := msg.(*jsonrpc.Response); ok && forReq != nil {
// TODO: we should never get a response when forReq is nil (the hanging GET).
// We should detect this case, and eliminate the 'persistent' flag arguments.
if jsonResp.ID == forReq.ID {
return "", true
}
}
case <-c.done:
// The connection was closed by the client; exit gracefully.
return "", true
Expand All @@ -1192,11 +1213,20 @@ func (c *streamableClientConn) processStream(resp *http.Response) (lastEventID s
func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, error) {
var finalErr error

for attempt := 0; attempt < c.ReconnectOptions.MaxRetries; attempt++ {
// We can reach the 'reconnect' path through the hanging GET, in which case
// lastEventID will be "".
//
// In this case, we need an initial attempt.
attempt := 0
if lastEventID != "" {
attempt = 1
}

for ; attempt <= c.ReconnectOptions.MaxRetries; attempt++ {
select {
case <-c.done:
return nil, fmt.Errorf("connection closed by client during reconnect")
case <-time.After(calculateReconnectDelay(c.ReconnectOptions, attempt)):
case <-time.After(calculateReconnectDelay(attempt)):
resp, err := c.establishSSE(lastEventID)
if err != nil {
finalErr = err // Store the error and try again.
Expand Down Expand Up @@ -1267,11 +1297,11 @@ func (c *streamableClientConn) establishSSE(lastEventID string) (*http.Response,
}

// calculateReconnectDelay calculates a delay using exponential backoff with full jitter.
func calculateReconnectDelay(opts *StreamableReconnectOptions, attempt int) time.Duration {
func calculateReconnectDelay(attempt int) time.Duration {
// Calculate the exponential backoff using the grow factor.
backoffDuration := time.Duration(float64(opts.initialDelay) * math.Pow(opts.growFactor, float64(attempt)))
backoffDuration := time.Duration(float64(reconnectInitialDelay) * math.Pow(reconnectGrowFactor, float64(attempt)))
// Cap the backoffDuration at maxDelay.
backoffDuration = min(backoffDuration, opts.maxDelay)
backoffDuration = min(backoffDuration, reconnectMaxDelay)

// Use a full jitter using backoffDuration
jitter := rand.N(backoffDuration)
Expand Down
Loading