Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions acp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package acp

import (
"context"
"encoding/json"
"io"
"slices"
"sync"
Expand Down Expand Up @@ -467,6 +468,107 @@ func TestConnectionHandlesNotifications(t *testing.T) {
}
}

func TestConnection_DoesNotCancelInboundContextBeforeDrainingNotificationsOnDisconnect(t *testing.T) {
const n = 25

incomingR, incomingW := io.Pipe()

var (
wg sync.WaitGroup
canceledCount atomic.Int64
)
wg.Add(n)

c := NewConnection(func(ctx context.Context, method string, _ json.RawMessage) (any, *RequestError) {
defer wg.Done()
// Slow down processing so some notifications are handled after the receive
// loop observes EOF and signals disconnect.
time.Sleep(10 * time.Millisecond)
if ctx.Err() != nil {
canceledCount.Add(1)
}
return nil, nil
}, io.Discard, incomingR)

// Write notifications quickly and then close the stream to simulate a peer disconnect.
for i := 0; i < n; i++ {
if _, err := io.WriteString(incomingW, `{"jsonrpc":"2.0","method":"test/notify","params":{}}`+"\n"); err != nil {
t.Fatalf("write notification: %v", err)
}
}
_ = incomingW.Close()

select {
case <-c.Done():
// Expected: peer disconnect observed promptly.
case <-time.After(2 * time.Second):
t.Fatalf("timeout waiting for connection Done()")
}

done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(3 * time.Second):
t.Fatalf("timeout waiting for notification handlers")
}

if got := canceledCount.Load(); got != 0 {
t.Fatalf("inbound handler context was canceled for %d/%d notifications", got, n)
}
}

func TestConnection_CancelsRequestHandlersOnDisconnectEvenWithNotificationBacklog(t *testing.T) {
const numNotifications = 200

incomingR, incomingW := io.Pipe()

reqDone := make(chan struct{})

c := NewConnection(func(ctx context.Context, method string, _ json.RawMessage) (any, *RequestError) {
switch method {
case "test/notify":
// Slow down to create a backlog of queued notifications.
time.Sleep(5 * time.Millisecond)
return nil, nil
case "test/request":
// Requests should be canceled promptly on disconnect (uses c.ctx).
<-ctx.Done()
close(reqDone)
return nil, NewInternalError(map[string]any{"error": "canceled"})
default:
return nil, nil
}
}, io.Discard, incomingR)

for i := 0; i < numNotifications; i++ {
if _, err := io.WriteString(incomingW, `{"jsonrpc":"2.0","method":"test/notify","params":{}}`+"\n"); err != nil {
t.Fatalf("write notification: %v", err)
}
}
if _, err := io.WriteString(incomingW, `{"jsonrpc":"2.0","id":1,"method":"test/request","params":{}}`+"\n"); err != nil {
t.Fatalf("write request: %v", err)
}
_ = incomingW.Close()

// Disconnect should be observed quickly.
select {
case <-c.Done():
case <-time.After(2 * time.Second):
t.Fatalf("timeout waiting for connection Done()")
}

// Even with a big notification backlog, the request handler should be canceled promptly.
select {
case <-reqDone:
case <-time.After(1 * time.Second):
t.Fatalf("timeout waiting for request handler cancellation")
}
}

