diff --git a/mcp/streamable.go b/mcp/streamable.go index 4ab343b2..8072a637 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1018,7 +1018,7 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er // Create a new cancellable context that will manage the connection's lifecycle. // This is crucial for cleanly shutting down the background SSE listener by // cancelling its blocking network operations, which prevents hangs on exit. - connCtx, cancel := context.WithCancel(context.Background()) + connCtx, cancel := context.WithCancel(ctx) conn := &streamableClientConn{ url: t.Endpoint, client: client, @@ -1230,7 +1230,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e // testAuth controls whether a fake Authorization header is added to outgoing requests. // TODO: replace with a better mechanism when client-side auth is in place. -var testAuth = false +var testAuth atomic.Bool func (c *streamableClientConn) setMCPHeaders(req *http.Request) { c.mu.Lock() @@ -1242,7 +1242,7 @@ func (c *streamableClientConn) setMCPHeaders(req *http.Request) { if c.sessionID != "" { req.Header.Set(sessionIDHeader, c.sessionID) } - if testAuth { + if testAuth.Load() { req.Header.Set("Authorization", "Bearer foo") } } @@ -1394,14 +1394,10 @@ func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, er // Close implements the [Connection] interface. func (c *streamableClientConn) Close() error { c.closeOnce.Do(func() { - // Cancel any hanging network requests. - c.cancel() - close(c.done) - if errors.Is(c.failure(), errSessionMissing) { // If the session is missing, no need to delete it. } else { - req, err := http.NewRequest(http.MethodDelete, c.url, nil) + req, err := http.NewRequestWithContext(c.ctx, http.MethodDelete, c.url, nil) if err != nil { c.closeErr = err } else { @@ -1411,6 +1407,10 @@ func (c *streamableClientConn) Close() error { } } } + + // Cancel any hanging network requests after cleanup. + c.cancel() + close(c.done) }) return c.closeErr } diff --git a/mcp/streamable_test.go b/mcp/streamable_test.go index 3b967f8f..3576d2b5 100644 --- a/mcp/streamable_test.go +++ b/mcp/streamable_test.go @@ -1310,8 +1310,9 @@ func textContent(t *testing.T, res *CallToolResult) string { } func TestTokenInfo(t *testing.T) { - defer func(b bool) { testAuth = b }(testAuth) - testAuth = true + oldAuth := testAuth.Load() + defer testAuth.Store(oldAuth) + testAuth.Store(true) ctx := context.Background() // Create a server with a tool that returns TokenInfo. @@ -1430,3 +1431,57 @@ func TestStreamableGET(t *testing.T) { t.Errorf("GET with session ID: got status %d, want %d", got, want) } } + +func TestStreamableClientContextPropagation(t *testing.T) { + type contextKey string + const testKey = contextKey("test-key") + const testValue = "test-value" + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + ctx2 := context.WithValue(ctx, testKey, testValue) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + switch req.Method { + case "POST": + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Mcp-Session-Id", "test-session") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2025-03-26","capabilities":{},"serverInfo":{"name":"test","version":"1.0"}}}`)) + case "GET": + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + case "DELETE": + w.WriteHeader(http.StatusNoContent) + } + })) + defer server.Close() + + transport := &StreamableClientTransport{Endpoint: server.URL} + conn, err := transport.Connect(ctx2) + if err != nil { + t.Fatalf("Connect failed: %v", err) + } + defer conn.Close() + + streamableConn, ok := conn.(*streamableClientConn) + if !ok { + t.Fatalf("Expected *streamableClientConn, got %T", conn) + } + + if got := streamableConn.ctx.Value(testKey); got != testValue { + t.Errorf("Context value not propagated: got %v, want %v", got, testValue) + } + + if streamableConn.ctx.Done() == nil { + t.Error("Connection context is not cancellable") + } + + cancel() + select { + case <-streamableConn.ctx.Done(): + case <-time.After(100 * time.Millisecond): + t.Error("Connection context was not cancelled when parent was cancelled") + } + +}