Skip to content

Commit 07b9cee

Browse files
committed
mcp: flush headers immediately for the hanging GET
Flush headers immediately for the persistent hanging GET of the streamable transport; otherwise, clients may time out. Fixes #410
1 parent 203792d commit 07b9cee

File tree

2 files changed

+96
-15
lines changed

2 files changed

+96
-15
lines changed

mcp/streamable.go

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -743,17 +743,28 @@ func (c *streamableServerConn) respondJSON(stream *stream, w http.ResponseWriter
743743

744744
// lastIndex is the index of the last seen event if resuming, else -1.
745745
func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int, persistent bool) {
746-
writes := 0
747-
748-
// Accept checked in [StreamableHTTPHandler]
746+
// Accept was checked in [StreamableHTTPHandler]
749747
w.Header().Set("Cache-Control", "no-cache, no-transform")
750748
w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler]
751749
w.Header().Set("Connection", "keep-alive")
752750
if c.sessionID != "" {
753751
w.Header().Set(sessionIDHeader, c.sessionID)
754752
}
753+
if persistent {
754+
// Issue #410: the hanging GET is likely not to receive messages for a long
755+
// time. Ensure that headers are flushed.
756+
//
757+
// For non-persistent requests, delay the writing of the header in case we
758+
// may want to set an error status.
759+
// (see the TODO: this probably isn't worth it).
760+
w.WriteHeader(http.StatusOK)
761+
if f, ok := w.(http.Flusher); ok {
762+
f.Flush()
763+
}
764+
}
755765

756766
// write one event containing data.
767+
writes := 0
757768
write := func(data []byte) bool {
758769
lastIndex++
759770
e := Event{
@@ -770,23 +781,19 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter,
770781
return true
771782
}
772783

773-
errorf := func(code int, format string, args ...any) {
774-
if writes == 0 {
775-
http.Error(w, fmt.Sprintf(format, args...), code)
776-
} else {
777-
// TODO(#170): log when we add server-side logging
778-
}
779-
}
780-
781784
// Repeatedly collect pending outgoing events and send them.
782785
ctx := req.Context()
783786
for msg, err := range c.messages(ctx, stream, persistent, lastIndex) {
784787
if err != nil {
785-
if ctx.Err() != nil && writes == 0 {
786-
// This probably doesn't matter, but respond with NoContent if the client disconnected.
787-
w.WriteHeader(http.StatusNoContent)
788+
if ctx.Err() == nil && writes == 0 && !persistent {
789+
// If we haven't yet written the header, we have an opportunity to
790+
// promote an error to an HTTP error.
791+
//
792+
// TODO: This may not matter in practice, in which case we should
793+
// simplify.
794+
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
788795
} else {
789-
errorf(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
796+
// TODO(#170): log when we add server-side logging
790797
}
791798
return
792799
}

mcp/streamable_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,3 +1312,77 @@ func TestTokenInfo(t *testing.T) {
13121312
t.Errorf("got %q, want %q", g, w)
13131313
}
13141314
}
1315+
1316+
func TestStreamableGET(t *testing.T) {
1317+
// This test checks the fix for problematic behavior described in #410:
1318+
// Hanging GET headers should be written immediately, even if there are no
1319+
// messages.
1320+
server := NewServer(testImpl, nil)
1321+
1322+
handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil)
1323+
httpServer := httptest.NewServer(handler)
1324+
defer httpServer.Close()
1325+
1326+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
1327+
defer cancel()
1328+
1329+
newReq := func(method string, msg jsonrpc.Message) *http.Request {
1330+
var body io.Reader
1331+
if msg != nil {
1332+
data, err := jsonrpc2.EncodeMessage(msg)
1333+
if err != nil {
1334+
t.Fatal(err)
1335+
}
1336+
body = bytes.NewReader(data)
1337+
}
1338+
req, err := http.NewRequestWithContext(ctx, method, httpServer.URL, body)
1339+
if err != nil {
1340+
t.Fatal(err)
1341+
}
1342+
req.Header.Set("Accept", "application/json, text/event-stream")
1343+
if msg != nil {
1344+
req.Header.Set("Content-Type", "application/json")
1345+
}
1346+
return req
1347+
}
1348+
1349+
get1 := newReq(http.MethodGet, nil)
1350+
resp, err := http.DefaultClient.Do(get1)
1351+
if err != nil {
1352+
t.Fatal(err)
1353+
}
1354+
if got, want := resp.StatusCode, http.StatusMethodNotAllowed; got != want {
1355+
t.Errorf("initial GET: got status %d, want %d", got, want)
1356+
}
1357+
defer resp.Body.Close()
1358+
1359+
post1 := newReq(http.MethodPost, req(1, methodInitialize, &InitializeParams{}))
1360+
resp, err = http.DefaultClient.Do(post1)
1361+
if err != nil {
1362+
t.Fatal(err)
1363+
}
1364+
defer resp.Body.Close()
1365+
if got, want := resp.StatusCode, http.StatusOK; got != want {
1366+
body, err := io.ReadAll(resp.Body)
1367+
if err != nil {
1368+
t.Fatal(err)
1369+
}
1370+
t.Errorf("initialize POST: got status %d, want %d; body:\n%s", got, want, string(body))
1371+
}
1372+
1373+
sessionID := resp.Header.Get(sessionIDHeader)
1374+
if sessionID == "" {
1375+
t.Fatalf("initialized missing session ID")
1376+
}
1377+
1378+
get2 := newReq("GET", nil)
1379+
get2.Header.Set(sessionIDHeader, sessionID)
1380+
resp, err = http.DefaultClient.Do(get2)
1381+
if err != nil {
1382+
t.Fatal(err)
1383+
}
1384+
defer resp.Body.Close()
1385+
if got, want := resp.StatusCode, http.StatusOK; got != want {
1386+
t.Errorf("GET with session ID: got status %d, want %d", got, want)
1387+
}
1388+
}

0 commit comments

Comments
 (0)