@@ -179,11 +179,10 @@ func NewStreamableServerTransport(sessionID string, opts *StreamableServerTransp
179179 id : sessionID ,
180180 incoming : make (chan jsonrpc.Message , 10 ),
181181 done : make (chan struct {}),
182- outgoing : make (map [StreamID ][][]byte ),
183- signals : make (map [StreamID ]chan struct {}),
182+ streams : make (map [StreamID ]* stream ),
184183 requestStreams : make (map [jsonrpc.ID ]StreamID ),
185- streamRequests : make (map [StreamID ]map [jsonrpc.ID ]struct {}),
186184 }
185+ t .streams [0 ] = newStream (0 )
187186 if opts != nil {
188187 t .opts = * opts
189188 }
@@ -219,59 +218,66 @@ type StreamableServerTransport struct {
219218 // perform the accounting described below when incoming HTTP requests are
220219 // handled.
221220 //
222- // The accounting is complicated. It is tempting to merge some of the maps
223- // below, but they each have different lifecycles, as indicated by Lifecycle:
224- // comments.
225- //
226221 // TODO(rfindley): simplify.
227222
228- // outgoing is the collection of outgoing messages, keyed by the logical
229- // stream ID where they should be delivered.
223+ // streams holds the logical streams for this session, keyed by their ID.
224+ streams map [StreamID ]* stream
225+
226+ // requestStreams maps incoming requests to their logical stream ID.
230227 //
231- // streamID 0 is used for messages that don't correlate with an incoming
232- // request.
228+ // Lifecycle: requestStreams persists for the duration of the session.
233229 //
234- // Lifecycle: persists for the duration of the session.
235- outgoing map [StreamID ][][]byte
230+ // TODO(rfindley): clean up once requests are handled.
231+ requestStreams map [jsonrpc.ID ]StreamID
232+ }
233+
234+ // A stream is a single logical stream of SSE events within a server session.
235+ // A stream begins with a client request, or with a client GET that has
236+ // no Last-Event-ID header.
237+ // A stream ends only when its session ends; we cannot determine its end otherwise,
238+ // since a client may send a GET with a Last-Event-ID that references the stream
239+ // at any time.
240+ type stream struct {
241+ // id is the logical ID for the stream, unique within a session.
242+ // ID 0 is used for messages that don't correlate with an incoming request.
243+ id StreamID
236244
237- // signals maps a logical stream ID to a 1-buffered channel, owned by an
245+ // These mutable fields are protected by the mutex of the corresponding StreamableServerTransport.
246+
247+ // outgoing is the list of outgoing messages, enqueued by server methods that
248+ // write notifications and responses, and dequeued by streamResponse.
249+ outgoing [][]byte
250+
251+ // signal is a 1-buffered channel, owned by an
238252 // incoming HTTP request, that signals that there are messages available to
239- // write into the HTTP response. Signals guarantees that at most one HTTP
253+ // write into the HTTP response. This guarantees that at most one HTTP
240254 // response can receive messages for a logical stream. After claiming
241255 // the stream, incoming requests should read from outgoing, to ensure
242256 // that no new messages are missed.
243257 //
244- // Lifecycle: signals persists for the duration of an HTTP POST or GET
258+ // Lifecycle: persists for the duration of an HTTP POST or GET
245259 // request for the given streamID.
246- signals map [ StreamID ] chan struct {}
260+ signal chan struct {}
247261
248- // requestStreams maps incoming requests to their logical stream ID.
249- //
250- // Lifecycle: requestStreams persists for the duration of the session.
262+ // streamRequests is the set of unanswered incoming RPCs for the stream.
251263 //
252- // TODO(rfindley): clean up once requests are handled.
253- requestStreams map [jsonrpc.ID ]StreamID
254-
255- // streamRequests tracks the set of unanswered incoming RPCs for each logical
256- // stream.
257- //
258- // When the server has responded to each request, the stream should be
259- // closed.
260- //
261- // Lifecycle: streamRequests values persist as until the requests have been
264+ // Lifecycle: requests values persist until the requests have been
262265 // replied to by the server. Notably, NOT until they are sent to an HTTP
263266 // response, as delivery is not guaranteed.
264- streamRequests map [ StreamID ] map [jsonrpc.ID ]struct {}
267+ requests map [jsonrpc.ID ]struct {}
265268}
266269
267- type StreamID int64
268-
269- // a streamableMsg is an SSE event with an index into its logical stream.
270- type streamableMsg struct {
271- idx int
272- event Event
270+ func newStream (id StreamID ) * stream {
271+ return & stream {
272+ id : id ,
273+ requests : make (map [jsonrpc.ID ]struct {}),
274+ }
273275}
274276
277+ // A StreamID identifies a stream of SSE events. It is unique within the stream's
278+ // [ServerSession].
279+ type StreamID int64
280+
275281// Connect implements the [Transport] interface.
276282//
277283// TODO(rfindley): Connect should return a new object.
@@ -334,16 +340,21 @@ func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Re
334340 }
335341
336342 t .mu .Lock ()
337- if _ , ok := t .signals [id ]; ok {
343+ stream , ok := t .streams [id ]
344+ if ! ok {
345+ http .Error (w , "unknown stream" , http .StatusBadRequest )
346+ t .mu .Unlock ()
347+ return
348+ }
349+ if stream .signal != nil {
338350 http .Error (w , "stream ID conflicts with ongoing stream" , http .StatusBadRequest )
339351 t .mu .Unlock ()
340352 return
341353 }
342- signal := make (chan struct {}, 1 )
343- t .signals [id ] = signal
354+ stream .signal = make (chan struct {}, 1 )
344355 t .mu .Unlock ()
345356
346- t .streamResponse (w , req , id , lastIdx , signal )
357+ t .streamResponse (stream , w , req , lastIdx )
347358}
348359
349360func (t * StreamableServerTransport ) servePOST (w http.ResponseWriter , req * http.Request ) {
@@ -375,17 +386,17 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
375386 }
376387
377388 // Update accounting for this request.
378- id := StreamID (t .nextStreamID .Add (1 ))
379- signal := make (chan struct {}, 1 )
389+ stream := newStream (StreamID (t .nextStreamID .Add (1 )))
380390 t .mu .Lock ()
391+ t .streams [stream .id ] = stream
381392 if len (requests ) > 0 {
382- t . streamRequests [ id ] = make (map [jsonrpc.ID ]struct {})
393+ stream . requests = make (map [jsonrpc.ID ]struct {})
383394 }
384395 for reqID := range requests {
385- t .requestStreams [reqID ] = id
386- t. streamRequests [ id ] [reqID ] = struct {}{}
396+ t .requestStreams [reqID ] = stream . id
397+ stream . requests [reqID ] = struct {}{}
387398 }
388- t . signals [ id ] = signal
399+ stream . signal = make ( chan struct {}, 1 )
389400 t .mu .Unlock ()
390401
391402 // Publish incoming messages.
@@ -396,29 +407,37 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
396407 // TODO(rfindley): consider optimizing for a single incoming request, by
397408 // responding with application/json when there is only a single message in
398409 // the response.
399- t .streamResponse (w , req , id , - 1 , signal )
410+ t .streamResponse (stream , w , req , - 1 )
400411}
401412
402413// lastIndex is the index of the last seen event if resuming, else -1.
403- func (t * StreamableServerTransport ) streamResponse (w http.ResponseWriter , req * http.Request , id StreamID , lastIndex int , signal chan struct {} ) {
414+ func (t * StreamableServerTransport ) streamResponse (stream * stream , w http.ResponseWriter , req * http.Request , lastIndex int ) {
404415 defer func () {
405416 t .mu .Lock ()
406- delete ( t . signals , id )
417+ stream . signal = nil
407418 t .mu .Unlock ()
408419 }()
409420
421+ t .mu .Lock ()
422+ // Although there is a gap in locking between when stream.signal is set and here,
423+ // it cannot change, because it is changed only when non-nil, and it is only
424+ // set to nil in the defer above.
425+ signal := stream .signal
426+ t .mu .Unlock ()
427+
410428 writes := 0
411429
412430 // write one event containing data.
413431 write := func (data []byte ) bool {
414432 lastIndex ++
415433 e := Event {
416434 Name : "message" ,
417- ID : formatEventID (id , lastIndex ),
435+ ID : formatEventID (stream . id , lastIndex ),
418436 Data : data ,
419437 }
420438 if _ , err := writeEvent (w , e ); err != nil {
421439 // Connection closed or broken.
440+ // TODO: log when we add server-side logging.
422441 return false
423442 }
424443 writes ++
@@ -432,7 +451,7 @@ func (t *StreamableServerTransport) streamResponse(w http.ResponseWriter, req *h
432451
433452 if lastIndex >= 0 {
434453 // Resume.
435- for data , err := range t .opts .EventStore .After (req .Context (), t .SessionID (), id , lastIndex ) {
454+ for data , err := range t .opts .EventStore .After (req .Context (), t .SessionID (), stream . id , lastIndex ) {
436455 if err != nil {
437456 // TODO: reevaluate these status codes.
438457 // Maybe distinguish between storage errors, which are 500s, and missing
@@ -456,12 +475,12 @@ stream:
456475 // Repeatedly collect pending outgoing events and send them.
457476 for {
458477 t .mu .Lock ()
459- outgoing := t .outgoing [ id ]
460- t .outgoing [ id ] = nil
478+ outgoing := stream .outgoing
479+ stream .outgoing = nil
461480 t .mu .Unlock ()
462481
463482 for _ , data := range outgoing {
464- if err := t .opts .EventStore .Append (req .Context (), t .id , id , data ); err != nil {
483+ if err := t .opts .EventStore .Append (req .Context (), t .SessionID (), stream . id , data ); err != nil {
465484 http .Error (w , err .Error (), http .StatusInternalServerError )
466485 return
467486 }
@@ -471,7 +490,7 @@ stream:
471490 }
472491
473492 t .mu .Lock ()
474- nOutstanding := len (t . streamRequests [ id ] )
493+ nOutstanding := len (stream . requests )
475494 t .mu .Unlock ()
476495 // If all requests have been handled and replied to, we should terminate this connection.
477496 // "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream."
@@ -585,30 +604,31 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa
585604 t .mu .Lock ()
586605 defer t .mu .Unlock ()
587606 if t .isDone {
588- return fmt . Errorf ("session is closed" ) // TODO: should this be EOF?
607+ return errors . New ("session is closed" )
589608 }
590609
591- if _ , ok := t .streamRequests [forConn ]; ! ok && forConn != 0 {
610+ stream := t .streams [forConn ]
611+ if stream == nil {
612+ return fmt .Errorf ("no stream with ID %d" , forConn )
613+ }
614+ if len (stream .requests ) == 0 && forConn != 0 {
592615 // No outstanding requests for this connection, which means it is logically
593616 // done. This is a sequencing violation from the server, so we should report
594617 // a side-channel error here. Put the message on the general queue to avoid
595618 // dropping messages.
596- forConn = 0
619+ stream = t . streams [ 0 ]
597620 }
598621
599- t .outgoing [ forConn ] = append (t .outgoing [ forConn ] , data )
622+ stream .outgoing = append (stream .outgoing , data )
600623 if replyTo .IsValid () {
601624 // Once we've put the reply on the queue, it's no longer outstanding.
602- delete (t .streamRequests [forConn ], replyTo )
603- if len (t .streamRequests [forConn ]) == 0 {
604- delete (t .streamRequests , forConn )
605- }
625+ delete (stream .requests , replyTo )
606626 }
607627
608628 // Signal work.
609- if c , ok := t . signals [ forConn ]; ok {
629+ if stream . signal != nil {
610630 select {
611- case c <- struct {}{}:
631+ case stream . signal <- struct {}{}:
612632 default :
613633 }
614634 }
0 commit comments