@@ -173,11 +173,10 @@ func NewStreamableServerTransport(sessionID string, opts *StreamableServerTransp
173173 id : sessionID ,
174174 incoming : make (chan jsonrpc.Message , 10 ),
175175 done : make (chan struct {}),
176- outgoing : make (map [StreamID ][][]byte ),
177- signals : make (map [StreamID ]chan struct {}),
176+ streams : make (map [StreamID ]* stream ),
178177 requestStreams : make (map [jsonrpc.ID ]StreamID ),
179- streamRequests : make (map [StreamID ]map [jsonrpc.ID ]struct {}),
180178 }
179+ t .streams [0 ] = newStream (0 )
181180 if opts != nil {
182181 t .opts = * opts
183182 }
@@ -213,59 +212,66 @@ type StreamableServerTransport struct {
213212 // perform the accounting described below when incoming HTTP requests are
214213 // handled.
215214 //
216- // The accounting is complicated. It is tempting to merge some of the maps
217- // below, but they each have different lifecycles, as indicated by Lifecycle:
218- // comments.
219- //
220215 // TODO(rfindley): simplify.
221216
222- // outgoing is the collection of outgoing messages, keyed by the logical
223- // stream ID where they should be delivered.
217+ // streams holds the logical streams for this session, keyed by their ID.
218+ streams map [StreamID ]* stream
219+
220+ // requestStreams maps incoming requests to their logical stream ID.
224221 //
225- // streamID 0 is used for messages that don't correlate with an incoming
226- // request.
222+ // Lifecycle: requestStreams persists for the duration of the session.
227223 //
228- // Lifecycle: persists for the duration of the session.
229- outgoing map [StreamID ][][]byte
224+ // TODO(rfindley): clean up once requests are handled.
225+ requestStreams map [jsonrpc.ID ]StreamID
226+ }
227+
228+ // A stream is a single logical stream of SSE events within a server session.
229+ // A stream begins with a client request, or with a client GET that has
230+ // no Last-Event-ID header.
231+ // A stream ends only when its session ends; we cannot determine its end otherwise,
232+ // since a client may send a GET with a Last-Event-ID that references the stream
233+ // at any time.
234+ type stream struct {
235+ // id is the logical ID for the stream, unique within a session.
236+ // ID 0 is used for messages that don't correlate with an incoming request.
237+ id StreamID
230238
231- // signals maps a logical stream ID to a 1-buffered channel, owned by an
239+ // These mutable fields are protected by the mutex of the corresponding StreamableServerTransport.
240+
241+ // outgoing is the list of outgoing messages, enqueued by server methods that
242+ // write notifications and responses, and dequeued by streamResponse.
243+ outgoing [][]byte
244+
245+ // signal is a 1-buffered channel, owned by an
232246 // incoming HTTP request, that signals that there are messages available to
233- // write into the HTTP response. Signals guarantees that at most one HTTP
247+ // write into the HTTP response. This guarantees that at most one HTTP
234248 // response can receive messages for a logical stream. After claiming
235249 // the stream, incoming requests should read from outgoing, to ensure
236250 // that no new messages are missed.
237251 //
238- // Lifecycle: signals persists for the duration of an HTTP POST or GET
252+ // Lifecycle: persists for the duration of an HTTP POST or GET
239253 // request for the given streamID.
240- signals map [ StreamID ] chan struct {}
254+ signal chan struct {}
241255
242- // requestStreams maps incoming requests to their logical stream ID.
243- //
244- // Lifecycle: requestStreams persists for the duration of the session.
256+ // streamRequests is the set of unanswered incoming RPCs for the stream.
245257 //
246- // TODO(rfindley): clean up once requests are handled.
247- requestStreams map [jsonrpc.ID ]StreamID
248-
249- // streamRequests tracks the set of unanswered incoming RPCs for each logical
250- // stream.
251- //
252- // When the server has responded to each request, the stream should be
253- // closed.
254- //
255- // Lifecycle: streamRequests values persist as until the requests have been
258+ // Lifecycle: requests values persist as until the requests have been
256259 // replied to by the server. Notably, NOT until they are sent to an HTTP
257260 // response, as delivery is not guaranteed.
258- streamRequests map [ StreamID ] map [jsonrpc.ID ]struct {}
261+ requests map [jsonrpc.ID ]struct {}
259262}
260263
261- type StreamID int64
262-
263- // a streamableMsg is an SSE event with an index into its logical stream.
264- type streamableMsg struct {
265- idx int
266- event Event
264+ func newStream (id StreamID ) * stream {
265+ return & stream {
266+ id : id ,
267+ requests : make (map [jsonrpc.ID ]struct {}),
268+ }
267269}
268270
271+ // A StreamID identifies a stream of SSE events. It is unique within the stream's
272+ // [ServerSession].
273+ type StreamID int64
274+
269275// Connect implements the [Transport] interface.
270276//
271277// TODO(rfindley): Connect should return a new object.
@@ -328,16 +334,21 @@ func (t *StreamableServerTransport) serveGET(w http.ResponseWriter, req *http.Re
328334 }
329335
330336 t .mu .Lock ()
331- if _ , ok := t .signals [id ]; ok {
337+ stream , ok := t .streams [id ]
338+ if ! ok {
339+ http .Error (w , "unknown stream" , http .StatusBadRequest )
340+ t .mu .Unlock ()
341+ return
342+ }
343+ if stream .signal != nil {
332344 http .Error (w , "stream ID conflicts with ongoing stream" , http .StatusBadRequest )
333345 t .mu .Unlock ()
334346 return
335347 }
336- signal := make (chan struct {}, 1 )
337- t .signals [id ] = signal
348+ stream .signal = make (chan struct {}, 1 )
338349 t .mu .Unlock ()
339350
340- t .streamResponse (w , req , id , lastIdx , signal )
351+ t .streamResponse (stream , w , req , lastIdx )
341352}
342353
343354func (t * StreamableServerTransport ) servePOST (w http.ResponseWriter , req * http.Request ) {
@@ -369,17 +380,17 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
369380 }
370381
371382 // Update accounting for this request.
372- id := StreamID (t .nextStreamID .Add (1 ))
373- signal := make (chan struct {}, 1 )
383+ stream := newStream (StreamID (t .nextStreamID .Add (1 )))
374384 t .mu .Lock ()
385+ t .streams [stream .id ] = stream
375386 if len (requests ) > 0 {
376- t . streamRequests [ id ] = make (map [jsonrpc.ID ]struct {})
387+ stream . requests = make (map [jsonrpc.ID ]struct {})
377388 }
378389 for reqID := range requests {
379- t .requestStreams [reqID ] = id
380- t. streamRequests [ id ] [reqID ] = struct {}{}
390+ t .requestStreams [reqID ] = stream . id
391+ stream . requests [reqID ] = struct {}{}
381392 }
382- t . signals [ id ] = signal
393+ stream . signal = make ( chan struct {}, 1 )
383394 t .mu .Unlock ()
384395
385396 // Publish incoming messages.
@@ -390,29 +401,37 @@ func (t *StreamableServerTransport) servePOST(w http.ResponseWriter, req *http.R
390401 // TODO(rfindley): consider optimizing for a single incoming request, by
391402 // responding with application/json when there is only a single message in
392403 // the response.
393- t .streamResponse (w , req , id , - 1 , signal )
404+ t .streamResponse (stream , w , req , - 1 )
394405}
395406
396407// lastIndex is the index of the last seen event if resuming, else -1.
397- func (t * StreamableServerTransport ) streamResponse (w http.ResponseWriter , req * http.Request , id StreamID , lastIndex int , signal chan struct {} ) {
408+ func (t * StreamableServerTransport ) streamResponse (stream * stream , w http.ResponseWriter , req * http.Request , lastIndex int ) {
398409 defer func () {
399410 t .mu .Lock ()
400- delete ( t . signals , id )
411+ stream . signal = nil
401412 t .mu .Unlock ()
402413 }()
403414
415+ t .mu .Lock ()
416+ // Although there is a gap in locking between when stream.signal is set and here,
417+ // it cannot change, because it is changed only when non-nil, and it is only
418+ // set to nil in the defer above.
419+ signal := stream .signal
420+ t .mu .Unlock ()
421+
404422 writes := 0
405423
406424 // write one event containing data.
407425 write := func (data []byte ) bool {
408426 lastIndex ++
409427 e := Event {
410428 Name : "message" ,
411- ID : formatEventID (id , lastIndex ),
429+ ID : formatEventID (stream . id , lastIndex ),
412430 Data : data ,
413431 }
414432 if _ , err := writeEvent (w , e ); err != nil {
415433 // Connection closed or broken.
434+ // TODO: log when we add server-side logging.
416435 return false
417436 }
418437 writes ++
@@ -426,7 +445,7 @@ func (t *StreamableServerTransport) streamResponse(w http.ResponseWriter, req *h
426445
427446 if lastIndex >= 0 {
428447 // Resume.
429- for data , err := range t .opts .EventStore .After (req .Context (), t .SessionID (), id , lastIndex ) {
448+ for data , err := range t .opts .EventStore .After (req .Context (), t .SessionID (), stream . id , lastIndex ) {
430449 if err != nil {
431450 // TODO: reevaluate these status codes.
432451 // Maybe distinguish between storage errors, which are 500s, and missing
@@ -450,12 +469,12 @@ stream:
450469 // Repeatedly collect pending outgoing events and send them.
451470 for {
452471 t .mu .Lock ()
453- outgoing := t .outgoing [ id ]
454- t .outgoing [ id ] = nil
472+ outgoing := stream .outgoing
473+ stream .outgoing = nil
455474 t .mu .Unlock ()
456475
457476 for _ , data := range outgoing {
458- if err := t .opts .EventStore .Append (req .Context (), t .id , id , data ); err != nil {
477+ if err := t .opts .EventStore .Append (req .Context (), t .SessionID (), stream . id , data ); err != nil {
459478 http .Error (w , err .Error (), http .StatusInternalServerError )
460479 return
461480 }
@@ -465,7 +484,7 @@ stream:
465484 }
466485
467486 t .mu .Lock ()
468- nOutstanding := len (t . streamRequests [ id ] )
487+ nOutstanding := len (stream . requests )
469488 t .mu .Unlock ()
470489 // If all requests have been handled and replied to, we should terminate this connection.
471490 // "After the JSON-RPC response has been sent, the server SHOULD close the SSE stream."
@@ -579,30 +598,31 @@ func (t *StreamableServerTransport) Write(ctx context.Context, msg jsonrpc.Messa
579598 t .mu .Lock ()
580599 defer t .mu .Unlock ()
581600 if t .isDone {
582- return fmt . Errorf ("session is closed" ) // TODO: should this be EOF?
601+ return errors . New ("session is closed" ) // TODO: should this be EOF?
583602 }
584603
585- if _ , ok := t .streamRequests [forConn ]; ! ok && forConn != 0 {
604+ stream := t .streams [forConn ]
605+ if stream == nil {
606+ return fmt .Errorf ("no stream with ID %d" , forConn )
607+ }
608+ if len (stream .requests ) == 0 && forConn != 0 {
586609 // No outstanding requests for this connection, which means it is logically
587610 // done. This is a sequencing violation from the server, so we should report
588611 // a side-channel error here. Put the message on the general queue to avoid
589612 // dropping messages.
590- forConn = 0
613+ stream = t . streams [ 0 ]
591614 }
592615
593- t .outgoing [ forConn ] = append (t .outgoing [ forConn ] , data )
616+ stream .outgoing = append (stream .outgoing , data )
594617 if replyTo .IsValid () {
595618 // Once we've put the reply on the queue, it's no longer outstanding.
596- delete (t .streamRequests [forConn ], replyTo )
597- if len (t .streamRequests [forConn ]) == 0 {
598- delete (t .streamRequests , forConn )
599- }
619+ delete (stream .requests , replyTo )
600620 }
601621
602622 // Signal work.
603- if c , ok := t . signals [ forConn ]; ok {
623+ if stream . signal != nil {
604624 select {
605- case c <- struct {}{}:
625+ case stream . signal <- struct {}{}:
606626 default :
607627 }
608628 }
0 commit comments