// Test initialize method behavior
func TestConnectionHandlesInitialize(t *testing.T) {
c2aR, c2aW := io.Pipe()
Expand Down
111 changes: 92 additions & 19 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ import (
"log/slog"
"sync"
"sync/atomic"
"time"
)

const (
notificationQueueDrainTimeout = 5 * time.Second
)

type anyMessage struct {
Expand Down Expand Up @@ -37,27 +42,45 @@ type Connection struct {
nextID atomic.Uint64
pending map[string]*pendingResponse

// ctx/cancel govern connection lifetime and are used for Done() and for canceling
// callers waiting on responses when the peer disconnects.
ctx context.Context
cancel context.CancelCauseFunc

// inboundCtx/inboundCancel are used when invoking the inbound MethodHandler.
// This ctx is intentionally kept alive long enough to process notifications
// that were successfully received and queued just before a peer disconnect.
// Otherwise, handlers that respect context cancellation may drop end-of-connection
// messages that we already read off the wire.
inboundCtx context.Context
inboundCancel context.CancelCauseFunc

logger *slog.Logger

// notificationWg tracks in-flight notification handlers. This ensures SendRequest waits
// for all notifications received before the response to complete processing.
notificationWg sync.WaitGroup

// notificationQueue serializes notification processing to maintain order
notificationQueue *unboundedQueue[*anyMessage]
}

func NewConnection(handler MethodHandler, peerInput io.Writer, peerOutput io.Reader) *Connection {
ctx, cancel := context.WithCancelCause(context.Background())
inboundCtx, inboundCancel := context.WithCancelCause(context.Background())
c := &Connection{
w: peerInput,
r: peerOutput,
handler: handler,
pending: make(map[string]*pendingResponse),
ctx: ctx,
cancel: cancel,
w: peerInput,
r: peerOutput,
handler: handler,
pending: make(map[string]*pendingResponse),
ctx: ctx,
cancel: cancel,
inboundCtx: inboundCtx,
inboundCancel: inboundCancel,
notificationQueue: newUnboundedQueue[*anyMessage](),
}
go c.receive()
go c.processNotifications()
return c
}

Expand Down Expand Up @@ -98,27 +121,68 @@ func (c *Connection) receive() {
case msg.ID != nil && msg.Method == "":
c.handleResponse(&msg)
case msg.Method != "":
// Only track notifications (no ID) in the WaitGroup, not requests (with ID).
// This prevents deadlock when a request handler makes another request.
isNotification := msg.ID == nil
if isNotification {
c.notificationWg.Add(1)
// Requests (method+id) must not be serialized behind notifications, otherwise
// a long-running request (e.g. session/prompt) can deadlock cancellation
// notifications (session/cancel) that are required to stop it.
if msg.ID != nil {
m := msg
go c.handleInbound(&m)
continue
}
go func(m *anyMessage, isNotif bool) {
if isNotif {
defer c.notificationWg.Done()
}
c.handleInbound(m)
}(&msg, isNotification)

c.notificationWg.Add(1)

// Queue the notification for sequential processing.
// The unbounded queue never blocks, preserving ordering while
// ensuring the receive loop can always read responses promptly.
m := msg
c.notificationQueue.push(&m)
default:
c.loggerOrDefault().Error("received message with neither id nor method", "raw", string(line))
}
}

c.cancel(errors.New("peer connection closed"))
cause := errors.New("peer connection closed")

// First, signal disconnect to callers waiting on responses.
c.cancel(cause)

// Then close the notification queue so already-received messages can drain.
// IMPORTANT: Do not block this receive goroutine waiting for the drain to complete;
// notification handlers may legitimately block until their context is canceled.
c.notificationQueue.close()

// Cancel inboundCtx after notifications finish, but ensure we don't leak forever if a
// handler blocks waiting for cancellation.
go func() {
done := make(chan struct{})
go func() {
c.notificationWg.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(notificationQueueDrainTimeout):
}
c.inboundCancel(cause)
}()

c.loggerOrDefault().Info("peer connection closed")
}

// processNotifications processes notifications sequentially to maintain order.
// It terminates when notificationQueue is closed (e.g. on disconnect in receive()).
func (c *Connection) processNotifications() {
for {
msg, ok := c.notificationQueue.pop()
if !ok {
return
}
c.handleInbound(msg)
c.notificationWg.Done()
}
}

func (c *Connection) handleResponse(msg *anyMessage) {
idStr := string(*msg.ID)

Expand All @@ -136,6 +200,15 @@ func (c *Connection) handleResponse(msg *anyMessage) {

func (c *Connection) handleInbound(req *anyMessage) {
res := anyMessage{JSONRPC: "2.0"}

// Notifications are allowed a slightly longer-lived context during disconnect so we can
// process already-received end-of-connection messages. Requests, however, should be
// canceled promptly when the peer disconnects to avoid doing unnecessary work after
// the caller is gone.
ctx := c.ctx
if req.ID == nil {
ctx = c.inboundCtx
}
// copy ID if present
if req.ID != nil {
res.ID = req.ID
Expand All @@ -148,7 +221,7 @@ func (c *Connection) handleInbound(req *anyMessage) {
return
}

result, err := c.handler(c.ctx, req.Method, req.Params)
result, err := c.handler(ctx, req.Method, req.Params)
if req.ID == nil {
// Notification: no response is sent; log handler errors to surface decode failures.
if err != nil {
Expand Down
60 changes: 60 additions & 0 deletions unboundedqueue.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package acp

import "sync"

// unboundedQueue is a thread-safe FIFO queue that never blocks on push.
// This ensures the receive loop can always enqueue notifications without
// stalling, while preserving strict ordering for the consumer.
type unboundedQueue[T any] struct {
mu sync.Mutex
cond *sync.Cond
items []T
closed bool
}

func newUnboundedQueue[T any]() *unboundedQueue[T] {
q := &unboundedQueue[T]{}
q.cond = sync.NewCond(&q.mu)
return q
}

// push appends an item to the queue. Never blocks.
func (q *unboundedQueue[T]) push(item T) {
q.mu.Lock()
q.items = append(q.items, item)
q.mu.Unlock()
q.cond.Signal()
}

// pop removes and returns the next item, blocking until one is available.
// Returns the zero value and false if the queue is closed and empty.
func (q *unboundedQueue[T]) pop() (T, bool) {
q.mu.Lock()
defer q.mu.Unlock()
for len(q.items) == 0 && !q.closed {
q.cond.Wait()
}
if len(q.items) == 0 {
var zero T
return zero, false
}
item := q.items[0]
q.items = q.items[1:]
return item, true
}

// close signals that no more items will be pushed.
// The consumer will drain remaining items before pop returns false.
func (q *unboundedQueue[T]) close() {
q.mu.Lock()
q.closed = true
q.mu.Unlock()
q.cond.Broadcast()
}

// len returns the current number of items in the queue.
func (q *unboundedQueue[T]) len() int {
q.mu.Lock()
defer q.mu.Unlock()
return len(q.items)
}
Loading