Skip to content

Commit 38d6de7

Browse files
krulsaidme0we.s.prilutskiy
andauthored
fix: serialize notification processing to maintain message order (#8)
Co-authored-by: e.s.prilutskiy <e.s.prilutskiy@tbank.ru>
1 parent 2b2c16c commit 38d6de7

File tree

2 files changed

+269
-20
lines changed

2 files changed

+269
-20
lines changed

acp_test.go

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,166 @@ func TestConnectionHandlesNotifications(t *testing.T) {
484484
}
485485
}
486486

487+
func TestConnectionDoesNotCancelInboundContextBeforeDrainingNotificationsOnDisconnect(t *testing.T) {
488+
const n = 25
489+
490+
incomingR, incomingW := io.Pipe()
491+
492+
var (
493+
wg sync.WaitGroup
494+
canceledCount atomic.Int64
495+
)
496+
wg.Add(n)
497+
498+
c := NewConnection(func(ctx context.Context, method string, _ json.RawMessage) (any, *RequestError) {
499+
defer wg.Done()
500+
// Slow down processing so some notifications are handled after the receive
501+
// loop observes EOF and signals disconnect.
502+
time.Sleep(10 * time.Millisecond)
503+
if ctx.Err() != nil {
504+
canceledCount.Add(1)
505+
}
506+
return nil, nil
507+
}, io.Discard, incomingR)
508+
509+
// Write notifications quickly and then close the stream to simulate a peer disconnect.
510+
for i := 0; i < n; i++ {
511+
if _, err := io.WriteString(incomingW, `{"jsonrpc":"2.0","method":"test/notify","params":{}}`+"\n"); err != nil {
512+
t.Fatalf("write notification: %v", err)
513+
}
514+
}
515+
_ = incomingW.Close()
516+
517+
select {
518+
case <-c.Done():
519+
// Expected: peer disconnect observed promptly.
520+
case <-time.After(2 * time.Second):
521+
t.Fatalf("timeout waiting for connection Done()")
522+
}
523+
524+
done := make(chan struct{})
525+
go func() {
526+
wg.Wait()
527+
close(done)
528+
}()
529+
select {
530+
case <-done:
531+
case <-time.After(3 * time.Second):
532+
t.Fatalf("timeout waiting for notification handlers")
533+
}
534+
535+
if got := canceledCount.Load(); got != 0 {
536+
t.Fatalf("inbound handler context was canceled for %d/%d notifications", got, n)
537+
}
538+
}
539+
540+
func TestConnectionCancelsRequestHandlersOnDisconnectEvenWithNotificationBacklog(t *testing.T) {
541+
const numNotifications = 200
542+
543+
incomingR, incomingW := io.Pipe()
544+
545+
reqDone := make(chan struct{})
546+
547+
c := NewConnection(func(ctx context.Context, method string, _ json.RawMessage) (any, *RequestError) {
548+
switch method {
549+
case "test/notify":
550+
// Slow down to create a backlog of queued notifications.
551+
time.Sleep(5 * time.Millisecond)
552+
return nil, nil
553+
case "test/request":
554+
// Requests should be canceled promptly on disconnect (uses c.ctx).
555+
<-ctx.Done()
556+
close(reqDone)
557+
return nil, NewInternalError(map[string]any{"error": "canceled"})
558+
default:
559+
return nil, nil
560+
}
561+
}, io.Discard, incomingR)
562+
563+
for i := 0; i < numNotifications; i++ {
564+
if _, err := io.WriteString(incomingW, `{"jsonrpc":"2.0","method":"test/notify","params":{}}`+"\n"); err != nil {
565+
t.Fatalf("write notification: %v", err)
566+
}
567+
}
568+
if _, err := io.WriteString(incomingW, `{"jsonrpc":"2.0","id":1,"method":"test/request","params":{}}`+"\n"); err != nil {
569+
t.Fatalf("write request: %v", err)
570+
}
571+
_ = incomingW.Close()
572+
573+
// Disconnect should be observed quickly.
574+
select {
575+
case <-c.Done():
576+
case <-time.After(2 * time.Second):
577+
t.Fatalf("timeout waiting for connection Done()")
578+
}
579+
580+
// Even with a big notification backlog, the request handler should be canceled promptly.
581+
select {
582+
case <-reqDone:
583+
case <-time.After(1 * time.Second):
584+
t.Fatalf("timeout waiting for request handler cancellation")
585+
}
586+
}
587+
588+
func TestConnectionFailsFastOnNotificationQueueOverflow(t *testing.T) {
589+
incomingR, incomingW := io.Pipe()
590+
591+
// Block the first notification handler so the queue can fill deterministically.
592+
firstStarted := make(chan struct{})
593+
releaseFirst := make(chan struct{})
594+
var handled atomic.Int64
595+
596+
c := NewConnection(func(context.Context, string, json.RawMessage) (any, *RequestError) {
597+
if handled.Add(1) == 1 {
598+
close(firstStarted)
599+
<-releaseFirst
600+
}
601+
return nil, nil
602+
}, io.Discard, incomingR)
603+
604+
if _, err := io.WriteString(incomingW, `{"jsonrpc":"2.0","method":"test/notify","params":{}}`+"\n"); err != nil {
605+
t.Fatalf("write first notification: %v", err)
606+
}
607+
select {
608+
case <-firstStarted:
609+
case <-time.After(1 * time.Second):
610+
t.Fatalf("timeout waiting for first notification handler to start")
611+
}
612+
613+
// Fill the buffered queue, then send one extra notification to force overflow.
614+
for i := 0; i < defaultMaxQueuedNotifications+1; i++ {
615+
if _, err := io.WriteString(incomingW, `{"jsonrpc":"2.0","method":"test/notify","params":{}}`+"\n"); err != nil {
616+
t.Fatalf("write overflow notification %d: %v", i, err)
617+
}
618+
}
619+
620+
select {
621+
case <-c.Done():
622+
case <-time.After(1 * time.Second):
623+
t.Fatalf("timeout waiting for connection cancellation on queue overflow")
624+
}
625+
626+
cause := context.Cause(c.ctx)
627+
if !errors.Is(cause, errNotificationQueueOverflow) {
628+
t.Fatalf("expected overflow cancellation cause, got %v", cause)
629+
}
630+
631+
// Let queued work drain and ensure waitgroup accounting remains balanced.
632+
close(releaseFirst)
633+
634+
drained := make(chan struct{})
635+
go func() {
636+
c.notificationWg.Wait()
637+
close(drained)
638+
}()
639+
640+
select {
641+
case <-drained:
642+
case <-time.After(1 * time.Second):
643+
t.Fatalf("notification waitgroup did not drain after overflow")
644+
}
645+
}
646+
487647
// Test initialize method behavior
488648
func TestConnectionHandlesInitialize(t *testing.T) {
489649
c2aR, c2aW := io.Pipe()

connection.go

Lines changed: 109 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,16 @@ import (
1111
"strings"
1212
"sync"
1313
"sync/atomic"
14+
"time"
1415
)
1516

17+
const (
18+
notificationQueueDrainTimeout = 5 * time.Second
19+
defaultMaxQueuedNotifications = 1024
20+
)
21+
22+
var errNotificationQueueOverflow = errors.New("notification queue overflow")
23+
1624
type anyMessage struct {
1725
JSONRPC string `json:"jsonrpc"`
1826
ID *json.RawMessage `json:"id,omitempty"`
@@ -38,27 +46,46 @@ type Connection struct {
3846
nextID atomic.Uint64
3947
pending map[string]*pendingResponse
4048

49+
// ctx/cancel govern connection lifetime and are used for Done() and for canceling
50+
// callers waiting on responses when the peer disconnects.
4151
ctx context.Context
4252
cancel context.CancelCauseFunc
4353

54+
// inboundCtx/inboundCancel are used when invoking the inbound MethodHandler.
55+
// This ctx is intentionally kept alive long enough to process notifications
56+
// that were successfully received and queued just before a peer disconnect.
57+
// Otherwise, handlers that respect context cancellation may drop end-of-connection
58+
// messages that we already read off the wire.
59+
inboundCtx context.Context
60+
inboundCancel context.CancelCauseFunc
61+
4462
logger *slog.Logger
4563

4664
// notificationWg tracks in-flight notification handlers. This ensures SendRequest waits
4765
// for all notifications received before the response to complete processing.
4866
notificationWg sync.WaitGroup
67+
68+
// notificationQueue serializes notification processing to maintain order.
69+
// It is bounded to keep memory usage predictable.
70+
notificationQueue chan *anyMessage
4971
}
5072

5173
func NewConnection(handler MethodHandler, peerInput io.Writer, peerOutput io.Reader) *Connection {
5274
ctx, cancel := context.WithCancelCause(context.Background())
75+
inboundCtx, inboundCancel := context.WithCancelCause(context.Background())
5376
c := &Connection{
54-
w: peerInput,
55-
r: peerOutput,
56-
handler: handler,
57-
pending: make(map[string]*pendingResponse),
58-
ctx: ctx,
59-
cancel: cancel,
77+
w: peerInput,
78+
r: peerOutput,
79+
handler: handler,
80+
pending: make(map[string]*pendingResponse),
81+
ctx: ctx,
82+
cancel: cancel,
83+
inboundCtx: inboundCtx,
84+
inboundCancel: inboundCancel,
85+
notificationQueue: make(chan *anyMessage, defaultMaxQueuedNotifications),
6086
}
6187
go c.receive()
88+
go c.processNotifications()
6289
return c
6390
}
6491

@@ -99,25 +126,78 @@ func (c *Connection) receive() {
99126
case msg.ID != nil && msg.Method == "":
100127
c.handleResponse(&msg)
101128
case msg.Method != "":
102-
// Only track notifications (no ID) in the WaitGroup, not requests (with ID).
103-
// This prevents deadlock when a request handler makes another request.
104-
isNotification := msg.ID == nil
105-
if isNotification {
106-
c.notificationWg.Add(1)
129+
// Requests (method+id) must not be serialized behind notifications, otherwise
130+
// a long-running request (e.g. session/prompt) can deadlock cancellation
131+
// notifications (session/cancel) that are required to stop it.
132+
if msg.ID != nil {
133+
m := msg
134+
go c.handleInbound(&m)
135+
continue
136+
}
137+
138+
c.notificationWg.Add(1)
139+
140+
// Queue the notification for sequential processing.
141+
m := msg
142+
select {
143+
case c.notificationQueue <- &m:
144+
default:
145+
// Balance Add above when the message was not accepted.
146+
c.notificationWg.Done()
147+
c.loggerOrDefault().Error("failed to queue notification; closing connection", "err", errNotificationQueueOverflow, "capacity", cap(c.notificationQueue), "queued", len(c.notificationQueue))
148+
c.shutdownReceive(errNotificationQueueOverflow)
149+
return
107150
}
108-
go func(m *anyMessage, isNotif bool) {
109-
if isNotif {
110-
defer c.notificationWg.Done()
111-
}
112-
c.handleInbound(m)
113-
}(&msg, isNotification)
114151
default:
115152
c.loggerOrDefault().Error("received message with neither id nor method", "raw", string(line))
116153
}
117154
}
118155

119-
c.cancel(errors.New("peer connection closed"))
120-
c.loggerOrDefault().Info("peer connection closed")
156+
cause := errors.New("peer connection closed")
157+
if err := scanner.Err(); err != nil {
158+
cause = err
159+
}
160+
c.shutdownReceive(cause)
161+
}
162+
163+
func (c *Connection) shutdownReceive(cause error) {
164+
if cause == nil {
165+
cause = errors.New("connection closed")
166+
}
167+
168+
// First, signal disconnect to callers waiting on responses.
169+
c.cancel(cause)
170+
171+
// Then close the notification queue so already-received messages can drain.
172+
// IMPORTANT: Do not block this receive goroutine waiting for the drain to complete;
173+
// notification handlers may legitimately block until their context is canceled.
174+
close(c.notificationQueue)
175+
176+
// Cancel inboundCtx after notifications finish, but ensure we don't leak forever if a
177+
// handler blocks waiting for cancellation.
178+
go func() {
179+
done := make(chan struct{})
180+
go func() {
181+
c.notificationWg.Wait()
182+
close(done)
183+
}()
184+
select {
185+
case <-done:
186+
case <-time.After(notificationQueueDrainTimeout):
187+
}
188+
c.inboundCancel(cause)
189+
}()
190+
191+
c.loggerOrDefault().Info("connection closed", "cause", cause.Error())
192+
}
193+
194+
// processNotifications processes notifications sequentially to maintain order.
195+
// It terminates when notificationQueue is closed (e.g. on disconnect in receive()).
196+
func (c *Connection) processNotifications() {
197+
for msg := range c.notificationQueue {
198+
c.handleInbound(msg)
199+
c.notificationWg.Done()
200+
}
121201
}
122202

123203
func (c *Connection) handleResponse(msg *anyMessage) {
@@ -137,6 +217,15 @@ func (c *Connection) handleResponse(msg *anyMessage) {
137217

138218
func (c *Connection) handleInbound(req *anyMessage) {
139219
res := anyMessage{JSONRPC: "2.0"}
220+
221+
// Notifications are allowed a slightly longer-lived context during disconnect so we can
222+
// process already-received end-of-connection messages. Requests, however, should be
223+
// canceled promptly when the peer disconnects to avoid doing unnecessary work after
224+
// the caller is gone.
225+
ctx := c.ctx
226+
if req.ID == nil {
227+
ctx = c.inboundCtx
228+
}
140229
// copy ID if present
141230
if req.ID != nil {
142231
res.ID = req.ID
@@ -149,7 +238,7 @@ func (c *Connection) handleInbound(req *anyMessage) {
149238
return
150239
}
151240

152-
result, err := c.handler(c.ctx, req.Method, req.Params)
241+
result, err := c.handler(ctx, req.Method, req.Params)
153242
if req.ID == nil {
154243
// Notification: no response is sent; log handler errors to surface decode failures.
155244
if err != nil {

0 commit comments

Comments
 (0)