Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -1130,7 +1130,7 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) {
// § 2.5: A server using the Streamable HTTP transport MAY assign a session
// ID at initialization time, by including it in an Mcp-Session-Id header
// on the HTTP response containing the InitializeResult.
go c.handleSSE(nil, true, nil)
go c.handleSSE("hanging GET", nil, true, nil)
}

// fail handles an asynchronous error while reading.
Expand Down Expand Up @@ -1224,17 +1224,27 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e
return nil
}

var requestSummary string
switch msg := msg.(type) {
case *jsonrpc.Request:
requestSummary = fmt.Sprintf("sending %q", msg.Method)
case *jsonrpc.Response:
requestSummary = fmt.Sprintf("sending jsonrpc response #%d", msg.ID)
default:
panic("unreachable")
}

switch ct := resp.Header.Get("Content-Type"); ct {
case "application/json":
go c.handleJSON(resp)
go c.handleJSON(requestSummary, resp)

case "text/event-stream":
jsonReq, _ := msg.(*jsonrpc.Request)
go c.handleSSE(resp, false, jsonReq)
go c.handleSSE(requestSummary, resp, false, jsonReq)

default:
resp.Body.Close()
return fmt.Errorf("unsupported content type %q", ct)
return fmt.Errorf("%s: unsupported content type %q", requestSummary, ct)
}
return nil
}
Expand All @@ -1258,16 +1268,16 @@ func (c *streamableClientConn) setMCPHeaders(req *http.Request) {
}
}

