Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -1018,7 +1018,7 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
// Create a new cancellable context that will manage the connection's lifecycle.
// This is crucial for cleanly shutting down the background SSE listener by
// cancelling its blocking network operations, which prevents hangs on exit.
connCtx, cancel := context.WithCancel(context.Background())
connCtx, cancel := context.WithCancel(ctx)
conn := &streamableClientConn{
url: t.Endpoint,
client: client,
Expand Down Expand Up @@ -1230,7 +1230,7 @@ func (c *streamableClientConn) Write(ctx context.Context, msg jsonrpc.Message) e

// testAuth controls whether a fake Authorization header is added to outgoing requests.
// TODO: replace with a better mechanism when client-side auth is in place.
var testAuth = false
var testAuth atomic.Bool

func (c *streamableClientConn) setMCPHeaders(req *http.Request) {
c.mu.Lock()
Expand All @@ -1242,7 +1242,7 @@ func (c *streamableClientConn) setMCPHeaders(req *http.Request) {
if c.sessionID != "" {
req.Header.Set(sessionIDHeader, c.sessionID)
}
if testAuth {
if testAuth.Load() {
req.Header.Set("Authorization", "Bearer foo")
}
}
Expand Down Expand Up @@ -1394,14 +1394,10 @@ func (c *streamableClientConn) reconnect(lastEventID string) (*http.Response, er
// Close implements the [Connection] interface.
func (c *streamableClientConn) Close() error {
c.closeOnce.Do(func() {
// Cancel any hanging network requests.
c.cancel()
close(c.done)

if errors.Is(c.failure(), errSessionMissing) {
// If the session is missing, no need to delete it.
} else {
req, err := http.NewRequest(http.MethodDelete, c.url, nil)
req, err := http.NewRequestWithContext(c.ctx, http.MethodDelete, c.url, nil)
if err != nil {
c.closeErr = err
} else {
Expand All @@ -1411,6 +1407,10 @@ func (c *streamableClientConn) Close() error {
}
}
}

// Cancel any hanging network requests after cleanup.
c.cancel()
close(c.done)
})
return c.closeErr
}
Expand Down
59 changes: 57 additions & 2 deletions mcp/streamable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1310,8 +1310,9 @@ func textContent(t *testing.T, res *CallToolResult) string {
}

func TestTokenInfo(t *testing.T) {
defer func(b bool) { testAuth = b }(testAuth)
testAuth = true
oldAuth := testAuth.Load()
defer testAuth.Store(oldAuth)
testAuth.Store(true)
ctx := context.Background()

// Create a server with a tool that returns TokenInfo.
Expand Down Expand Up @@ -1430,3 +1431,57 @@ func TestStreamableGET(t *testing.T) {
t.Errorf("GET with session ID: got status %d, want %d", got, want)
}
}

func TestStreamableClientContextPropagation(t *testing.T) {
type contextKey string
const testKey = contextKey("test-key")
const testValue = "test-value"

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx2 := context.WithValue(ctx, testKey, testValue)

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
switch req.Method {
case "POST":
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Mcp-Session-Id", "test-session")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2025-03-26","capabilities":{},"serverInfo":{"name":"test","version":"1.0"}}}`))
case "GET":
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
case "DELETE":
w.WriteHeader(http.StatusNoContent)
}
}))
defer server.Close()

transport := &StreamableClientTransport{Endpoint: server.URL}
conn, err := transport.Connect(ctx2)
if err != nil {
t.Fatalf("Connect failed: %v", err)
}
defer conn.Close()

streamableConn, ok := conn.(*streamableClientConn)
if !ok {
t.Fatalf("Expected *streamableClientConn, got %T", conn)
}

if got := streamableConn.ctx.Value(testKey); got != testValue {
t.Errorf("Context value not propagated: got %v, want %v", got, testValue)
}

if streamableConn.ctx.Done() == nil {
t.Error("Connection context is not cancellable")
}

cancel()
select {
case <-streamableConn.ctx.Done():
case <-time.After(100 * time.Millisecond):
t.Error("Connection context was not cancelled when parent was cancelled")
}

}