Skip to content

Commit 9c995c2

Browse files
committed
mcp: flush headers immediately for the hanging GET
Flush headers immediately for the persistent hanging GET of the streamable transport; otherwise, clients may time out. Fixes #410
1 parent a1eb484 commit 9c995c2

File tree

3 files changed

+104
-16
lines changed

3 files changed

+104
-16
lines changed

examples/server/basic/main.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"context"
99
"fmt"
1010
"log"
11+
"os"
1112

1213
"github.com/modelcontextprotocol/go-sdk/mcp"
1314
)
@@ -31,7 +32,13 @@ func main() {
3132
server := mcp.NewServer(&mcp.Implementation{Name: "greeter", Version: "v0.0.1"}, nil)
3233
mcp.AddTool(server, &mcp.Tool{Name: "greet", Description: "say hi"}, SayHi)
3334

34-
serverSession, err := server.Connect(ctx, serverTransport, nil)
35+
f, err := os.Create("/tmp/mcp.log")
36+
if err != nil {
37+
log.Fatal(err)
38+
}
39+
defer f.Close()
40+
transport := &mcp.LoggingTransport{Transport: serverTransport, Writer: f}
41+
serverSession, err := server.Connect(ctx, transport, nil)
3542
if err != nil {
3643
log.Fatal(err)
3744
}

mcp/streamable.go

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -742,17 +742,28 @@ func (c *streamableServerConn) respondJSON(stream *stream, w http.ResponseWriter
742742

743743
// lastIndex is the index of the last seen event if resuming, else -1.
744744
func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int, persistent bool) {
745-
writes := 0
746-
747-
// Accept checked in [StreamableHTTPHandler]
745+
// Accept was checked in [StreamableHTTPHandler]
748746
w.Header().Set("Cache-Control", "no-cache, no-transform")
749747
w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler]
750748
w.Header().Set("Connection", "keep-alive")
751749
if c.sessionID != "" {
752750
w.Header().Set(sessionIDHeader, c.sessionID)
753751
}
752+
if persistent {
753+
// Issue #410: the hanging GET is likely not to receive messages for a long
754+
// time. Ensure that headers are flushed.
755+
//
756+
// For non-persistent requests, delay the writing of the header in case we
757+
// may want to set an error status.
758+
// (see the TODO: this probably isn't worth it).
759+
w.WriteHeader(http.StatusOK)
760+
if f, ok := w.(http.Flusher); ok {
761+
f.Flush()
762+
}
763+
}
754764

755765
// write one event containing data.
766+
writes := 0
756767
write := func(data []byte) bool {
757768
lastIndex++
758769
e := Event{
@@ -769,23 +780,19 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter,
769780
return true
770781
}
771782

772-
errorf := func(code int, format string, args ...any) {
773-
if writes == 0 {
774-
http.Error(w, fmt.Sprintf(format, args...), code)
775-
} else {
776-
// TODO(#170): log when we add server-side logging
777-
}
778-
}
779-
780783
// Repeatedly collect pending outgoing events and send them.
781784
ctx := req.Context()
782785
for msg, err := range c.messages(ctx, stream, persistent, lastIndex) {
783786
if err != nil {
784-
if ctx.Err() != nil && writes == 0 {
785-
// This probably doesn't matter, but respond with NoContent if the client disconnected.
786-
w.WriteHeader(http.StatusNoContent)
787+
if ctx.Err() == nil && writes == 0 && !persistent {
788+
// If we haven't yet written the header, we have an opportunity to
789+
// promote an error to an HTTP error.
790+
//
791+
// TODO: This may not matter in practice, in which case we should
792+
// simplify.
793+
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
787794
} else {
788-
errorf(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError))
795+
// TODO(#170): log when we add server-side logging
789796
}
790797
return
791798
}

mcp/streamable_test.go

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,3 +1312,77 @@ func TestTokenInfo(t *testing.T) {
13121312
t.Errorf("got %q, want %q", g, w)
13131313
}
13141314
}
1315+
1316+
func TestStreamableGET(t *testing.T) {
1317+
// This test checks the fix for problematic behavior described in #410:
1318+
// Hanging GET headers should be written immediately, even if there are no
1319+
// messages.
1320+
server := NewServer(testImpl, nil)
1321+
1322+
handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil)
1323+
httpServer := httptest.NewServer(handler)
1324+
defer httpServer.Close()
1325+
1326+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
1327+
defer cancel()
1328+
1329+
newReq := func(method string, msg jsonrpc.Message) *http.Request {
1330+
var body io.Reader
1331+
if msg != nil {
1332+
data, err := jsonrpc2.EncodeMessage(msg)
1333+
if err != nil {
1334+
t.Fatal(err)
1335+
}
1336+
body = bytes.NewReader(data)
1337+
}
1338+
req, err := http.NewRequestWithContext(ctx, method, httpServer.URL, body)
1339+
if err != nil {
1340+
t.Fatal(err)
1341+
}
1342+
req.Header.Set("Accept", "application/json, text/event-stream")
1343+
if msg != nil {
1344+
req.Header.Set("Content-Type", "application/json")
1345+
}
1346+
return req
1347+
}
1348+
1349+
get1 := newReq(http.MethodGet, nil)
1350+
resp, err := http.DefaultClient.Do(get1)
1351+
if err != nil {
1352+
t.Fatal(err)
1353+
}
1354+
if got, want := resp.StatusCode, http.StatusMethodNotAllowed; got != want {
1355+
t.Errorf("initial GET: got status %d, want %d", got, want)
1356+
}
1357+
defer resp.Body.Close()
1358+
1359+
post1 := newReq(http.MethodPost, req(1, methodInitialize, &InitializeParams{}))
1360+
resp, err = http.DefaultClient.Do(post1)
1361+
if err != nil {
1362+
t.Fatal(err)
1363+
}
1364+
defer resp.Body.Close()
1365+
if got, want := resp.StatusCode, http.StatusOK; got != want {
1366+
body, err := io.ReadAll(resp.Body)
1367+
if err != nil {
1368+
t.Fatal(err)
1369+
}
1370+
t.Errorf("initialize POST: got status %d, want %d; body:\n%s", got, want, string(body))
1371+
}
1372+
1373+
sessionID := resp.Header.Get(sessionIDHeader)
1374+
if sessionID == "" {
1375+
t.Fatalf("initialized missing session ID")
1376+
}
1377+
1378+
get2 := newReq("GET", nil)
1379+
get2.Header.Set(sessionIDHeader, sessionID)
1380+
resp, err = http.DefaultClient.Do(get2)
1381+
if err != nil {
1382+
t.Fatal(err)
1383+
}
1384+
defer resp.Body.Close()
1385+
if got, want := resp.StatusCode, http.StatusOK; got != want {
1386+
t.Errorf("GET with session ID: got status %d, want %d", got, want)
1387+
}
1388+
}

0 commit comments

Comments
 (0)