Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions acp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,166 @@ func TestConnectionHandlesNotifications(t *testing.T) {
}
}

func TestConnectionDoesNotCancelInboundContextBeforeDrainingNotificationsOnDisconnect(t *testing.T) {
const n = 25

incomingR, incomingW := io.Pipe()

var (
wg sync.WaitGroup
canceledCount atomic.Int64
)
wg.Add(n)

c := NewConnection(func(ctx context.Context, method string, _ json.RawMessage) (any, *RequestError) {
defer wg.Done()
// Slow down processing so some notifications are handled after the receive
// loop observes EOF and signals disconnect.
time.Sleep(10 * time.Millisecond)
if ctx.Err() != nil {
canceledCount.Add(1)
}
return nil, nil
}, io.Discard, incomingR)

// Write notifications quickly and then close the stream to simulate a peer disconnect.
for i := 0; i < n; i++ {
if _, err := io.WriteString(incomingW, `{"jsonrpc":"2.0","method":"test/notify","params":{}}`+"\n"); err != nil {
t.Fatalf("write notification: %v", err)
}
}
_ = incomingW.Close()

select {
case <-c.Done():
// Expected: peer disconnect observed promptly.
case <-time.After(2 * time.Second):
t.Fatalf("timeout waiting for connection Done()")
}

done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(3 * time.Second):
t.Fatalf("timeout waiting for notification handlers")
}

if got := canceledCount.Load(); got != 0 {
t.Fatalf("inbound handler context was canceled for %d/%d notifications", got, n)
}
}

func TestConnectionCancelsRequestHandlersOnDisconnectEvenWithNotificationBacklog(t *testing.T) {
const numNotifications = 200

incomingR, incomingW := io.Pipe()

reqDone := make(chan struct{})

c := NewConnection(func(ctx context.Context, method string, _ json.RawMessage) (any, *RequestError) {
switch method {
case "test/notify":
// Slow down to create a backlog of queued notifications.
time.Sleep(5 * time.Millisecond)
return nil, nil
case "test/request":
// Requests should be canceled promptly on disconnect (uses c.ctx).
<-ctx.Done()
close(reqDone)
return nil, NewInternalError(map[string]any{"error": "canceled"})
default:
return nil, nil
}
}, io.Discard, incomingR)

for i := 0; i < numNotifications; i++ {
if _, err := io.WriteString(incomingW, `{"jsonrpc":"2.0","method":"test/notify","params":{}}`+"\n"); err != nil {
t.Fatalf("write notification: %v", err)
}
}
if _, err := io.WriteString(incomingW, `{"jsonrpc":"2.0","id":1,"method":"test/request","params":{}}`+"\n"); err != nil {
t.Fatalf("write request: %v", err)
}
_ = incomingW.Close()

// Disconnect should be observed quickly.
select {
case <-c.Done():
case <-time.After(2 * time.Second):
t.Fatalf("timeout waiting for connection Done()")
}

// Even with a big notification backlog, the request handler should be canceled promptly.
select {
case <-reqDone:
case <-time.After(1 * time.Second):
t.Fatalf("timeout waiting for request handler cancellation")
}
}

func TestConnectionFailsFastOnNotificationQueueOverflow(t *testing.T) {
incomingR, incomingW := io.Pipe()

// Block the first notification handler so the queue can fill deterministically.
firstStarted := make(chan struct{})
releaseFirst := make(chan struct{})
var handled atomic.Int64

c := NewConnection(func(context.Context, string, json.RawMessage) (any, *RequestError) {
if handled.Add(1) == 1 {
close(firstStarted)
<-releaseFirst
}
return nil, nil
}, io.Discard, incomingR)

if _, err := io.WriteString(incomingW, `{"jsonrpc":"2.0","method":"test/notify","params":{}}`+"\n"); err != nil {
t.Fatalf("write first notification: %v", err)
}
select {
case <-firstStarted:
case <-time.After(1 * time.Second):
t.Fatalf("timeout waiting for first notification handler to start")
}

// Fill the buffered queue, then send one extra notification to force overflow.
for i := 0; i < defaultMaxQueuedNotifications+1; i++ {
if _, err := io.WriteString(incomingW, `{"jsonrpc":"2.0","method":"test/notify","params":{}}`+"\n"); err != nil {
t.Fatalf("write overflow notification %d: %v", i, err)
}
}

select {
case <-c.Done():
case <-time.After(1 * time.Second):
t.Fatalf("timeout waiting for connection cancellation on queue overflow")
}

cause := context.Cause(c.ctx)
if !errors.Is(cause, errNotificationQueueOverflow) {
t.Fatalf("expected overflow cancellation cause, got %v", cause)
}

// Let queued work drain and ensure waitgroup accounting remains balanced.
close(releaseFirst)

drained := make(chan struct{})
go func() {
c.notificationWg.Wait()
close(drained)
}()

select {
case <-drained:
case <-time.After(1 * time.Second):
t.Fatalf("notification waitgroup did not drain after overflow")
}
}

