@@ -10,6 +10,11 @@ import (
1010 "log/slog"
1111 "sync"
1212 "sync/atomic"
13+ "time"
14+ )
15+
16+ const (
17+ notificationQueueDrainTimeout = 5 * time .Second
1318)
1419
1520type anyMessage struct {
@@ -37,27 +42,45 @@ type Connection struct {
3742 nextID atomic.Uint64
3843 pending map [string ]* pendingResponse
3944
45+ // ctx/cancel govern connection lifetime and are used for Done() and for canceling
46+ // callers waiting on responses when the peer disconnects.
4047 ctx context.Context
4148 cancel context.CancelCauseFunc
4249
50+ // inboundCtx/inboundCancel are used when invoking the inbound MethodHandler.
51+ // This ctx is intentionally kept alive long enough to process notifications
52+ // that were successfully received and queued just before a peer disconnect.
53+ // Otherwise, handlers that respect context cancellation may drop end-of-connection
54+ // messages that we already read off the wire.
55+ inboundCtx context.Context
56+ inboundCancel context.CancelCauseFunc
57+
4358 logger * slog.Logger
4459
4560 // notificationWg tracks in-flight notification handlers. This ensures SendRequest waits
4661 // for all notifications received before the response to complete processing.
4762 notificationWg sync.WaitGroup
63+
64+ // notificationQueue serializes notification processing to maintain order
65+ notificationQueue * unboundedQueue [* anyMessage ]
4866}
4967
5068func NewConnection (handler MethodHandler , peerInput io.Writer , peerOutput io.Reader ) * Connection {
5169 ctx , cancel := context .WithCancelCause (context .Background ())
70+ inboundCtx , inboundCancel := context .WithCancelCause (context .Background ())
5271 c := & Connection {
53- w : peerInput ,
54- r : peerOutput ,
55- handler : handler ,
56- pending : make (map [string ]* pendingResponse ),
57- ctx : ctx ,
58- cancel : cancel ,
72+ w : peerInput ,
73+ r : peerOutput ,
74+ handler : handler ,
75+ pending : make (map [string ]* pendingResponse ),
76+ ctx : ctx ,
77+ cancel : cancel ,
78+ inboundCtx : inboundCtx ,
79+ inboundCancel : inboundCancel ,
80+ notificationQueue : newUnboundedQueue [* anyMessage ](),
5981 }
6082 go c .receive ()
83+ go c .processNotifications ()
6184 return c
6285}
6386
@@ -98,27 +121,68 @@ func (c *Connection) receive() {
98121 case msg .ID != nil && msg .Method == "" :
99122 c .handleResponse (& msg )
100123 case msg .Method != "" :
101- // Only track notifications (no ID) in the WaitGroup, not requests (with ID).
102- // This prevents deadlock when a request handler makes another request.
103- isNotification := msg .ID == nil
104- if isNotification {
105- c .notificationWg .Add (1 )
124+ // Requests (method+id) must not be serialized behind notifications, otherwise
125+ // a long-running request (e.g. session/prompt) can deadlock cancellation
126+ // notifications (session/cancel) that are required to stop it.
127+ if msg .ID != nil {
128+ m := msg
129+ go c .handleInbound (& m )
130+ continue
106131 }
107- go func (m * anyMessage , isNotif bool ) {
108- if isNotif {
109- defer c .notificationWg .Done ()
110- }
111- c .handleInbound (m )
112- }(& msg , isNotification )
132+
133+ c .notificationWg .Add (1 )
134+
135+ // Queue the notification for sequential processing.
136+ // The unbounded queue never blocks, preserving ordering while
137+ // ensuring the receive loop can always read responses promptly.
138+ m := msg
139+ c .notificationQueue .push (& m )
113140 default :
114141 c .loggerOrDefault ().Error ("received message with neither id nor method" , "raw" , string (line ))
115142 }
116143 }
117144
118- c .cancel (errors .New ("peer connection closed" ))
145+ cause := errors .New ("peer connection closed" )
146+
147+ // First, signal disconnect to callers waiting on responses.
148+ c .cancel (cause )
149+
150+ // Then close the notification queue so already-received messages can drain.
151+ // IMPORTANT: Do not block this receive goroutine waiting for the drain to complete;
152+ // notification handlers may legitimately block until their context is canceled.
153+ c .notificationQueue .close ()
154+
155+ // Cancel inboundCtx after notifications finish, but ensure we don't leak forever if a
156+ // handler blocks waiting for cancellation.
157+ go func () {
158+ done := make (chan struct {})
159+ go func () {
160+ c .notificationWg .Wait ()
161+ close (done )
162+ }()
163+ select {
164+ case <- done :
165+ case <- time .After (notificationQueueDrainTimeout ):
166+ }
167+ c .inboundCancel (cause )
168+ }()
169+
119170 c .loggerOrDefault ().Info ("peer connection closed" )
120171}
121172
173+ // processNotifications processes notifications sequentially to maintain order.
174+ // It terminates when notificationQueue is closed (e.g. on disconnect in receive()).
175+ func (c * Connection ) processNotifications () {
176+ for {
177+ msg , ok := c .notificationQueue .pop ()
178+ if ! ok {
179+ return
180+ }
181+ c .handleInbound (msg )
182+ c .notificationWg .Done ()
183+ }
184+ }
185+
122186func (c * Connection ) handleResponse (msg * anyMessage ) {
123187 idStr := string (* msg .ID )
124188
@@ -136,6 +200,15 @@ func (c *Connection) handleResponse(msg *anyMessage) {
136200
137201func (c * Connection ) handleInbound (req * anyMessage ) {
138202 res := anyMessage {JSONRPC : "2.0" }
203+
204+ // Notifications are allowed a slightly longer-lived context during disconnect so we can
205+ // process already-received end-of-connection messages. Requests, however, should be
206+ // canceled promptly when the peer disconnects to avoid doing unnecessary work after
207+ // the caller is gone.
208+ ctx := c .ctx
209+ if req .ID == nil {
210+ ctx = c .inboundCtx
211+ }
139212 // copy ID if present
140213 if req .ID != nil {
141214 res .ID = req .ID
@@ -148,7 +221,7 @@ func (c *Connection) handleInbound(req *anyMessage) {
148221 return
149222 }
150223
151- result , err := c .handler (c . ctx , req .Method , req .Params )
224+ result , err := c .handler (ctx , req .Method , req .Params )
152225 if req .ID == nil {
153226 // Notification: no response is sent; log handler errors to surface decode failures.
154227 if err != nil {
0 commit comments