@@ -11,8 +11,16 @@ import (
1111 "strings"
1212 "sync"
1313 "sync/atomic"
14+ "time"
1415)
1516
17+ const (
18+ notificationQueueDrainTimeout = 5 * time .Second
19+ defaultMaxQueuedNotifications = 1024
20+ )
21+
22+ var errNotificationQueueOverflow = errors .New ("notification queue overflow" )
23+
1624type anyMessage struct {
1725 JSONRPC string `json:"jsonrpc"`
1826 ID * json.RawMessage `json:"id,omitempty"`
@@ -38,27 +46,46 @@ type Connection struct {
3846 nextID atomic.Uint64
3947 pending map [string ]* pendingResponse
4048
49+ // ctx/cancel govern connection lifetime and are used for Done() and for canceling
50+ // callers waiting on responses when the peer disconnects.
4151 ctx context.Context
4252 cancel context.CancelCauseFunc
4353
54+ // inboundCtx/inboundCancel are used when invoking the inbound MethodHandler.
55+ // This ctx is intentionally kept alive long enough to process notifications
56+ // that were successfully received and queued just before a peer disconnect.
57+ // Otherwise, handlers that respect context cancellation may drop end-of-connection
58+ // messages that we already read off the wire.
59+ inboundCtx context.Context
60+ inboundCancel context.CancelCauseFunc
61+
4462 logger * slog.Logger
4563
4664 // notificationWg tracks in-flight notification handlers. This ensures SendRequest waits
4765 // for all notifications received before the response to complete processing.
4866 notificationWg sync.WaitGroup
67+
68+ // notificationQueue serializes notification processing to maintain order.
69+ // It is bounded to keep memory usage predictable.
70+ notificationQueue chan * anyMessage
4971}
5072
5173func NewConnection (handler MethodHandler , peerInput io.Writer , peerOutput io.Reader ) * Connection {
5274 ctx , cancel := context .WithCancelCause (context .Background ())
75+ inboundCtx , inboundCancel := context .WithCancelCause (context .Background ())
5376 c := & Connection {
54- w : peerInput ,
55- r : peerOutput ,
56- handler : handler ,
57- pending : make (map [string ]* pendingResponse ),
58- ctx : ctx ,
59- cancel : cancel ,
77+ w : peerInput ,
78+ r : peerOutput ,
79+ handler : handler ,
80+ pending : make (map [string ]* pendingResponse ),
81+ ctx : ctx ,
82+ cancel : cancel ,
83+ inboundCtx : inboundCtx ,
84+ inboundCancel : inboundCancel ,
85+ notificationQueue : make (chan * anyMessage , defaultMaxQueuedNotifications ),
6086 }
6187 go c .receive ()
88+ go c .processNotifications ()
6289 return c
6390}
6491
@@ -99,25 +126,78 @@ func (c *Connection) receive() {
99126 case msg .ID != nil && msg .Method == "" :
100127 c .handleResponse (& msg )
101128 case msg .Method != "" :
102- // Only track notifications (no ID) in the WaitGroup, not requests (with ID).
103- // This prevents deadlock when a request handler makes another request.
104- isNotification := msg .ID == nil
105- if isNotification {
106- c .notificationWg .Add (1 )
129+ // Requests (method+id) must not be serialized behind notifications, otherwise
130+ // a long-running request (e.g. session/prompt) can deadlock cancellation
131+ // notifications (session/cancel) that are required to stop it.
132+ if msg .ID != nil {
133+ m := msg
134+ go c .handleInbound (& m )
135+ continue
136+ }
137+
138+ c .notificationWg .Add (1 )
139+
140+ // Queue the notification for sequential processing.
141+ m := msg
142+ select {
143+ case c .notificationQueue <- & m :
144+ default :
145+ // Balance Add above when the message was not accepted.
146+ c .notificationWg .Done ()
147+ c .loggerOrDefault ().Error ("failed to queue notification; closing connection" , "err" , errNotificationQueueOverflow , "capacity" , cap (c .notificationQueue ), "queued" , len (c .notificationQueue ))
148+ c .shutdownReceive (errNotificationQueueOverflow )
149+ return
107150 }
108- go func (m * anyMessage , isNotif bool ) {
109- if isNotif {
110- defer c .notificationWg .Done ()
111- }
112- c .handleInbound (m )
113- }(& msg , isNotification )
114151 default :
115152 c .loggerOrDefault ().Error ("received message with neither id nor method" , "raw" , string (line ))
116153 }
117154 }
118155
119- c .cancel (errors .New ("peer connection closed" ))
120- c .loggerOrDefault ().Info ("peer connection closed" )
156+ cause := errors .New ("peer connection closed" )
157+ if err := scanner .Err (); err != nil {
158+ cause = err
159+ }
160+ c .shutdownReceive (cause )
161+ }
162+
163+ func (c * Connection ) shutdownReceive (cause error ) {
164+ if cause == nil {
165+ cause = errors .New ("connection closed" )
166+ }
167+
168+ // First, signal disconnect to callers waiting on responses.
169+ c .cancel (cause )
170+
171+ // Then close the notification queue so already-received messages can drain.
172+ // IMPORTANT: Do not block this receive goroutine waiting for the drain to complete;
173+ // notification handlers may legitimately block until their context is canceled.
174+ close (c .notificationQueue )
175+
176+ // Cancel inboundCtx after notifications finish, but ensure we don't leak forever if a
177+ // handler blocks waiting for cancellation.
178+ go func () {
179+ done := make (chan struct {})
180+ go func () {
181+ c .notificationWg .Wait ()
182+ close (done )
183+ }()
184+ select {
185+ case <- done :
186+ case <- time .After (notificationQueueDrainTimeout ):
187+ }
188+ c .inboundCancel (cause )
189+ }()
190+
191+ c .loggerOrDefault ().Info ("connection closed" , "cause" , cause .Error ())
192+ }
193+
194+ // processNotifications processes notifications sequentially to maintain order.
195+ // It terminates when notificationQueue is closed (e.g. on disconnect in receive()).
196+ func (c * Connection ) processNotifications () {
197+ for msg := range c .notificationQueue {
198+ c .handleInbound (msg )
199+ c .notificationWg .Done ()
200+ }
121201}
122202
123203func (c * Connection ) handleResponse (msg * anyMessage ) {
@@ -137,6 +217,15 @@ func (c *Connection) handleResponse(msg *anyMessage) {
137217
138218func (c * Connection ) handleInbound (req * anyMessage ) {
139219 res := anyMessage {JSONRPC : "2.0" }
220+
221+ // Notifications are allowed a slightly longer-lived context during disconnect so we can
222+ // process already-received end-of-connection messages. Requests, however, should be
223+ // canceled promptly when the peer disconnects to avoid doing unnecessary work after
224+ // the caller is gone.
225+ ctx := c .ctx
226+ if req .ID == nil {
227+ ctx = c .inboundCtx
228+ }
140229 // copy ID if present
141230 if req .ID != nil {
142231 res .ID = req .ID
@@ -149,7 +238,7 @@ func (c *Connection) handleInbound(req *anyMessage) {
149238 return
150239 }
151240
152- result , err := c .handler (c . ctx , req .Method , req .Params )
241+ result , err := c .handler (ctx , req .Method , req .Params )
153242 if req .ID == nil {
154243 // Notification: no response is sent; log handler errors to surface decode failures.
155244 if err != nil {
0 commit comments