func (c *streamableClientConn) handleJSON(resp *http.Response) {
func (c *streamableClientConn) handleJSON(requestSummary string, resp *http.Response) {
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
c.fail(err)
c.fail(fmt.Errorf("%s: failed to read body: %v", requestSummary, err))
return
}
msg, err := jsonrpc.DecodeMessage(body)
if err != nil {
c.fail(fmt.Errorf("failed to decode response: %v", err))
c.fail(fmt.Errorf("%s: failed to decode response: %v", requestSummary, err))
return
}
select {
Expand All @@ -1282,12 +1292,12 @@ func (c *streamableClientConn) handleJSON(resp *http.Response) {
//
// If forReq is set, it is the request that initiated the stream, and the
// stream is complete when we receive its response.
func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent bool, forReq *jsonrpc2.Request) {
func (c *streamableClientConn) handleSSE(requestSummary string, initialResp *http.Response, persistent bool, forReq *jsonrpc2.Request) {
resp := initialResp
var lastEventID string
for {
if resp != nil {
eventID, clientClosed := c.processStream(resp, forReq)
eventID, clientClosed := c.processStream(requestSummary, resp, forReq)
lastEventID = eventID

// If the connection was closed by the client, we're done.
Expand All @@ -1305,7 +1315,7 @@ func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent
newResp, err := c.reconnect(lastEventID)
if err != nil {
// All reconnection attempts failed: fail the connection.
c.fail(err)
c.fail(fmt.Errorf("%s: failed to reconnect: %v", requestSummary, err))
return
}
resp = newResp
Expand All @@ -1316,7 +1326,7 @@ func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
resp.Body.Close()
c.fail(fmt.Errorf("failed to reconnect: %v", http.StatusText(resp.StatusCode)))
c.fail(fmt.Errorf("%s: failed to reconnect: %v", requestSummary, http.StatusText(resp.StatusCode)))
return
}
// Reconnection was successful. Continue the loop with the new response.
Expand All @@ -1327,7 +1337,7 @@ func (c *streamableClientConn) handleSSE(initialResp *http.Response, persistent
// incoming channel. It returns the ID of the last processed event and a flag
// indicating if the connection was closed by the client. If resp is nil, it
// returns "", false.
func (c *streamableClientConn) processStream(resp *http.Response, forReq *jsonrpc.Request) (lastEventID string, clientClosed bool) {
func (c *streamableClientConn) processStream(requestSummary string, resp *http.Response, forReq *jsonrpc.Request) (lastEventID string, clientClosed bool) {
defer resp.Body.Close()
for evt, err := range scanEvents(resp.Body) {
if err != nil {
Expand All @@ -1340,7 +1350,7 @@ func (c *streamableClientConn) processStream(resp *http.Response, forReq *jsonrp

msg, err := jsonrpc.DecodeMessage(evt.Data)
if err != nil {
c.fail(fmt.Errorf("failed to decode event: %v", err))
c.fail(fmt.Errorf("%s: failed to decode event: %v", requestSummary, err))
return "", true
}

Expand Down
275 changes: 275 additions & 0 deletions mcp/streamable_client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
// Copyright 2025 The Go MCP SDK Authors. All rights reserved.
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file.

package mcp

import (
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2"
"github.com/modelcontextprotocol/go-sdk/jsonrpc"
)

type streamableRequestKey struct {
httpMethod string // http method
sessionID string // session ID header
jsonrpcMethod string // jsonrpc method, or "" for non-requests
}

type header map[string]string

type streamableResponse struct {
header header
status int // or http.StatusOK
body string // or ""
optional bool // if set, request need not be sent
wantProtocolVersion string // if "", unchecked
callback func() // if set, called after the request is handled
}

type fakeResponses map[streamableRequestKey]*streamableResponse

type fakeStreamableServer struct {
t *testing.T
responses fakeResponses

callMu sync.Mutex
calls map[streamableRequestKey]int
}

func (s *fakeStreamableServer) missingRequests() []streamableRequestKey {
s.callMu.Lock()
defer s.callMu.Unlock()

var unused []streamableRequestKey
for k, resp := range s.responses {
if s.calls[k] == 0 && !resp.optional {
unused = append(unused, k)
}
}
return unused
}

func (s *fakeStreamableServer) ServeHTTP(w http.ResponseWriter, req *http.Request) {
key := streamableRequestKey{
httpMethod: req.Method,
sessionID: req.Header.Get(sessionIDHeader),
}
if req.Method == http.MethodPost {
body, err := io.ReadAll(req.Body)
if err != nil {
s.t.Errorf("failed to read body: %v", err)
http.Error(w, "failed to read body", http.StatusInternalServerError)
return
}
msg, err := jsonrpc.DecodeMessage(body)
if err != nil {
s.t.Errorf("invalid body: %v", err)
http.Error(w, "invalid body", http.StatusInternalServerError)
return
}
if r, ok := msg.(*jsonrpc.Request); ok {
key.jsonrpcMethod = r.Method
}
}

s.callMu.Lock()
if s.calls == nil {
s.calls = make(map[streamableRequestKey]int)
}
s.calls[key]++
s.callMu.Unlock()

resp, ok := s.responses[key]
if !ok {
s.t.Errorf("missing response for %v", key)
http.Error(w, "no response", http.StatusInternalServerError)
return
}
if resp.callback != nil {
defer resp.callback()
}
for k, v := range resp.header {
w.Header().Set(k, v)
}
status := resp.status
if status == 0 {
status = http.StatusOK
}
w.WriteHeader(status)

if v := req.Header.Get(protocolVersionHeader); v != resp.wantProtocolVersion && resp.wantProtocolVersion != "" {
s.t.Errorf("%v: bad protocol version header: got %q, want %q", key, v, resp.wantProtocolVersion)
}
w.Write([]byte(resp.body))
}

var (
initResult = &InitializeResult{
Capabilities: &ServerCapabilities{
Completions: &CompletionCapabilities{},
Logging: &LoggingCapabilities{},
Tools: &ToolCapabilities{ListChanged: true},
},
ProtocolVersion: latestProtocolVersion,
ServerInfo: &Implementation{Name: "testServer", Version: "v1.0.0"},
}
initResp = resp(1, initResult, nil)
)

func jsonBody(t *testing.T, msg jsonrpc2.Message) string {
data, err := jsonrpc2.EncodeMessage(msg)
if err != nil {
t.Fatalf("encoding failed: %v", err)
}
return string(data)
}

func TestStreamableClientTransportLifecycle(t *testing.T) {
ctx := context.Background()

// The lifecycle test verifies various behavior of the streamable client
// initialization:
// - check that it can handle application/json responses
// - check that it sends the negotiated protocol version
fake := &fakeStreamableServer{
t: t,
responses: fakeResponses{
{"POST", "", methodInitialize}: {
header: header{
"Content-Type": "application/json",
sessionIDHeader: "123",
},
body: jsonBody(t, initResp),
},
{"POST", "123", notificationInitialized}: {
status: http.StatusAccepted,
wantProtocolVersion: latestProtocolVersion,
},
{"GET", "123", ""}: {
header: header{
"Content-Type": "text/event-stream",
},
optional: true,
wantProtocolVersion: latestProtocolVersion,
},
{"DELETE", "123", ""}: {},
},
}

httpServer := httptest.NewServer(fake)
defer httpServer.Close()

transport := &StreamableClientTransport{Endpoint: httpServer.URL}
client := NewClient(testImpl, nil)
session, err := client.Connect(ctx, transport, nil)
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}
if err := session.Close(); err != nil {
t.Errorf("closing session: %v", err)
}
if missing := fake.missingRequests(); len(missing) > 0 {
t.Errorf("did not receive expected requests: %v", missing)
}
if diff := cmp.Diff(initResult, session.state.InitializeResult); diff != "" {
t.Errorf("mismatch (-want, +got):\n%s", diff)
}
}

func TestStreamableClientGETHandling(t *testing.T) {
ctx := context.Background()

tests := []struct {
status int
wantErrorContaining string
}{
{http.StatusOK, ""},
{http.StatusMethodNotAllowed, ""},
{http.StatusBadRequest, "hanging GET"},
}

for _, test := range tests {
t.Run(fmt.Sprintf("status=%d", test.status), func(t *testing.T) {
fake := &fakeStreamableServer{
t: t,
responses: fakeResponses{
{"POST", "", methodInitialize}: {
header: header{
"Content-Type": "application/json",
sessionIDHeader: "123",
},
body: jsonBody(t, initResp),
},
{"POST", "123", notificationInitialized}: {
status: http.StatusAccepted,
wantProtocolVersion: latestProtocolVersion,
},
{"GET", "123", ""}: {
header: header{
"Content-Type": "text/event-stream",
},
status: test.status,
wantProtocolVersion: latestProtocolVersion,
},
{"POST", "123", methodListTools}: {
header: header{
"Content-Type": "application/json",
sessionIDHeader: "123",
},
body: jsonBody(t, resp(2, &ListToolsResult{Tools: []*Tool{}}, nil)),
optional: true,
},
{"DELETE", "123", ""}: {optional: true},
},
}
httpServer := httptest.NewServer(fake)
defer httpServer.Close()

transport := &StreamableClientTransport{Endpoint: httpServer.URL}
client := NewClient(testImpl, nil)
session, err := client.Connect(ctx, transport, nil)
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}

// wait for all required requests to be handled, with exponential
// backoff.
start := time.Now()
delay := 1 * time.Millisecond
for range 10 {
if len(fake.missingRequests()) == 0 {
break
}
time.Sleep(delay)
delay *= 2
}
if missing := fake.missingRequests(); len(missing) > 0 {
t.Errorf("did not receive expected requests after %s: %v", time.Since(start), missing)
}

_, err = session.ListTools(ctx, nil)
if (err != nil) != (test.wantErrorContaining != "") {
t.Errorf("After initialization, got error %v, want %v", err, test.wantErrorContaining)
} else if err != nil {
if !strings.Contains(err.Error(), test.wantErrorContaining) {
t.Errorf("After initialization, got error %s, want containing %q", err, test.wantErrorContaining)
}
}

if err := session.Close(); err != nil {
t.Errorf("closing session: %v", err)
}
})
}
}
Loading