diff --git a/mcp/streamable.go b/mcp/streamable.go index 0469613b..c9e31fe7 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -742,17 +742,28 @@ func (c *streamableServerConn) respondJSON(stream *stream, w http.ResponseWriter // lastIndex is the index of the last seen event if resuming, else -1. func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter, req *http.Request, lastIndex int, persistent bool) { - writes := 0 - - // Accept checked in [StreamableHTTPHandler] + // Accept was checked in [StreamableHTTPHandler] w.Header().Set("Cache-Control", "no-cache, no-transform") w.Header().Set("Content-Type", "text/event-stream") // Accept checked in [StreamableHTTPHandler] w.Header().Set("Connection", "keep-alive") if c.sessionID != "" { w.Header().Set(sessionIDHeader, c.sessionID) } + if persistent { + // Issue #410: the hanging GET is likely not to receive messages for a long + // time. Ensure that headers are flushed. + // + // For non-persistent requests, delay the writing of the header in case we + // may want to set an error status. + // (see the TODO: this probably isn't worth it). + w.WriteHeader(http.StatusOK) + if f, ok := w.(http.Flusher); ok { + f.Flush() + } + } // write one event containing data. + writes := 0 write := func(data []byte) bool { lastIndex++ e := Event{ @@ -769,23 +780,19 @@ func (c *streamableServerConn) respondSSE(stream *stream, w http.ResponseWriter, return true } - errorf := func(code int, format string, args ...any) { - if writes == 0 { - http.Error(w, fmt.Sprintf(format, args...), code) - } else { - // TODO(#170): log when we add server-side logging - } - } - // Repeatedly collect pending outgoing events and send them. ctx := req.Context() for msg, err := range c.messages(ctx, stream, persistent, lastIndex) { if err != nil { - if ctx.Err() != nil && writes == 0 { - // This probably doesn't matter, but respond with NoContent if the client disconnected. - w.WriteHeader(http.StatusNoContent) + if ctx.Err() == nil && writes == 0 && !persistent { + // If we haven't yet written the header, we have an opportunity to + // promote an error to an HTTP error. + // + // TODO: This may not matter in practice, in which case we should + // simplify. + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) } else { - errorf(http.StatusInternalServerError, http.StatusText(http.StatusInternalServerError)) + // TODO(#170): log when we add server-side logging } return } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 6b69b210..0d171d83 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1312,3 +1312,77 @@ func TestTokenInfo(t *testing.T) { t.Errorf("got %q, want %q", g, w) } } + +func TestStreamableGET(t *testing.T) { + // This test checks the fix for problematic behavior described in #410: + // Hanging GET headers should be written immediately, even if there are no + // messages. + server := NewServer(testImpl, nil) + + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + newReq := func(method string, msg jsonrpc.Message) *http.Request { + var body io.Reader + if msg != nil { + data, err := jsonrpc2.EncodeMessage(msg) + if err != nil { + t.Fatal(err) + } + body = bytes.NewReader(data) + } + req, err := http.NewRequestWithContext(ctx, method, httpServer.URL, body) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Accept", "application/json, text/event-stream") + if msg != nil { + req.Header.Set("Content-Type", "application/json") + } + return req + } + + get1 := newReq(http.MethodGet, nil) + resp, err := http.DefaultClient.Do(get1) + if err != nil { + t.Fatal(err) + } + if got, want := resp.StatusCode, http.StatusMethodNotAllowed; got != want { + t.Errorf("initial GET: got status %d, want %d", got, want) + } + defer resp.Body.Close() + + post1 := newReq(http.MethodPost, req(1, methodInitialize, &InitializeParams{})) + resp, err = http.DefaultClient.Do(post1) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if got, want := resp.StatusCode, http.StatusOK; got != want { + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + t.Errorf("initialize POST: got status %d, want %d; body:\n%s", got, want, string(body)) + } + + sessionID := resp.Header.Get(sessionIDHeader) + if sessionID == "" { + t.Fatalf("initialized missing session ID") + } + + get2 := newReq("GET", nil) + get2.Header.Set(sessionIDHeader, sessionID) + resp, err = http.DefaultClient.Do(get2) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + if got, want := resp.StatusCode, http.StatusOK; got != want { + t.Errorf("GET with session ID: got status %d, want %d", got, want) + } +}