diff --git a/acp_test.go b/acp_test.go index c5a7c37..e936fbe 100644 --- a/acp_test.go +++ b/acp_test.go @@ -2,6 +2,7 @@ package acp import ( "context" + "encoding/json" "io" "slices" "sync" @@ -467,6 +468,107 @@ func TestConnectionHandlesNotifications(t *testing.T) { } } +func TestConnection_DoesNotCancelInboundContextBeforeDrainingNotificationsOnDisconnect(t *testing.T) { + const n = 25 + + incomingR, incomingW := io.Pipe() + + var ( + wg sync.WaitGroup + canceledCount atomic.Int64 + ) + wg.Add(n) + + c := NewConnection(func(ctx context.Context, method string, _ json.RawMessage) (any, *RequestError) { + defer wg.Done() + // Slow down processing so some notifications are handled after the receive + // loop observes EOF and signals disconnect. + time.Sleep(10 * time.Millisecond) + if ctx.Err() != nil { + canceledCount.Add(1) + } + return nil, nil + }, io.Discard, incomingR) + + // Write notifications quickly and then close the stream to simulate a peer disconnect. + for i := 0; i < n; i++ { + if _, err := io.WriteString(incomingW, `{"jsonrpc":"2.0","method":"test/notify","params":{}}`+"\n"); err != nil { + t.Fatalf("write notification: %v", err) + } + } + _ = incomingW.Close() + + select { + case <-c.Done(): + // Expected: peer disconnect observed promptly. + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting for connection Done()") + } + + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatalf("timeout waiting for notification handlers") + } + + if got := canceledCount.Load(); got != 0 { + t.Fatalf("inbound handler context was canceled for %d/%d notifications", got, n) + } +} + +func TestConnection_CancelsRequestHandlersOnDisconnectEvenWithNotificationBacklog(t *testing.T) { + const numNotifications = 200 + + incomingR, incomingW := io.Pipe() + + reqDone := make(chan struct{}) + + c := NewConnection(func(ctx context.Context, method string, _ json.RawMessage) (any, *RequestError) { + switch method { + case "test/notify": + // Slow down to create a backlog of queued notifications. + time.Sleep(5 * time.Millisecond) + return nil, nil + case "test/request": + // Requests should be canceled promptly on disconnect (uses c.ctx). + <-ctx.Done() + close(reqDone) + return nil, NewInternalError(map[string]any{"error": "canceled"}) + default: + return nil, nil + } + }, io.Discard, incomingR) + + for i := 0; i < numNotifications; i++ { + if _, err := io.WriteString(incomingW, `{"jsonrpc":"2.0","method":"test/notify","params":{}}`+"\n"); err != nil { + t.Fatalf("write notification: %v", err) + } + } + if _, err := io.WriteString(incomingW, `{"jsonrpc":"2.0","id":1,"method":"test/request","params":{}}`+"\n"); err != nil { + t.Fatalf("write request: %v", err) + } + _ = incomingW.Close() + + // Disconnect should be observed quickly. + select { + case <-c.Done(): + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting for connection Done()") + } + + // Even with a big notification backlog, the request handler should be canceled promptly. + select { + case <-reqDone: + case <-time.After(1 * time.Second): + t.Fatalf("timeout waiting for request handler cancellation") + } +} + // Test initialize method behavior func TestConnectionHandlesInitialize(t *testing.T) { c2aR, c2aW := io.Pipe() diff --git a/connection.go b/connection.go index fc114e5..81e94e1 100644 --- a/connection.go +++ b/connection.go @@ -10,6 +10,11 @@ import ( "log/slog" "sync" "sync/atomic" + "time" +) + +const ( + notificationQueueDrainTimeout = 5 * time.Second ) type anyMessage struct { @@ -37,27 +42,45 @@ type Connection struct { nextID atomic.Uint64 pending map[string]*pendingResponse + // ctx/cancel govern connection lifetime and are used for Done() and for canceling + // callers waiting on responses when the peer disconnects. ctx context.Context cancel context.CancelCauseFunc + // inboundCtx/inboundCancel are used when invoking the inbound MethodHandler. + // This ctx is intentionally kept alive long enough to process notifications + // that were successfully received and queued just before a peer disconnect. + // Otherwise, handlers that respect context cancellation may drop end-of-connection + // messages that we already read off the wire. + inboundCtx context.Context + inboundCancel context.CancelCauseFunc + logger *slog.Logger // notificationWg tracks in-flight notification handlers. This ensures SendRequest waits // for all notifications received before the response to complete processing. notificationWg sync.WaitGroup + + // notificationQueue serializes notification processing to maintain order + notificationQueue *unboundedQueue[*anyMessage] } func NewConnection(handler MethodHandler, peerInput io.Writer, peerOutput io.Reader) *Connection { ctx, cancel := context.WithCancelCause(context.Background()) + inboundCtx, inboundCancel := context.WithCancelCause(context.Background()) c := &Connection{ - w: peerInput, - r: peerOutput, - handler: handler, - pending: make(map[string]*pendingResponse), - ctx: ctx, - cancel: cancel, + w: peerInput, + r: peerOutput, + handler: handler, + pending: make(map[string]*pendingResponse), + ctx: ctx, + cancel: cancel, + inboundCtx: inboundCtx, + inboundCancel: inboundCancel, + notificationQueue: newUnboundedQueue[*anyMessage](), } go c.receive() + go c.processNotifications() return c } @@ -98,27 +121,68 @@ func (c *Connection) receive() { case msg.ID != nil && msg.Method == "": c.handleResponse(&msg) case msg.Method != "": - // Only track notifications (no ID) in the WaitGroup, not requests (with ID). - // This prevents deadlock when a request handler makes another request. - isNotification := msg.ID == nil - if isNotification { - c.notificationWg.Add(1) + // Requests (method+id) must not be serialized behind notifications, otherwise + // a long-running request (e.g. session/prompt) can deadlock cancellation + // notifications (session/cancel) that are required to stop it. + if msg.ID != nil { + m := msg + go c.handleInbound(&m) + continue } - go func(m *anyMessage, isNotif bool) { - if isNotif { - defer c.notificationWg.Done() - } - c.handleInbound(m) - }(&msg, isNotification) + + c.notificationWg.Add(1) + + // Queue the notification for sequential processing. + // The unbounded queue never blocks, preserving ordering while + // ensuring the receive loop can always read responses promptly. + m := msg + c.notificationQueue.push(&m) default: c.loggerOrDefault().Error("received message with neither id nor method", "raw", string(line)) } } - c.cancel(errors.New("peer connection closed")) + cause := errors.New("peer connection closed") + + // First, signal disconnect to callers waiting on responses. + c.cancel(cause) + + // Then close the notification queue so already-received messages can drain. + // IMPORTANT: Do not block this receive goroutine waiting for the drain to complete; + // notification handlers may legitimately block until their context is canceled. + c.notificationQueue.close() + + // Cancel inboundCtx after notifications finish, but ensure we don't leak forever if a + // handler blocks waiting for cancellation. + go func() { + done := make(chan struct{}) + go func() { + c.notificationWg.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(notificationQueueDrainTimeout): + } + c.inboundCancel(cause) + }() + c.loggerOrDefault().Info("peer connection closed") } +// processNotifications processes notifications sequentially to maintain order. +// It terminates when notificationQueue is closed (e.g. on disconnect in receive()). +func (c *Connection) processNotifications() { + for { + msg, ok := c.notificationQueue.pop() + if !ok { + return + } + c.handleInbound(msg) + c.notificationWg.Done() + } +} + func (c *Connection) handleResponse(msg *anyMessage) { idStr := string(*msg.ID) @@ -136,6 +200,15 @@ func (c *Connection) handleResponse(msg *anyMessage) { func (c *Connection) handleInbound(req *anyMessage) { res := anyMessage{JSONRPC: "2.0"} + + // Notifications are allowed a slightly longer-lived context during disconnect so we can + // process already-received end-of-connection messages. Requests, however, should be + // canceled promptly when the peer disconnects to avoid doing unnecessary work after + // the caller is gone. + ctx := c.ctx + if req.ID == nil { + ctx = c.inboundCtx + } // copy ID if present if req.ID != nil { res.ID = req.ID @@ -148,7 +221,7 @@ func (c *Connection) handleInbound(req *anyMessage) { return } - result, err := c.handler(c.ctx, req.Method, req.Params) + result, err := c.handler(ctx, req.Method, req.Params) if req.ID == nil { // Notification: no response is sent; log handler errors to surface decode failures. if err != nil { diff --git a/unboundedqueue.go b/unboundedqueue.go new file mode 100644 index 0000000..d902ace --- /dev/null +++ b/unboundedqueue.go @@ -0,0 +1,60 @@ +package acp + +import "sync" + +// unboundedQueue is a thread-safe FIFO queue that never blocks on push. +// This ensures the receive loop can always enqueue notifications without +// stalling, while preserving strict ordering for the consumer. +type unboundedQueue[T any] struct { + mu sync.Mutex + cond *sync.Cond + items []T + closed bool +} + +func newUnboundedQueue[T any]() *unboundedQueue[T] { + q := &unboundedQueue[T]{} + q.cond = sync.NewCond(&q.mu) + return q +} + +// push appends an item to the queue. Never blocks. +func (q *unboundedQueue[T]) push(item T) { + q.mu.Lock() + q.items = append(q.items, item) + q.mu.Unlock() + q.cond.Signal() +} + +// pop removes and returns the next item, blocking until one is available. +// Returns the zero value and false if the queue is closed and empty. +func (q *unboundedQueue[T]) pop() (T, bool) { + q.mu.Lock() + defer q.mu.Unlock() + for len(q.items) == 0 && !q.closed { + q.cond.Wait() + } + if len(q.items) == 0 { + var zero T + return zero, false + } + item := q.items[0] + q.items = q.items[1:] + return item, true +} + +// close signals that no more items will be pushed. +// The consumer will drain remaining items before pop returns false. +func (q *unboundedQueue[T]) close() { + q.mu.Lock() + q.closed = true + q.mu.Unlock() + q.cond.Broadcast() +} + +// len returns the current number of items in the queue. +func (q *unboundedQueue[T]) len() int { + q.mu.Lock() + defer q.mu.Unlock() + return len(q.items) +} diff --git a/unboundedqueue_test.go b/unboundedqueue_test.go new file mode 100644 index 0000000..f758889 --- /dev/null +++ b/unboundedqueue_test.go @@ -0,0 +1,216 @@ +package acp + +import ( + "sync" + "testing" + "time" +) + +func TestUnboundedQueuePushPop(t *testing.T) { + q := newUnboundedQueue[int]() + + q.push(1) + q.push(2) + q.push(3) + + v, ok := q.pop() + if !ok || v != 1 { + t.Errorf("expected (1, true), got (%d, %v)", v, ok) + } + + v, ok = q.pop() + if !ok || v != 2 { + t.Errorf("expected (2, true), got (%d, %v)", v, ok) + } + + v, ok = q.pop() + if !ok || v != 3 { + t.Errorf("expected (3, true), got (%d, %v)", v, ok) + } +} + +func TestUnboundedQueuePopBlocksUntilPush(t *testing.T) { + q := newUnboundedQueue[string]() + + done := make(chan string) + go func() { + v, ok := q.pop() + if ok { + done <- v + } + }() + + // Give the goroutine time to block + time.Sleep(10 * time.Millisecond) + + select { + case <-done: + t.Fatal("pop should have blocked") + default: + } + + q.push("hello") + + select { + case v := <-done: + if v != "hello" { + t.Errorf("expected 'hello', got %q", v) + } + case <-time.After(100 * time.Millisecond): + t.Fatal("pop did not unblock after push") + } +} + +func TestUnboundedQueueCloseUnblocksPop(t *testing.T) { + q := newUnboundedQueue[int]() + + done := make(chan bool) + go func() { + _, ok := q.pop() + done <- ok + }() + + // Give the goroutine time to block + time.Sleep(10 * time.Millisecond) + + q.close() + + select { + case ok := <-done: + if ok { + t.Error("expected ok=false after close on empty queue") + } + case <-time.After(100 * time.Millisecond): + t.Fatal("pop did not unblock after close") + } +} + +func TestUnboundedQueueDrainAfterClose(t *testing.T) { + q := newUnboundedQueue[int]() + + q.push(1) + q.push(2) + q.close() + + // Should still be able to drain existing items + v, ok := q.pop() + if !ok || v != 1 { + t.Errorf("expected (1, true), got (%d, %v)", v, ok) + } + + v, ok = q.pop() + if !ok || v != 2 { + t.Errorf("expected (2, true), got (%d, %v)", v, ok) + } + + // Now should return false + _, ok = q.pop() + if ok { + t.Error("expected ok=false after draining closed queue") + } +} + +func TestUnboundedQueueLen(t *testing.T) { + q := newUnboundedQueue[int]() + + if q.len() != 0 { + t.Errorf("expected len 0, got %d", q.len()) + } + + q.push(1) + q.push(2) + + if q.len() != 2 { + t.Errorf("expected len 2, got %d", q.len()) + } + + q.pop() + + if q.len() != 1 { + t.Errorf("expected len 1, got %d", q.len()) + } +} + +func TestUnboundedQueueConcurrentPushPop(t *testing.T) { + q := newUnboundedQueue[int]() + const n = 1000 + + var wg sync.WaitGroup + + // Producer + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < n; i++ { + q.push(i) + } + q.close() + }() + + // Consumer - verify ordering + received := make([]int, 0, n) + wg.Add(1) + go func() { + defer wg.Done() + for { + v, ok := q.pop() + if !ok { + break + } + received = append(received, v) + } + }() + + wg.Wait() + + if len(received) != n { + t.Fatalf("expected %d items, got %d", n, len(received)) + } + + for i, v := range received { + if v != i { + t.Errorf("ordering broken: expected %d at index %d, got %d", i, i, v) + break + } + } +} + +func TestUnboundedQueueMultipleProducers(t *testing.T) { + q := newUnboundedQueue[int]() + const producers = 10 + const itemsPerProducer = 100 + + var wg sync.WaitGroup + + // Multiple producers + for p := 0; p < producers; p++ { + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < itemsPerProducer; i++ { + q.push(i) + } + }() + } + + // Wait for all producers, then close + go func() { + wg.Wait() + q.close() + }() + + // Consumer + count := 0 + for { + _, ok := q.pop() + if !ok { + break + } + count++ + } + + expected := producers * itemsPerProducer + if count != expected { + t.Errorf("expected %d items, got %d", expected, count) + } +}