Skip to content

Commit 042111e

Browse files
author
e.s.prilutskiy
committed
Fix bug with notification messages order
1 parent 434e635 commit 042111e

File tree

4 files changed

+470
-19
lines changed

4 files changed

+470
-19
lines changed

acp_test.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package acp
22

33
import (
44
"context"
5+
"encoding/json"
56
"io"
67
"slices"
78
"sync"
@@ -467,6 +468,107 @@ func TestConnectionHandlesNotifications(t *testing.T) {
467468
}
468469
}
469470

471+
func TestConnection_DoesNotCancelInboundContextBeforeDrainingNotificationsOnDisconnect(t *testing.T) {
472+
const n = 25
473+
474+
incomingR, incomingW := io.Pipe()
475+
476+
var (
477+
wg sync.WaitGroup
478+
canceledCount atomic.Int64
479+
)
480+
wg.Add(n)
481+
482+
c := NewConnection(func(ctx context.Context, method string, _ json.RawMessage) (any, *RequestError) {
483+
defer wg.Done()
484+
// Slow down processing so some notifications are handled after the receive
485+
// loop observes EOF and signals disconnect.
486+
time.Sleep(10 * time.Millisecond)
487+
if ctx.Err() != nil {
488+
canceledCount.Add(1)
489+
}
490+
return nil, nil
491+
}, io.Discard, incomingR)
492+
493+
// Write notifications quickly and then close the stream to simulate a peer disconnect.
494+
for i := 0; i < n; i++ {
495+
if _, err := io.WriteString(incomingW, `{"jsonrpc":"2.0","method":"test/notify","params":{}}`+"\n"); err != nil {
496+
t.Fatalf("write notification: %v", err)
497+
}
498+
}
499+
_ = incomingW.Close()
500+
501+
select {
502+
case <-c.Done():
503+
// Expected: peer disconnect observed promptly.
504+
case <-time.After(2 * time.Second):
505+
t.Fatalf("timeout waiting for connection Done()")
506+
}
507+
508+
done := make(chan struct{})
509+
go func() {
510+
wg.Wait()
511+
close(done)
512+
}()
513+
select {
514+
case <-done:
515+
case <-time.After(3 * time.Second):
516+
t.Fatalf("timeout waiting for notification handlers")
517+
}
518+
519+
if got := canceledCount.Load(); got != 0 {
520+
t.Fatalf("inbound handler context was canceled for %d/%d notifications", got, n)
521+
}
522+
}
523+
524+
func TestConnection_CancelsRequestHandlersOnDisconnectEvenWithNotificationBacklog(t *testing.T) {
525+
const numNotifications = 200
526+
527+
incomingR, incomingW := io.Pipe()
528+
529+
reqDone := make(chan struct{})
530+
531+
c := NewConnection(func(ctx context.Context, method string, _ json.RawMessage) (any, *RequestError) {
532+
switch method {
533+
case "test/notify":
534+
// Slow down to create a backlog of queued notifications.
535+
time.Sleep(5 * time.Millisecond)
536+
return nil, nil
537+
case "test/request":
538+
// Requests should be canceled promptly on disconnect (uses c.ctx).
539+
<-ctx.Done()
540+
close(reqDone)
541+
return nil, NewInternalError(map[string]any{"error": "canceled"})
542+
default:
543+
return nil, nil
544+
}
545+
}, io.Discard, incomingR)
546+
547+
for i := 0; i < numNotifications; i++ {
548+
if _, err := io.WriteString(incomingW, `{"jsonrpc":"2.0","method":"test/notify","params":{}}`+"\n"); err != nil {
549+
t.Fatalf("write notification: %v", err)
550+
}
551+
}
552+
if _, err := io.WriteString(incomingW, `{"jsonrpc":"2.0","id":1,"method":"test/request","params":{}}`+"\n"); err != nil {
553+
t.Fatalf("write request: %v", err)
554+
}
555+
_ = incomingW.Close()
556+
557+
// Disconnect should be observed quickly.
558+
select {
559+
case <-c.Done():
560+
case <-time.After(2 * time.Second):
561+
t.Fatalf("timeout waiting for connection Done()")
562+
}
563+
564+
// Even with a big notification backlog, the request handler should be canceled promptly.
565+
select {
566+
case <-reqDone:
567+
case <-time.After(1 * time.Second):
568+
t.Fatalf("timeout waiting for request handler cancellation")
569+
}
570+
}
571+
470572
// Test initialize method behavior
471573
func TestConnectionHandlesInitialize(t *testing.T) {
472574
c2aR, c2aW := io.Pipe()

connection.go

Lines changed: 92 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ import (
1010
"log/slog"
1111
"sync"
1212
"sync/atomic"
13+
"time"
14+
)
15+
16+
const (
17+
notificationQueueDrainTimeout = 5 * time.Second
1318
)
1419

1520
type anyMessage struct {
@@ -37,27 +42,45 @@ type Connection struct {
3742
nextID atomic.Uint64
3843
pending map[string]*pendingResponse
3944

45+
// ctx/cancel govern connection lifetime and are used for Done() and for canceling
46+
// callers waiting on responses when the peer disconnects.
4047
ctx context.Context
4148
cancel context.CancelCauseFunc
4249

50+
// inboundCtx/inboundCancel are used when invoking the inbound MethodHandler.
51+
// This ctx is intentionally kept alive long enough to process notifications
52+
// that were successfully received and queued just before a peer disconnect.
53+
// Otherwise, handlers that respect context cancellation may drop end-of-connection
54+
// messages that we already read off the wire.
55+
inboundCtx context.Context
56+
inboundCancel context.CancelCauseFunc
57+
4358
logger *slog.Logger
4459

4560
// notificationWg tracks in-flight notification handlers. This ensures SendRequest waits
4661
// for all notifications received before the response to complete processing.
4762
notificationWg sync.WaitGroup
63+
64+
// notificationQueue serializes notification processing to maintain order
65+
notificationQueue *unboundedQueue[*anyMessage]
4866
}
4967

5068
func NewConnection(handler MethodHandler, peerInput io.Writer, peerOutput io.Reader) *Connection {
5169
ctx, cancel := context.WithCancelCause(context.Background())
70+
inboundCtx, inboundCancel := context.WithCancelCause(context.Background())
5271
c := &Connection{
53-
w: peerInput,
54-
r: peerOutput,
55-
handler: handler,
56-
pending: make(map[string]*pendingResponse),
57-
ctx: ctx,
58-
cancel: cancel,
72+
w: peerInput,
73+
r: peerOutput,
74+
handler: handler,
75+
pending: make(map[string]*pendingResponse),
76+
ctx: ctx,
77+
cancel: cancel,
78+
inboundCtx: inboundCtx,
79+
inboundCancel: inboundCancel,
80+
notificationQueue: newUnboundedQueue[*anyMessage](),
5981
}
6082
go c.receive()
83+
go c.processNotifications()
6184
return c
6285
}
6386

@@ -98,27 +121,68 @@ func (c *Connection) receive() {
98121
case msg.ID != nil && msg.Method == "":
99122
c.handleResponse(&msg)
100123
case msg.Method != "":
101-
// Only track notifications (no ID) in the WaitGroup, not requests (with ID).
102-
// This prevents deadlock when a request handler makes another request.
103-
isNotification := msg.ID == nil
104-
if isNotification {
105-
c.notificationWg.Add(1)
124+
// Requests (method+id) must not be serialized behind notifications, otherwise
125+
// a long-running request (e.g. session/prompt) can deadlock cancellation
126+
// notifications (session/cancel) that are required to stop it.
127+
if msg.ID != nil {
128+
m := msg
129+
go c.handleInbound(&m)
130+
continue
106131
}
107-
go func(m *anyMessage, isNotif bool) {
108-
if isNotif {
109-
defer c.notificationWg.Done()
110-
}
111-
c.handleInbound(m)
112-
}(&msg, isNotification)
132+
133+
c.notificationWg.Add(1)
134+
135+
// Queue the notification for sequential processing.
136+
// The unbounded queue never blocks, preserving ordering while
137+
// ensuring the receive loop can always read responses promptly.
138+
m := msg
139+
c.notificationQueue.push(&m)
113140
default:
114141
c.loggerOrDefault().Error("received message with neither id nor method", "raw", string(line))
115142
}
116143
}
117144

118-
c.cancel(errors.New("peer connection closed"))
145+
cause := errors.New("peer connection closed")
146+
147+
// First, signal disconnect to callers waiting on responses.
148+
c.cancel(cause)
149+
150+
// Then close the notification queue so already-received messages can drain.
151+
// IMPORTANT: Do not block this receive goroutine waiting for the drain to complete;
152+
// notification handlers may legitimately block until their context is canceled.
153+
c.notificationQueue.close()
154+
155+
// Cancel inboundCtx after notifications finish, but ensure we don't leak forever if a
156+
// handler blocks waiting for cancellation.
157+
go func() {
158+
done := make(chan struct{})
159+
go func() {
160+
c.notificationWg.Wait()
161+
close(done)
162+
}()
163+
select {
164+
case <-done:
165+
case <-time.After(notificationQueueDrainTimeout):
166+
}
167+
c.inboundCancel(cause)
168+
}()
169+
119170
c.loggerOrDefault().Info("peer connection closed")
120171
}
121172

173+
// processNotifications processes notifications sequentially to maintain order.
174+
// It terminates when notificationQueue is closed (e.g. on disconnect in receive()).
175+
func (c *Connection) processNotifications() {
176+
for {
177+
msg, ok := c.notificationQueue.pop()
178+
if !ok {
179+
return
180+
}
181+
c.handleInbound(msg)
182+
c.notificationWg.Done()
183+
}
184+
}
185+
122186
func (c *Connection) handleResponse(msg *anyMessage) {
123187
idStr := string(*msg.ID)
124188

@@ -136,6 +200,15 @@ func (c *Connection) handleResponse(msg *anyMessage) {
136200

137201
func (c *Connection) handleInbound(req *anyMessage) {
138202
res := anyMessage{JSONRPC: "2.0"}
203+
204+
// Notifications are allowed a slightly longer-lived context during disconnect so we can
205+
// process already-received end-of-connection messages. Requests, however, should be
206+
// canceled promptly when the peer disconnects to avoid doing unnecessary work after
207+
// the caller is gone.
208+
ctx := c.ctx
209+
if req.ID == nil {
210+
ctx = c.inboundCtx
211+
}
139212
// copy ID if present
140213
if req.ID != nil {
141214
res.ID = req.ID
@@ -148,7 +221,7 @@ func (c *Connection) handleInbound(req *anyMessage) {
148221
return
149222
}
150223

151-
result, err := c.handler(c.ctx, req.Method, req.Params)
224+
result, err := c.handler(ctx, req.Method, req.Params)
152225
if req.ID == nil {
153226
// Notification: no response is sent; log handler errors to surface decode failures.
154227
if err != nil {

unboundedqueue.go

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
package acp
2+
3+
import "sync"
4+
5+
// unboundedQueue is a thread-safe FIFO queue that never blocks on push.
6+
// This ensures the receive loop can always enqueue notifications without
7+
// stalling, while preserving strict ordering for the consumer.
8+
type unboundedQueue[T any] struct {
9+
mu sync.Mutex
10+
cond *sync.Cond
11+
items []T
12+
closed bool
13+
}
14+
15+
func newUnboundedQueue[T any]() *unboundedQueue[T] {
16+
q := &unboundedQueue[T]{}
17+
q.cond = sync.NewCond(&q.mu)
18+
return q
19+
}
20+
21+
// push appends an item to the queue. Never blocks.
22+
func (q *unboundedQueue[T]) push(item T) {
23+
q.mu.Lock()
24+
q.items = append(q.items, item)
25+
q.mu.Unlock()
26+
q.cond.Signal()
27+
}
28+
29+
// pop removes and returns the next item, blocking until one is available.
30+
// Returns the zero value and false if the queue is closed and empty.
31+
func (q *unboundedQueue[T]) pop() (T, bool) {
32+
q.mu.Lock()
33+
defer q.mu.Unlock()
34+
for len(q.items) == 0 && !q.closed {
35+
q.cond.Wait()
36+
}
37+
if len(q.items) == 0 {
38+
var zero T
39+
return zero, false
40+
}
41+
item := q.items[0]
42+
q.items = q.items[1:]
43+
return item, true
44+
}
45+
46+
// close signals that no more items will be pushed.
47+
// The consumer will drain remaining items before pop returns false.
48+
func (q *unboundedQueue[T]) close() {
49+
q.mu.Lock()
50+
q.closed = true
51+
q.mu.Unlock()
52+
q.cond.Broadcast()
53+
}
54+
55+
// len returns the current number of items in the queue.
56+
func (q *unboundedQueue[T]) len() int {
57+
q.mu.Lock()
58+
defer q.mu.Unlock()
59+
return len(q.items)
60+
}

0 commit comments

Comments
 (0)