Skip to content

Commit 1f2991f

Browse files
committed
Acquire lock before add pending. Unexport ProcessTimeNow in em.
1 parent 8d860da commit 1f2991f

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ type ElementManager struct {
235235
pendingElements sync.WaitGroup // pendingElements counts all unprocessed elements in a job. Jobs with no pending elements terminate successfully.
236236

237237
processTimeEvents *stageRefreshQueue // Manages sequence of stage updates when interfacing with processing time. Callers must hold refreshCond.L lock.
238+
238239
testStreamHandler *testStreamHandler // Optional test stream handler when a test stream is in the pipeline.
239240
}
240241

@@ -398,7 +399,7 @@ func (em *ElementManager) Bundles(ctx context.Context, upstreamCancelFn context.
398399
for {
399400
em.refreshCond.L.Lock()
400401
// Check if processing time has advanced before the wait loop.
401-
emNow := em.ProcessingTimeNow()
402+
emNow := em.processingTimeNow()
402403
changedByProcessingTime := em.processTimeEvents.AdvanceTo(emNow)
403404
em.changedStages.merge(changedByProcessingTime)
404405

@@ -415,7 +416,7 @@ func (em *ElementManager) Bundles(ctx context.Context, upstreamCancelFn context.
415416
em.refreshCond.Wait() // until watermarks may have changed.
416417

417418
// Update if the processing time has advanced while we waited, and add refreshes here. (TODO waking on real time here for prod mode)
418-
emNow = em.ProcessingTimeNow()
419+
emNow = em.processingTimeNow()
419420
changedByProcessingTime = em.processTimeEvents.AdvanceTo(emNow)
420421
em.changedStages.merge(changedByProcessingTime)
421422
}
@@ -521,7 +522,7 @@ func (em *ElementManager) DumpStages() string {
521522
stageState = append(stageState, fmt.Sprintf("TestStreamHandler: completed %v, curIndex %v of %v events: %+v, processingTime %v, %v, ptEvents %v \n",
522523
em.testStreamHandler.completed, em.testStreamHandler.nextEventIndex, len(em.testStreamHandler.events), em.testStreamHandler.events, em.testStreamHandler.processingTime, mtime.FromTime(em.testStreamHandler.processingTime), em.processTimeEvents))
523524
} else {
524-
stageState = append(stageState, fmt.Sprintf("ElementManager Now: %v processingTimeEvents: %v injectedBundles: %v\n", em.ProcessingTimeNow(), em.processTimeEvents.events, em.injectedBundles))
525+
stageState = append(stageState, fmt.Sprintf("ElementManager Now: %v processingTimeEvents: %v injectedBundles: %v\n", em.processingTimeNow(), em.processTimeEvents.events, em.injectedBundles))
525526
}
526527
sort.Strings(ids)
527528
for _, id := range ids {
@@ -880,8 +881,23 @@ func (em *ElementManager) PersistBundle(rb RunBundle, col2Coders map[string]PCol
880881
slog.Int("newPending", len(newPending)), "consumers", consumers, "sideConsumers", sideConsumers,
881882
"pendingDelta", len(newPending)*len(consumers))
882883
for _, sID := range consumers {
884+
883885
consumer := em.stages[sID]
884-
count := consumer.AddPending(em, newPending)
886+
var count int
887+
_, isAggregateStage := consumer.kind.(*aggregateStageKind)
888+
if isAggregateStage {
889+
// While adding pending elements in aggregate stage, we may need to
890+
// access em.processTimeEvents to determine triggered bundles.
891+
// To avoid deadlocks, we acquire the em.refreshCond.L lock here before
892+
// AddPending is called.
893+
func() {
894+
em.refreshCond.L.Lock()
895+
defer em.refreshCond.L.Unlock()
896+
count = consumer.AddPending(em, newPending)
897+
}()
898+
} else {
899+
count = consumer.AddPending(em, newPending)
900+
}
885901
em.addPending(count)
886902
}
887903
for _, link := range sideConsumers {
@@ -993,7 +1009,7 @@ func (em *ElementManager) triageTimers(d TentativeData, inputInfo PColInfo, stag
9931009
win typex.Window
9941010
}
9951011
em.refreshCond.L.Lock()
996-
emNow := em.ProcessingTimeNow()
1012+
emNow := em.processingTimeNow()
9971013
em.refreshCond.L.Unlock()
9981014

9991015
var pendingEventTimers []element
@@ -1334,9 +1350,7 @@ func (ss *stageState) injectTriggeredBundlesIfReady(em *ElementManager, window t
13341350
}
13351351
state := wv[key]
13361352
endOfWindowReached := window.MaxTimestamp() < ss.input
1337-
em.refreshCond.L.Lock()
1338-
emNow := em.ProcessingTimeNow()
1339-
em.refreshCond.L.Unlock()
1353+
emNow := em.processingTimeNow()
13401354
ready := ss.strat.IsTriggerReady(triggerInput{
13411355
newElementCount: 1,
13421356
endOfWindowReached: endOfWindowReached,
@@ -1377,9 +1391,7 @@ func (ss *stageState) injectTriggeredBundlesIfReady(em *ElementManager, window t
13771391
// TODO: how to deal with watermark holds for this implicit processing time timer
13781392
// ss.watermarkHolds.Add(timer.holdTimestamp, 1)
13791393
ss.processingTimeTimers.Persist(firingTime, timer, notYetHolds)
1380-
em.refreshCond.L.Lock()
13811394
em.processTimeEvents.Schedule(firingTime, ss.ID)
1382-
em.refreshCond.L.Unlock()
13831395
em.wakeUpAt(firingTime)
13841396
}
13851397
}
@@ -2444,8 +2456,8 @@ func (ss *stageState) bundleReady(em *ElementManager, emNow mtime.Time) (mtime.T
24442456
return upstreamW, ready, ptimeEventsReady, injectedReady
24452457
}
24462458

2447-
// ProcessingTimeNow gives the current processing time for the runner.
2448-
func (em *ElementManager) ProcessingTimeNow() (ret mtime.Time) {
2459+
// processingTimeNow gives the current processing time for the runner.
2460+
func (em *ElementManager) processingTimeNow() (ret mtime.Time) {
24492461
if em.testStreamHandler != nil && !em.testStreamHandler.completed {
24502462
return em.testStreamHandler.Now()
24512463
}

sdks/go/pkg/beam/runners/prism/internal/engine/teststream.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ func (ev tsProcessingTimeEvent) Execute(em *ElementManager) {
238238
}
239239

240240
// Add the refreshes now so our block prevention logic works.
241-
emNow := em.ProcessingTimeNow()
241+
emNow := em.processingTimeNow()
242242
toRefresh := em.processTimeEvents.AdvanceTo(emNow)
243243
em.changedStages.merge(toRefresh)
244244
}

0 commit comments

Comments
 (0)