@@ -568,7 +568,13 @@ func (bsg *BYOCGatewayServer) startControlPublish(ctx context.Context, control *
568568
569569const clearStreamDelay = 1 * time .Minute
570570
571- func (bsg * BYOCGatewayServer ) startEventsSubscribe (ctx context.Context , url * url.URL , params byocAIRequestParams , orchAddr string , orchUrl string ) {
571+ func (bsg * BYOCGatewayServer ) startEventsSubscribe (
572+ ctx context.Context ,
573+ url * url.URL ,
574+ params byocAIRequestParams ,
575+ orchAddr string ,
576+ orchUrl string ,
577+ ) {
572578 subscriber , err := trickle .NewTrickleSubscriber (trickle.TrickleSubscriberConfig {
573579 URL : url .String (),
574580 Ctx : ctx ,
@@ -577,67 +583,90 @@ func (bsg *BYOCGatewayServer) startEventsSubscribe(ctx context.Context, url *url
577583 stopProcessing (ctx , params , fmt .Errorf ("event sub init failed: %w" , err ))
578584 return
579585 }
586+
580587 streamId := params .liveParams .streamID
581588
582- // vars to check events periodically to ensure liveness
583- var (
589+ const (
584590 eventCheckInterval = 10 * time .Second
585591 maxEventGap = 30 * time .Second
586- eventTicker = time .NewTicker (eventCheckInterval )
587- eventsDone = make (chan bool )
588- // remaining vars in this block must be protected by mutex
589- lastEventMu = & sync.Mutex {}
592+ maxRetries = 5
593+ retryPause = 300 * time .Millisecond
594+ )
595+
596+ eventTicker := time .NewTicker (eventCheckInterval )
597+ eventsDone := make (chan struct {}, 1 )
598+
599+ var (
600+ lastEventMu sync.Mutex
590601 lastEvent = time .Now ()
591602 )
592603
593604 clog .Infof (ctx , "Starting event subscription for URL: %s" , url .String ())
594605
606+ // Clear stream state after delay unless canceled
595607 go func () {
596- defer time .AfterFunc (clearStreamDelay , func () {
608+ select {
609+ case <- time .After (clearStreamDelay ):
597610 bsg .statusStore .Clear (streamId )
598- })
611+ case <- ctx .Done ():
612+ }
613+ }()
614+
615+ // Event reader goroutine
616+ go func () {
599617 defer func () {
600618 eventTicker .Stop ()
601- eventsDone <- true
619+ select {
620+ case eventsDone <- struct {}{}:
621+ default :
622+ }
602623 }()
603- const maxRetries = 5
604- const retryPause = 300 * time .Millisecond
624+
605625 retries := 0
626+
606627 for {
607628 select {
608629 case <- ctx .Done ():
609630 clog .Info (ctx , "event subscription done" )
610631 return
611632 default :
612633 }
634+
613635 clog .Infof (ctx , "Reading from event subscription for URL: %s" , url .String ())
614636 segment , err := subscriber .Read ()
615- if err == nil {
616- retries = 0
617- } else {
618- // handle errors from event read
637+ if err != nil {
619638 if errors .Is (err , trickle .EOS ) || errors .Is (err , trickle .StreamNotFoundErr ) {
620639 clog .Infof (ctx , "Stopping subscription due to %s" , err )
621640 return
622641 }
642+
623643 var seqErr * trickle.SequenceNonexistent
624644 if errors .As (err , & seqErr ) {
625- // stream exists but segment doesn't, so skip to leading edge
626645 subscriber .SetSeq (seqErr .Latest )
627646 }
628- if retries > maxRetries {
629- stopProcessing (ctx , params , fmt .Errorf ("too many errors reading events; stopping subscription, err=%w" , err ))
647+
648+ if retries >= maxRetries {
649+ stopProcessing (ctx , params , fmt .Errorf (
650+ "too many errors reading events; stopping subscription: %w" , err ,
651+ ))
630652 return
631653 }
654+
632655 clog .Infof (ctx , "Error reading events subscription: err=%v retry=%d" , err , retries )
633656 retries ++
634- time .Sleep (retryPause )
657+
658+ select {
659+ case <- time .After (retryPause ):
660+ case <- ctx .Done ():
661+ return
662+ }
635663 continue
636664 }
637665
666+ retries = 0
667+
638668 body , err := io .ReadAll (segment .Body )
639669 segment .Body .Close ()
640-
641670 if err != nil {
642671 clog .Infof (ctx , "Error reading events subscription body: %s" , err )
643672 continue
@@ -655,7 +684,6 @@ func (bsg *BYOCGatewayServer) startEventsSubscribe(ctx context.Context, url *url
655684 event := eventWrapper .Event
656685 queueEventType := eventWrapper .QueueEventType
657686 if event == nil {
658- // revert this once push to prod -- If no "event" field found, treat the entire body as the event
659687 event = make (map [string ]interface {})
660688 if err := json .Unmarshal (body , & event ); err != nil {
661689 clog .Infof (ctx , "Failed to parse JSON as direct event: %s" , err )
@@ -672,38 +700,40 @@ func (bsg *BYOCGatewayServer) startEventsSubscribe(ctx context.Context, url *url
672700 "url" : orchUrl ,
673701 }
674702
675- clog .V (8 ).Infof (ctx , "Received event for seq=%d event=%+v" , trickle .GetSeq (segment ), event )
703+ clog .V (8 ).Infof (ctx , "Received event for seq=%d event=%+v" ,
704+ trickle .GetSeq (segment ), event ,
705+ )
676706
677- // record the event time
678707 lastEventMu .Lock ()
679708 lastEvent = time .Now ()
680709 lastEventMu .Unlock ()
681710
682711 eventType , ok := event ["type" ].(string )
683712 if ! ok {
684713 eventType = "unknown"
685- clog .Warningf (ctx , "Received event without a type stream=%s event=%+v" , streamId , event )
714+ clog .Warningf (ctx , "Received event without a type stream=%s event=%+v" ,
715+ streamId , event ,
716+ )
686717 }
687718
688719 if eventType == "status" {
689720 queueEventType = "ai_stream_status"
690- // The large logs and params fields are only sent once and then cleared to save bandwidth. So coalesce the
691- // incoming status with the last non-null value that we received on such fields for the status API.
692- lastStreamStatus , _ := bsg .statusStore .Get (streamId )
693721
694- // Check if inference_status exists in both current and last status
722+ lastStreamStatus , _ := bsg . statusStore . Get ( streamId )
695723 inferenceStatus , hasInference := event ["inference_status" ].(map [string ]interface {})
696724 lastInferenceStatus , hasLastInference := lastStreamStatus ["inference_status" ].(map [string ]interface {})
697725
698726 if hasInference {
699727 if logs , ok := inferenceStatus ["last_restart_logs" ]; ! ok || logs == nil {
700728 if hasLastInference {
701- inferenceStatus ["last_restart_logs" ] = lastInferenceStatus ["last_restart_logs" ]
729+ inferenceStatus ["last_restart_logs" ] =
730+ lastInferenceStatus ["last_restart_logs" ]
702731 }
703732 }
704733 if params , ok := inferenceStatus ["last_params" ]; ! ok || params == nil {
705734 if hasLastInference {
706- inferenceStatus ["last_params" ] = lastInferenceStatus ["last_params" ]
735+ inferenceStatus ["last_params" ] =
736+ lastInferenceStatus ["last_params" ]
707737 }
708738 }
709739 }
@@ -715,22 +745,23 @@ func (bsg *BYOCGatewayServer) startEventsSubscribe(ctx context.Context, url *url
715745 }
716746 }()
717747
718- // Use events as a heartbeat of sorts:
719- // if no events arrive for too long, abort the job
748+ // Heartbeat watchdog
720749 go func () {
721750 for {
722751 select {
752+ case <- ctx .Done ():
753+ return
754+ case <- eventsDone :
755+ return
723756 case <- eventTicker .C :
724757 lastEventMu .Lock ()
725758 eventTime := lastEvent
726759 lastEventMu .Unlock ()
727- if time .Now ().Sub (eventTime ) > maxEventGap {
760+
761+ if time .Since (eventTime ) > maxEventGap {
728762 stopProcessing (ctx , params , fmt .Errorf ("timeout waiting for events" ))
729- eventTicker .Stop ()
730763 return
731764 }
732- case <- eventsDone :
733- return
734765 }
735766 }
736767 }()
0 commit comments