Skip to content

Commit f7cc200

Browse files
fix events subscriber to close appropriately to fix flaky synctest
1 parent 5aee21a commit f7cc200

File tree

1 file changed

+68
-37
lines changed

1 file changed

+68
-37
lines changed

byoc/trickle.go

Lines changed: 68 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,13 @@ func (bsg *BYOCGatewayServer) startControlPublish(ctx context.Context, control *
568568

569569
const clearStreamDelay = 1 * time.Minute
570570

571-
func (bsg *BYOCGatewayServer) startEventsSubscribe(ctx context.Context, url *url.URL, params byocAIRequestParams, orchAddr string, orchUrl string) {
571+
func (bsg *BYOCGatewayServer) startEventsSubscribe(
572+
ctx context.Context,
573+
url *url.URL,
574+
params byocAIRequestParams,
575+
orchAddr string,
576+
orchUrl string,
577+
) {
572578
subscriber, err := trickle.NewTrickleSubscriber(trickle.TrickleSubscriberConfig{
573579
URL: url.String(),
574580
Ctx: ctx,
@@ -577,67 +583,90 @@ func (bsg *BYOCGatewayServer) startEventsSubscribe(ctx context.Context, url *url
577583
stopProcessing(ctx, params, fmt.Errorf("event sub init failed: %w", err))
578584
return
579585
}
586+
580587
streamId := params.liveParams.streamID
581588

582-
// vars to check events periodically to ensure liveness
583-
var (
589+
const (
584590
eventCheckInterval = 10 * time.Second
585591
maxEventGap = 30 * time.Second
586-
eventTicker = time.NewTicker(eventCheckInterval)
587-
eventsDone = make(chan bool)
588-
// remaining vars in this block must be protected by mutex
589-
lastEventMu = &sync.Mutex{}
592+
maxRetries = 5
593+
retryPause = 300 * time.Millisecond
594+
)
595+
596+
eventTicker := time.NewTicker(eventCheckInterval)
597+
eventsDone := make(chan struct{}, 1)
598+
599+
var (
600+
lastEventMu sync.Mutex
590601
lastEvent = time.Now()
591602
)
592603

593604
clog.Infof(ctx, "Starting event subscription for URL: %s", url.String())
594605

606+
// Clear stream state after delay unless canceled
595607
go func() {
596-
defer time.AfterFunc(clearStreamDelay, func() {
608+
select {
609+
case <-time.After(clearStreamDelay):
597610
bsg.statusStore.Clear(streamId)
598-
})
611+
case <-ctx.Done():
612+
}
613+
}()
614+
615+
// Event reader goroutine
616+
go func() {
599617
defer func() {
600618
eventTicker.Stop()
601-
eventsDone <- true
619+
select {
620+
case eventsDone <- struct{}{}:
621+
default:
622+
}
602623
}()
603-
const maxRetries = 5
604-
const retryPause = 300 * time.Millisecond
624+
605625
retries := 0
626+
606627
for {
607628
select {
608629
case <-ctx.Done():
609630
clog.Info(ctx, "event subscription done")
610631
return
611632
default:
612633
}
634+
613635
clog.Infof(ctx, "Reading from event subscription for URL: %s", url.String())
614636
segment, err := subscriber.Read()
615-
if err == nil {
616-
retries = 0
617-
} else {
618-
// handle errors from event read
637+
if err != nil {
619638
if errors.Is(err, trickle.EOS) || errors.Is(err, trickle.StreamNotFoundErr) {
620639
clog.Infof(ctx, "Stopping subscription due to %s", err)
621640
return
622641
}
642+
623643
var seqErr *trickle.SequenceNonexistent
624644
if errors.As(err, &seqErr) {
625-
// stream exists but segment doesn't, so skip to leading edge
626645
subscriber.SetSeq(seqErr.Latest)
627646
}
628-
if retries > maxRetries {
629-
stopProcessing(ctx, params, fmt.Errorf("too many errors reading events; stopping subscription, err=%w", err))
647+
648+
if retries >= maxRetries {
649+
stopProcessing(ctx, params, fmt.Errorf(
650+
"too many errors reading events; stopping subscription: %w", err,
651+
))
630652
return
631653
}
654+
632655
clog.Infof(ctx, "Error reading events subscription: err=%v retry=%d", err, retries)
633656
retries++
634-
time.Sleep(retryPause)
657+
658+
select {
659+
case <-time.After(retryPause):
660+
case <-ctx.Done():
661+
return
662+
}
635663
continue
636664
}
637665

666+
retries = 0
667+
638668
body, err := io.ReadAll(segment.Body)
639669
segment.Body.Close()
640-
641670
if err != nil {
642671
clog.Infof(ctx, "Error reading events subscription body: %s", err)
643672
continue
@@ -655,7 +684,6 @@ func (bsg *BYOCGatewayServer) startEventsSubscribe(ctx context.Context, url *url
655684
event := eventWrapper.Event
656685
queueEventType := eventWrapper.QueueEventType
657686
if event == nil {
658-
// revert this once push to prod -- If no "event" field found, treat the entire body as the event
659687
event = make(map[string]interface{})
660688
if err := json.Unmarshal(body, &event); err != nil {
661689
clog.Infof(ctx, "Failed to parse JSON as direct event: %s", err)
@@ -672,38 +700,40 @@ func (bsg *BYOCGatewayServer) startEventsSubscribe(ctx context.Context, url *url
672700
"url": orchUrl,
673701
}
674702

675-
clog.V(8).Infof(ctx, "Received event for seq=%d event=%+v", trickle.GetSeq(segment), event)
703+
clog.V(8).Infof(ctx, "Received event for seq=%d event=%+v",
704+
trickle.GetSeq(segment), event,
705+
)
676706

677-
// record the event time
678707
lastEventMu.Lock()
679708
lastEvent = time.Now()
680709
lastEventMu.Unlock()
681710

682711
eventType, ok := event["type"].(string)
683712
if !ok {
684713
eventType = "unknown"
685-
clog.Warningf(ctx, "Received event without a type stream=%s event=%+v", streamId, event)
714+
clog.Warningf(ctx, "Received event without a type stream=%s event=%+v",
715+
streamId, event,
716+
)
686717
}
687718

688719
if eventType == "status" {
689720
queueEventType = "ai_stream_status"
690-
// The large logs and params fields are only sent once and then cleared to save bandwidth. So coalesce the
691-
// incoming status with the last non-null value that we received on such fields for the status API.
692-
lastStreamStatus, _ := bsg.statusStore.Get(streamId)
693721

694-
// Check if inference_status exists in both current and last status
722+
lastStreamStatus, _ := bsg.statusStore.Get(streamId)
695723
inferenceStatus, hasInference := event["inference_status"].(map[string]interface{})
696724
lastInferenceStatus, hasLastInference := lastStreamStatus["inference_status"].(map[string]interface{})
697725

698726
if hasInference {
699727
if logs, ok := inferenceStatus["last_restart_logs"]; !ok || logs == nil {
700728
if hasLastInference {
701-
inferenceStatus["last_restart_logs"] = lastInferenceStatus["last_restart_logs"]
729+
inferenceStatus["last_restart_logs"] =
730+
lastInferenceStatus["last_restart_logs"]
702731
}
703732
}
704733
if params, ok := inferenceStatus["last_params"]; !ok || params == nil {
705734
if hasLastInference {
706-
inferenceStatus["last_params"] = lastInferenceStatus["last_params"]
735+
inferenceStatus["last_params"] =
736+
lastInferenceStatus["last_params"]
707737
}
708738
}
709739
}
@@ -715,22 +745,23 @@ func (bsg *BYOCGatewayServer) startEventsSubscribe(ctx context.Context, url *url
715745
}
716746
}()
717747

718-
// Use events as a heartbeat of sorts:
719-
// if no events arrive for too long, abort the job
748+
// Heartbeat watchdog
720749
go func() {
721750
for {
722751
select {
752+
case <-ctx.Done():
753+
return
754+
case <-eventsDone:
755+
return
723756
case <-eventTicker.C:
724757
lastEventMu.Lock()
725758
eventTime := lastEvent
726759
lastEventMu.Unlock()
727-
if time.Now().Sub(eventTime) > maxEventGap {
760+
761+
if time.Since(eventTime) > maxEventGap {
728762
stopProcessing(ctx, params, fmt.Errorf("timeout waiting for events"))
729-
eventTicker.Stop()
730763
return
731764
}
732-
case <-eventsDone:
733-
return
734765
}
735766
}
736767
}()

0 commit comments

Comments
 (0)