From 807b9c0a1f3fea1c1444e6228f10e0b4cdacb726 Mon Sep 17 00:00:00 2001 From: Sam Thanawalla Date: Tue, 1 Jul 2025 14:39:52 +0000 Subject: [PATCH] mcp: add retry and replay to the Streamable HTTP implementation Adds exponential backoff and jitter to the client-side POST. Implements replay using Last-Event-ID for resumability. Since these concepts are intertwined- I have added this in a single CL. --- mcp/streamable.go | 449 ++++++++++++++++++++++++++++++++++------- mcp/streamable_test.go | 287 ++++++++++++++++++++++++++ 2 files changed, 660 insertions(+), 76 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index da950fb2..963325c2 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -7,13 +7,17 @@ package mcp import ( "bytes" "context" + "errors" "fmt" "io" + "math/rand" + "net" "net/http" "strconv" "strings" "sync" "sync/atomic" + "time" "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" ) @@ -90,7 +94,7 @@ func (h *StreamableHTTPHandler) ServeHTTP(w http.ResponseWriter, req *http.Reque var session *StreamableServerTransport if id := req.Header.Get("Mcp-Session-Id"); id != "" { h.sessionsMu.Lock() - session, _ = h.sessions[id] + session = h.sessions[id] h.sessionsMu.Unlock() if session == nil { http.Error(w, "session not found", http.StatusNotFound) @@ -594,6 +598,15 @@ type StreamableClientTransportOptions struct { // HTTPClient is the client to use for making HTTP requests. If nil, // http.DefaultClient is used. HTTPClient *http.Client + // MaxRetries specifies the maximum number of retries for sending a message + // or re-establishing a hanging GET connection. If 0, no retries are performed + // beyond the initial attempt. + MaxRetries int + + // InitialBackoff is the initial duration to wait before the first retry + // attempt. Subsequent retries use exponential backoff. If 0, a default + // of 1 second is used. + InitialBackoff time.Duration } // NewStreamableClientTransport returns a new client transport that connects to @@ -602,6 +615,13 @@ func NewStreamableClientTransport(url string, opts *StreamableClientTransportOpt t := &StreamableClientTransport{url: url} if opts != nil { t.opts = *opts + } else { + t.opts = StreamableClientTransportOptions{} + } + + // Set default initial backoff if not specified. + if t.opts.InitialBackoff == 0 { + t.opts.InitialBackoff = time.Second } return t } @@ -619,33 +639,60 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er if client == nil { client = http.DefaultClient } - return &streamableClientConn{ - url: t.url, - client: client, - incoming: make(chan []byte, 100), - done: make(chan struct{}), - }, nil + conn := &streamableClientConn{ + url: t.url, + client: client, + incoming: make(chan []byte, 100), + done: make(chan struct{}), + pendingMessages: make(chan JSONRPCMessage, 100), // Buffer pending messages + maxRetries: t.opts.MaxRetries, + initialBackoff: t.opts.InitialBackoff, + randSource: rand.New(rand.NewSource(time.Now().UnixNano())), // Seed for jitter + } + conn.sessionID.Store("") + + // Start the goroutines that handle sending messages and receiving SSE events. + go conn.startMessageWriter() + go conn.startEventStreamReceiver() + + return conn, nil } type streamableClientConn struct { - url string - client *http.Client - incoming chan []byte - done chan struct{} + url string + // sessionID stores the current session ID. + sessionID atomic.Value + client *http.Client + incoming chan []byte + done chan struct{} closeOnce sync.Once closeErr error - mu sync.Mutex - _sessionID string + mu sync.Mutex // Protects lastEventID and err + // lastEventID stores the ID of the last successfully processed SSE event, + // used for resuming the stream. + lastEventID string // bodies map[*http.Response]io.Closer + // err stores the last error that caused the connection to be deemed unhealthy. err error + + // pendingMessages is a buffered channel for messages waiting to be sent. + pendingMessages chan JSONRPCMessage + + // Retry configuration + maxRetries int + initialBackoff time.Duration + randSource *rand.Rand // For adding jitter to backoff + + // cancelHangingGet is a context.CancelFunc for the currently active + // hanging GET request. Used to cancel the request if the connection needs + // to be closed or a new hanging GET is initiated. + cancelHangingGet context.CancelFunc } func (c *streamableClientConn) SessionID() string { - c.mu.Lock() - defer c.mu.Unlock() - return c._sessionID + return c.sessionID.Load().(string) } // Read implements the [Connection] interface. @@ -654,120 +701,370 @@ func (s *streamableClientConn) Read(ctx context.Context) (JSONRPCMessage, error) case <-ctx.Done(): return nil, ctx.Err() case <-s.done: + s.mu.Lock() + defer s.mu.Unlock() + if s.err != nil { + return nil, s.err // Return explicit error if connection closed due to error + } return nil, io.EOF case data := <-s.incoming: return jsonrpc2.DecodeMessage(data) } } -// Write implements the [Connection] interface. +// Write implements the [Connection] interface by enqueuing the message +// for an asynchronous send operation. The actual sending, including retries, +// is handled by the startMessageWriter goroutine. func (s *streamableClientConn) Write(ctx context.Context, msg JSONRPCMessage) error { - s.mu.Lock() - if s.err != nil { - s.mu.Unlock() - return s.err - } - - sessionID := s._sessionID - if sessionID == "" { - // Hold lock for the first request. + select { + case <-ctx.Done(): + return ctx.Err() + case <-s.done: + s.mu.Lock() defer s.mu.Unlock() - } else { - s.mu.Unlock() - } - - gotSessionID, err := s.postMessage(ctx, sessionID, msg) - if err != nil { - if sessionID != "" { - // unlocked; lock to set err - s.mu.Lock() - defer s.mu.Unlock() - } if s.err != nil { - s.err = err + return s.err } - return err + return io.EOF // Connection closed + case s.pendingMessages <- msg: // Enqueue the message for sending + return nil } +} - if sessionID == "" { - // locked - s._sessionID = gotSessionID +// startMessageWriter continuously sends messages from the pendingMessages channel, +// applying retry logic for transient errors. +func (s *streamableClientConn) startMessageWriter() { + for { + select { + case <-s.done: + return // Connection is closed + case msg := <-s.pendingMessages: + // Use a new context for each send attempt to allow individual retries to be cancelled + // if the overall connection context is cancelled. + // This context is cancelled by the inner goroutine once the send attempt (including retries) is done. + ctx, cancel := context.WithCancel(context.Background()) + + go func(msgToSend JSONRPCMessage) { + defer cancel() // Ensure context is cancelled when this goroutine finishes + + currentSessionID := s.sessionID.Load().(string) + var lastErr error + for i := 0; i <= s.maxRetries; i++ { + // Check if the main connection has been closed during retries + select { + case <-s.done: + return + case <-ctx.Done(): // Check if the individual send context was cancelled + return + default: + // Continue + } + + gotSessionID, sendErr := s.postMessage(ctx, currentSessionID, msgToSend) + if sendErr == nil { + // If sessionID was not set and we got one, update it. + if currentSessionID == "" && gotSessionID != "" { + s.sessionID.Store(gotSessionID) + } + // Undefined behavior when currentSessionID != gotSessionID + return + } + + lastErr = sendErr // Store the latest error + if !isRetryable(sendErr) || i == s.maxRetries { + break // Not a retryable error or max retries reached + } + + // Apply exponential backoff with jitter + backoffDuration := s.initialBackoff * time.Duration(1<= 300 { - // TODO: do a best effort read of the body here, and format it in the error. + bodyBytes, _ := io.ReadAll(resp.Body) // Try to read body for more context + resp.Body.Close() + // Wrap the error with httpStatusError for easier status code checking + return "", &httpStatusError{ + StatusCode: resp.StatusCode, + Err: fmt.Errorf("POST request returned unexpected status %d %s: %s", resp.StatusCode, resp.Status, strings.TrimSpace(string(bodyBytes))), + } + } + + newSessionID := resp.Header.Get("Mcp-Session-Id") + if currentSessionID == "" && newSessionID == "" { resp.Body.Close() - return "", fmt.Errorf("broken session: %v", resp.Status) + // This should ideally not happen if server correctly sets session ID on first POST. + return "", fmt.Errorf("initial POST request did not return an Mcp-Session-Id") + } + if newSessionID == "" { + // If the server didn't explicitly send a new one, assume the existing one is still valid. + newSessionID = currentSessionID } - sessionID = resp.Header.Get("Mcp-Session-Id") if resp.Header.Get("Content-Type") == "text/event-stream" { go s.handleSSE(resp) } else { resp.Body.Close() } - return sessionID, nil + + return newSessionID, nil } -func (s *streamableClientConn) handleSSE(resp *http.Response) { - defer resp.Body.Close() +// startEventStreamReceiver continuously attempts to establish and maintain +// the hanging GET connection for receiving Server-Sent Events (SSE). +func (s *streamableClientConn) startEventStreamReceiver() { + backoffDuration := s.initialBackoff + retries := 0 - done := make(chan struct{}) - go func() { - defer close(done) - for evt, err := range scanEvents(resp.Body) { - if err != nil { - // TODO: surface this error; possibly break the stream - return + for { + select { + case <-s.done: + return // Connection is closed. + default: + // Continue + } + + sessionID := s.sessionID.Load().(string) + if sessionID == "" { + // Session ID not yet established (first POST hasn't completed). + // Wait and retry. + time.Sleep(100 * time.Millisecond) // Avoid busy-waiting + continue + } + + // Create a context for the current hanging GET request. + ctx, cancel := context.WithCancel(context.Background()) + s.mu.Lock() + s.cancelHangingGet = cancel // Store cancel function to allow external cancellation + lastEventID := s.lastEventID // Get the last processed event ID for replay + s.mu.Unlock() + + // Perform the hanging GET request + err := s.performHangingGet(ctx, sessionID, lastEventID) + + // Clean up after the hanging GET attempt + s.mu.Lock() + s.cancelHangingGet = nil // Clear the cancel function + s.mu.Unlock() + cancel() // Ensure the context for this specific GET is cancelled + + if err == nil { + // Successful hanging GET, reset retry state + retries = 0 + backoffDuration = s.initialBackoff + // Loop immediately to re-establish connection if it closed gracefully + continue + } + + // Error occurred during hanging GET, check for retry + if retries >= s.maxRetries { + s.mu.Lock() + s.err = fmt.Errorf("failed to maintain SSE connection after %d retries: %w", s.maxRetries, err) + s.mu.Unlock() + s.Close() // Close the connection if persistent failure + return + } + + // Apply exponential backoff with jitter + delay := backoffDuration + time.Duration(s.randSource.Int63n(int64(backoffDuration/2))) + select { + case <-s.done: + return // Connection closed during backoff + case <-time.After(delay): + retries++ + backoffDuration *= 2 // Exponential increase + if backoffDuration > 30*time.Second { // Cap backoff duration + backoffDuration = 30 * time.Second } - s.incoming <- evt.data } - }() + } +} - select { - case <-s.done: - case <-done: +// performHangingGet makes a single HTTP GET request for the SSE stream. +// It returns nil on graceful stream termination or an error on failure. +func (s *streamableClientConn) performHangingGet(ctx context.Context, sessionID, lastEventID string) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, s.url, nil) + if err != nil { + return fmt.Errorf("failed to create GET request: %w", err) + } + req.Header.Set("Mcp-Session-Id", sessionID) + req.Header.Set("Accept", "text/event-stream") + if lastEventID != "" { + req.Header.Set("Last-Event-ID", lastEventID) // Replay from this event + } + + resp, err := s.client.Do(req) + if err != nil { + return fmt.Errorf("GET request failed: %w", err) } + + if resp.StatusCode != http.StatusOK { + bodyBytes, _ := io.ReadAll(resp.Body) + resp.Body.Close() + // Wrap the error with httpStatusError for easier status code checking + return &httpStatusError{ + StatusCode: resp.StatusCode, + Err: fmt.Errorf("GET request returned unexpected status %d %s: %s", resp.StatusCode, resp.Status, strings.TrimSpace(string(bodyBytes))), + } + } + + // Handle the SSE stream from the response body. + return s.handleSSE(resp) +} + +// handleSSE processes Server-Sent Events from the provided HTTP response body. +// It pushes decoded messages to the incoming channel and updates the lastEventID. +func (s *streamableClientConn) handleSSE(resp *http.Response) error { + defer resp.Body.Close() + for evt, err := range scanEvents(resp.Body) { + if err != nil { + if err == io.EOF { + return nil // Stream ended gracefully + } + return fmt.Errorf("error scanning SSE events: %w", err) + } + // Update lastEventID on successful event receipt, crucial for replayability + if evt.id != "" { + s.mu.Lock() + s.lastEventID = evt.id + s.mu.Unlock() + } + select { + case s.incoming <- evt.data: + // Message successfully sent to incoming channel + case <-s.done: + // Connection closed while trying to send incoming message + return io.EOF + } + } + return nil // Stream finished without error +} + +// isRetryable checks if a given error indicates a transient condition +// that warrants a retry. +func isRetryable(err error) bool { + if err == nil { + return false + } + + // Check if the error is an httpStatusError and if its status code is retryable. + var httpErr *httpStatusError + if errors.As(err, &httpErr) { + switch httpErr.StatusCode { + case http.StatusRequestTimeout, // 408 + http.StatusTooEarly, // 425 + http.StatusTooManyRequests, // 429 + http.StatusInternalServerError, // 500 + http.StatusBadGateway, // 502 + http.StatusServiceUnavailable, // 503 + http.StatusGatewayTimeout: // 504 + return true + default: + return false // Non-retryable HTTP status code + } + } + + // Check for network-related errors + if netErr, ok := err.(net.Error); ok { + if netErr.Timeout() { + return true // Retry on timeout errors + } + } + + // Context cancellation should be non-retryable if it's explicitly from the caller. + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return false + } + + return false // Default to not retry for unknown errors } // Close implements the [Connection] interface. +// It ensures that all background goroutines are stopped and +// sends a DELETE request to the server to terminate the logical session. func (s *streamableClientConn) Close() error { s.closeOnce.Do(func() { - close(s.done) + close(s.done) // Signal all goroutines to stop - req, err := http.NewRequest(http.MethodDelete, s.url, nil) - if err != nil { - s.closeErr = err - } else { - req.Header.Set("Mcp-Session-Id", s._sessionID) - if _, err := s.client.Do(req); err != nil { - s.closeErr = err + // Cancel any ongoing hanging GET request + s.mu.Lock() + if s.cancelHangingGet != nil { + s.cancelHangingGet() + } + s.mu.Unlock() + close(s.pendingMessages) + + // Send DELETE request to terminate the session on the server + sessionID := s.sessionID.Load().(string) + if sessionID != "" { + req, err := http.NewRequest(http.MethodDelete, s.url, nil) + if err != nil { + s.closeErr = fmt.Errorf("failed to create DELETE request: %w", err) + } else { + req.Header.Set("Mcp-Session-Id", sessionID) + if _, err := s.client.Do(req); err != nil { + // Log the error but don't prevent close, as session termination is best effort. + s.closeErr = fmt.Errorf("failed to send DELETE request to terminate session: %w", err) + } } } }) return s.closeErr } + +// httpStatusError wraps an error and includes an HTTP status code. +type httpStatusError struct { + StatusCode int + Err error +} + +func (e *httpStatusError) Error() string { + if e.Err != nil { + return fmt.Sprintf("HTTP status %d: %v", e.StatusCode, e.Err) + } + return fmt.Sprintf("HTTP status %d", e.StatusCode) +} + +func (e *httpStatusError) Unwrap() error { + return e.Err +} diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index a8c916e8..ad7a4cdb 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -18,6 +18,7 @@ import ( "sync" "sync/atomic" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -97,6 +98,292 @@ func TestStreamableTransports(t *testing.T) { } } +// TestClientTransportRetriesPost simulates a server that fails a few POST requests +// before succeeding, verifying the client's retry mechanism. +func TestClientTransportRetriesPost(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + var ( + postAttempt atomic.Int32 + expectedAttempts = 3 // Fail twice, succeed on third attempt + ) + + // Mock server that fails POST requests for a few attempts. + server := NewServer("mockServer", "v1.0.0", nil) + server.AddTools(NewServerTool("greet", "say hi", sayHi)) + + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + + httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodPost { + // Read body to inspect if it's the specific tool call being tested + bodyBytes, readErr := io.ReadAll(r.Body) + if readErr != nil { + t.Errorf("Failed to read request body: %v", readErr) + w.WriteHeader(http.StatusInternalServerError) + return + } + r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Restore body for handler.ServeHTTP + + var rawMsg json.RawMessage + if err := json.Unmarshal(bodyBytes, &rawMsg); err == nil { + var msg JSONRPCMessage + if m, decodeErr := jsonrpc2.DecodeMessage(rawMsg); decodeErr == nil { + msg = m + } + + if reqMsg, ok := msg.(*JSONRPCRequest); ok && reqMsg.Method == "tools/call" { + currentAttempt := postAttempt.Add(1) + if currentAttempt <= int32(expectedAttempts-1) { // Fail for expectedAttempts-1 times + t.Logf("Server: Failing POST attempt %d with 503 (tool call)", currentAttempt) + w.WriteHeader(http.StatusServiceUnavailable) + return + } + t.Logf("Server: Succeeding POST attempt %d (tool call)", currentAttempt) + // For the successful attempt, allow the normal handler to proceed. + // This will eventually return 200 OK after the JSON-RPC response is ready. + } + } + } + handler.ServeHTTP(w, r) + })) + defer httpServer.Close() + + transport := NewStreamableClientTransport(httpServer.URL, &StreamableClientTransportOptions{ + MaxRetries: expectedAttempts - 1, // Allow 2 retries (total 3 attempts) + InitialBackoff: 1 * time.Millisecond, // Small backoff for faster test + }) + client := NewClient("testClient", "v1.0.0", nil) + session, err := client.Connect(ctx, transport) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + defer session.Close() + + params := &CallToolParams{ + Name: "greet", + Arguments: map[string]any{"name": "retrytest"}, + } + + callErr := make(chan error, 1) + go func() { + _, err = session.CallTool(ctx, params) + callErr <- err + }() + + select { + case <-ctx.Done(): + t.Fatal("Test timed out before client could complete retries") + case err = <-callErr: + if err != nil { + t.Fatalf("CallTool() failed unexpectedly: %v", err) + } + if postAttempt.Load() != int32(expectedAttempts) { + t.Errorf("Expected %d POST attempts, got %d", expectedAttempts, postAttempt.Load()) + } + } +} + +// TestStreamableClientReplayEvents simulates a client reconnecting after a network +// interruption (simulated by server returning errors) and verifies that missed SSE events +// are replayed by the server upon successful reconnection. +func TestStreamableClientReplayEvents(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + var ( + counter atomic.Int32 + msgCount atomic.Int32 // Counts how many messages have been received + totalMessages = 10 + ) + + // Mock server that notifies progress to the client + server := NewServer("testServer", "v1.0.0", nil) + server.AddTools(NewServerTool("noop", "no operation", func(ctx context.Context, ss *ServerSession, params *CallToolParamsFor[any]) (*CallToolResultFor[any], error) { + // Send totalMessages notifications from the server + for i := range totalMessages { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + err := ss.NotifyProgress(ctx, &ProgressNotificationParams{ + ProgressToken: "test-token", + Message: fmt.Sprintf("Message %d", i), + Progress: float64(i), + }) + if err != nil { + // Connection might be closed, this is expected if client disconnects or server handler returns error + t.Logf("Server failed to send message %d: %v", i, err) + return nil, err + } + time.Sleep(10 * time.Millisecond) // Small delay between messages + } + } + return &CallToolResultFor[any]{}, nil + })) + + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Logf("Server received %s %s with headers: Mcp-Session-Id=%q, Last-Event-ID=%q", + r.Method, r.URL.Path, r.Header.Get("Mcp-Session-Id"), r.Header.Get("Last-Event-ID")) + + bodyBytes, readErr := io.ReadAll(r.Body) + if readErr != nil { + t.Errorf("Failed to read request body: %v", readErr) + w.WriteHeader(http.StatusInternalServerError) + return + } + r.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Restore body for handler.ServeHTTP + + var rawMsg json.RawMessage + if err := json.Unmarshal(bodyBytes, &rawMsg); err == nil { + var msg JSONRPCMessage + if m, decodeErr := jsonrpc2.DecodeMessage(rawMsg); decodeErr == nil { + msg = m + } + if reqMsg, ok := msg.(*JSONRPCRequest); ok && reqMsg.Method == "tools/call" { + count := counter.Load() + // simulate alternating failures + if count%2 == 0 { + counter.Add(1) + t.Logf("Server: Simulating connection failure (attempt %d) with 503 Service Unavailable", count) + w.WriteHeader(http.StatusServiceUnavailable) // Retryable error + return + } + } + } + + handler.ServeHTTP(w, r) + })) + defer httpServer.Close() + + transport := NewStreamableClientTransport(httpServer.URL, &StreamableClientTransportOptions{ + MaxRetries: 2, + InitialBackoff: 10 * time.Millisecond, // Small backoff for faster test + }) + client := NewClient("testClient", "v1.0.0", nil) + session, err := client.Connect(ctx, transport) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + defer session.Close() + + // Collect messages received by the client + receivedMessages := make(chan string, totalMessages) + client.opts.ProgressNotificationHandler = func(ctx context.Context, cs *ClientSession, params *ProgressNotificationParams) { + receivedMessages <- params.Message + msgCount.Add(1) + } + + // Trigger messages from the server by calling a noop tool. + // This will happen concurrently with the client's GET retries. + go func() { + _, callErr := session.CallTool(ctx, &CallToolParams{Name: "noop"}) + if callErr != nil && !strings.Contains(callErr.Error(), "context canceled") { + t.Errorf("CallTool returned unexpected error: %v", callErr) + } + }() + + // Wait for all messages to be received, or timeout + allMessages := []string{} + for len(allMessages) < totalMessages { + select { + case <-ctx.Done(): + t.Fatalf("Test timed out. Received %d messages, expected %d. Last messages: %v", len(allMessages), totalMessages, allMessages) + case msg := <-receivedMessages: + allMessages = append(allMessages, msg) + } + } + + // Verify all messages were received in order + expectedMessages := make([]string, totalMessages) + for i := range totalMessages { + expectedMessages[i] = fmt.Sprintf("Message %d", i) + } + + if diff := cmp.Diff(expectedMessages, allMessages); diff != "" { + t.Errorf("Received messages mismatch (-want +got):\n%s", diff) + } +} + +// TestStreamableClientSessionTermination verifies that the client correctly +// sends a DELETE request to terminate the session when Close() is called. +func TestStreamableClientSessionTermination(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + var establishedSessionID atomic.Value // Stores the session ID we expect to see deleted + establishedSessionID.Store("") + deleteReceived := sync.WaitGroup{} + deleteReceived.Add(1) + // Server that records session IDs and responds to DELETE + server := NewServer("testServer", "v1.0.0", nil) + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + + httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Logf("Server received %s %s with headers: Mcp-Session-Id=%q, Last-Event-ID=%q", + r.Method, r.URL.Path, r.Header.Get("Mcp-Session-Id"), r.Header.Get("Last-Event-ID")) + if r.Method == http.MethodDelete { + if id := r.Header.Get("Mcp-Session-Id"); id != "" { + establishedSessionID.Store(id) + deleteReceived.Done() + } + } + handler.ServeHTTP(w, r) + })) + defer httpServer.Close() + + transport := NewStreamableClientTransport(httpServer.URL, nil) + client := NewClient("testClient", "v1.0.0", nil) + session, err := client.Connect(ctx, transport) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + + // Make a dummy call to ensure sessionID is established on the client side + // This also ensures the server handler sets the ID, which is picked up by the test hook. + session.CallTool(ctx, &CallToolParams{Name: "dummy", Arguments: map[string]any{}}) + + // Close the session + if err := session.Close(); err != nil { + t.Fatalf("session.Close() failed: %v", err) + } + deleteReceived.Wait() +} + +// TestStreamableServerDeleteWithoutSessionID verifies that a DELETE request +// without an Mcp-Session-Id header returns a 400 Bad Request. +func TestStreamableServerDeleteWithoutSessionID(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + server := NewServer("testServer", "v1.0.0", nil) + handler := NewStreamableHTTPHandler(func(req *http.Request) *Server { return server }, nil) + defer handler.closeAll() + + httpServer := httptest.NewServer(handler) + defer httpServer.Close() + + // Make a DELETE request without setting the Mcp-Session-Id header + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, httpServer.URL, nil) + if err != nil { + t.Fatalf("Failed to create DELETE request: %v", err) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("DELETE request failed: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("Expected status %d (Bad Request) for DELETE without session ID, got %d", http.StatusBadRequest, resp.StatusCode) + } else { + t.Logf("Received expected status %d for DELETE without session ID.", resp.StatusCode) + } +} + func TestStreamableServerTransport(t *testing.T) { // This test checks detailed behavior of the streamable server transport, by // faking the behavior of a streamable client using a sequence of HTTP