// Test initialize method behavior
func TestConnectionHandlesInitialize(t *testing.T) {
c2aR, c2aW := io.Pipe()
Expand Down
129 changes: 109 additions & 20 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,16 @@ import (
"strings"
"sync"
"sync/atomic"
"time"
)

const (
notificationQueueDrainTimeout = 5 * time.Second
defaultMaxQueuedNotifications = 1024
)

var errNotificationQueueOverflow = errors.New("notification queue overflow")

type anyMessage struct {
JSONRPC string `json:"jsonrpc"`
ID *json.RawMessage `json:"id,omitempty"`
Expand All @@ -38,27 +46,46 @@ type Connection struct {
nextID atomic.Uint64
pending map[string]*pendingResponse

// ctx/cancel govern connection lifetime and are used for Done() and for canceling
// callers waiting on responses when the peer disconnects.
ctx context.Context
cancel context.CancelCauseFunc

// inboundCtx/inboundCancel are used when invoking the inbound MethodHandler.
// This ctx is intentionally kept alive long enough to process notifications
// that were successfully received and queued just before a peer disconnect.
// Otherwise, handlers that respect context cancellation may drop end-of-connection
// messages that we already read off the wire.
inboundCtx context.Context
inboundCancel context.CancelCauseFunc

logger *slog.Logger

// notificationWg tracks in-flight notification handlers. This ensures SendRequest waits
// for all notifications received before the response to complete processing.
notificationWg sync.WaitGroup

// notificationQueue serializes notification processing to maintain order.
// It is bounded to keep memory usage predictable.
notificationQueue chan *anyMessage
}

func NewConnection(handler MethodHandler, peerInput io.Writer, peerOutput io.Reader) *Connection {
ctx, cancel := context.WithCancelCause(context.Background())
inboundCtx, inboundCancel := context.WithCancelCause(context.Background())
c := &Connection{
w: peerInput,
r: peerOutput,
handler: handler,
pending: make(map[string]*pendingResponse),
ctx: ctx,
cancel: cancel,
w: peerInput,
r: peerOutput,
handler: handler,
pending: make(map[string]*pendingResponse),
ctx: ctx,
cancel: cancel,
inboundCtx: inboundCtx,
inboundCancel: inboundCancel,
notificationQueue: make(chan *anyMessage, defaultMaxQueuedNotifications),
}
go c.receive()
go c.processNotifications()
return c
}

Expand Down Expand Up @@ -99,25 +126,78 @@ func (c *Connection) receive() {
case msg.ID != nil && msg.Method == "":
c.handleResponse(&msg)
case msg.Method != "":
// Only track notifications (no ID) in the WaitGroup, not requests (with ID).
// This prevents deadlock when a request handler makes another request.
isNotification := msg.ID == nil
if isNotification {
c.notificationWg.Add(1)
// Requests (method+id) must not be serialized behind notifications, otherwise
// a long-running request (e.g. session/prompt) can deadlock cancellation
// notifications (session/cancel) that are required to stop it.
if msg.ID != nil {
m := msg
go c.handleInbound(&m)
continue
}

c.notificationWg.Add(1)

// Queue the notification for sequential processing.
m := msg
select {
case c.notificationQueue <- &m:
default:
// Balance Add above when the message was not accepted.
c.notificationWg.Done()
c.loggerOrDefault().Error("failed to queue notification; closing connection", "err", errNotificationQueueOverflow, "capacity", cap(c.notificationQueue), "queued", len(c.notificationQueue))
c.shutdownReceive(errNotificationQueueOverflow)
return
}
go func(m *anyMessage, isNotif bool) {
if isNotif {
defer c.notificationWg.Done()
}
c.handleInbound(m)
}(&msg, isNotification)
default:
c.loggerOrDefault().Error("received message with neither id nor method", "raw", string(line))
}
}

c.cancel(errors.New("peer connection closed"))
c.loggerOrDefault().Info("peer connection closed")
cause := errors.New("peer connection closed")
if err := scanner.Err(); err != nil {
cause = err
}
c.shutdownReceive(cause)
}

func (c *Connection) shutdownReceive(cause error) {
if cause == nil {
cause = errors.New("connection closed")
}

// First, signal disconnect to callers waiting on responses.
c.cancel(cause)

// Then close the notification queue so already-received messages can drain.
// IMPORTANT: Do not block this receive goroutine waiting for the drain to complete;
// notification handlers may legitimately block until their context is canceled.
close(c.notificationQueue)

// Cancel inboundCtx after notifications finish, but ensure we don't leak forever if a
// handler blocks waiting for cancellation.
go func() {
done := make(chan struct{})
go func() {
c.notificationWg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(notificationQueueDrainTimeout):
}
c.inboundCancel(cause)
}()

c.loggerOrDefault().Info("connection closed", "cause", cause.Error())
}

// processNotifications processes notifications sequentially to maintain order.
// It terminates when notificationQueue is closed (e.g. on disconnect in receive()).
func (c *Connection) processNotifications() {
for msg := range c.notificationQueue {
c.handleInbound(msg)
c.notificationWg.Done()
}
}

func (c *Connection) handleResponse(msg *anyMessage) {
Expand All @@ -137,6 +217,15 @@ func (c *Connection) handleResponse(msg *anyMessage) {

func (c *Connection) handleInbound(req *anyMessage) {
res := anyMessage{JSONRPC: "2.0"}

// Notifications are allowed a slightly longer-lived context during disconnect so we can
// process already-received end-of-connection messages. Requests, however, should be
// canceled promptly when the peer disconnects to avoid doing unnecessary work after
// the caller is gone.
ctx := c.ctx
if req.ID == nil {
ctx = c.inboundCtx
}
// copy ID if present
if req.ID != nil {
res.ID = req.ID
Expand All @@ -149,7 +238,7 @@ func (c *Connection) handleInbound(req *anyMessage) {
return
}

result, err := c.handler(c.ctx, req.Method, req.Params)
result, err := c.handler(ctx, req.Method, req.Params)
if req.ID == nil {
// Notification: no response is sent; log handler errors to surface decode failures.
if err != nil {
Expand Down