Skip to content

Commit 8b2cac4

Browse files
committed
Acquire lock before add pending. Unexport ProcessTimeNow in em.
1 parent 8d860da commit 8b2cac4

File tree

2 files changed

+25
-14
lines changed

2 files changed

+25
-14
lines changed

sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ type ElementManager struct {
235235
pendingElements sync.WaitGroup // pendingElements counts all unprocessed elements in a job. Jobs with no pending elements terminate successfully.
236236

237237
processTimeEvents *stageRefreshQueue // Manages sequence of stage updates when interfacing with processing time. Callers must hold refreshCond.L lock.
238+
238239
testStreamHandler *testStreamHandler // Optional test stream handler when a test stream is in the pipeline.
239240
}
240241

@@ -398,7 +399,7 @@ func (em *ElementManager) Bundles(ctx context.Context, upstreamCancelFn context.
398399
for {
399400
em.refreshCond.L.Lock()
400401
// Check if processing time has advanced before the wait loop.
401-
emNow := em.ProcessingTimeNow()
402+
emNow := em.processingTimeNow()
402403
changedByProcessingTime := em.processTimeEvents.AdvanceTo(emNow)
403404
em.changedStages.merge(changedByProcessingTime)
404405

@@ -415,7 +416,7 @@ func (em *ElementManager) Bundles(ctx context.Context, upstreamCancelFn context.
415416
em.refreshCond.Wait() // until watermarks may have changed.
416417

417418
// Update if the processing time has advanced while we waited, and add refreshes here. (TODO waking on real time here for prod mode)
418-
emNow = em.ProcessingTimeNow()
419+
emNow = em.processingTimeNow()
419420
changedByProcessingTime = em.processTimeEvents.AdvanceTo(emNow)
420421
em.changedStages.merge(changedByProcessingTime)
421422
}
@@ -521,7 +522,7 @@ func (em *ElementManager) DumpStages() string {
521522
stageState = append(stageState, fmt.Sprintf("TestStreamHandler: completed %v, curIndex %v of %v events: %+v, processingTime %v, %v, ptEvents %v \n",
522523
em.testStreamHandler.completed, em.testStreamHandler.nextEventIndex, len(em.testStreamHandler.events), em.testStreamHandler.events, em.testStreamHandler.processingTime, mtime.FromTime(em.testStreamHandler.processingTime), em.processTimeEvents))
523524
} else {
524-
stageState = append(stageState, fmt.Sprintf("ElementManager Now: %v processingTimeEvents: %v injectedBundles: %v\n", em.ProcessingTimeNow(), em.processTimeEvents.events, em.injectedBundles))
525+
stageState = append(stageState, fmt.Sprintf("ElementManager Now: %v processingTimeEvents: %v injectedBundles: %v\n", em.processingTimeNow(), em.processTimeEvents.events, em.injectedBundles))
525526
}
526527
sort.Strings(ids)
527528
for _, id := range ids {
@@ -880,8 +881,23 @@ func (em *ElementManager) PersistBundle(rb RunBundle, col2Coders map[string]PCol
880881
slog.Int("newPending", len(newPending)), "consumers", consumers, "sideConsumers", sideConsumers,
881882
"pendingDelta", len(newPending)*len(consumers))
882883
for _, sID := range consumers {
884+
883885
consumer := em.stages[sID]
884-
count := consumer.AddPending(em, newPending)
886+
var count int
887+
_, isAggregateStage := consumer.kind.(*aggregateStageKind)
888+
if isAggregateStage {
889+
// While adding pending elements in aggregate stage, we may need to
890+
// access em.processTimeEvents to determine triggered bundles.
891+
// To avoid deadlocks, we acquire the em.refreshCond.L lock here before
892+
// AddPending is called.
893+
func() {
894+
em.refreshCond.L.Lock()
895+
defer em.refreshCond.L.Unlock()
896+
count = consumer.AddPending(em, newPending)
897+
}()
898+
} else {
899+
count = consumer.AddPending(em, newPending)
900+
}
885901
em.addPending(count)
886902
}
887903
for _, link := range sideConsumers {
@@ -993,7 +1009,7 @@ func (em *ElementManager) triageTimers(d TentativeData, inputInfo PColInfo, stag
9931009
win typex.Window
9941010
}
9951011
em.refreshCond.L.Lock()
996-
emNow := em.ProcessingTimeNow()
1012+
emNow := em.processingTimeNow()
9971013
em.refreshCond.L.Unlock()
9981014

9991015
var pendingEventTimers []element
@@ -1334,13 +1350,10 @@ func (ss *stageState) injectTriggeredBundlesIfReady(em *ElementManager, window t
13341350
}
13351351
state := wv[key]
13361352
endOfWindowReached := window.MaxTimestamp() < ss.input
1337-
em.refreshCond.L.Lock()
1338-
emNow := em.ProcessingTimeNow()
1339-
em.refreshCond.L.Unlock()
13401353
ready := ss.strat.IsTriggerReady(triggerInput{
13411354
newElementCount: 1,
13421355
endOfWindowReached: endOfWindowReached,
1343-
emNow: emNow,
1356+
emNow: em.processingTimeNow(),
13441357
}, &state)
13451358

13461359
if ready {
@@ -1377,9 +1390,7 @@ func (ss *stageState) injectTriggeredBundlesIfReady(em *ElementManager, window t
13771390
// TODO: how to deal with watermark holds for this implicit processing time timer
13781391
// ss.watermarkHolds.Add(timer.holdTimestamp, 1)
13791392
ss.processingTimeTimers.Persist(firingTime, timer, notYetHolds)
1380-
em.refreshCond.L.Lock()
13811393
em.processTimeEvents.Schedule(firingTime, ss.ID)
1382-
em.refreshCond.L.Unlock()
13831394
em.wakeUpAt(firingTime)
13841395
}
13851396
}
@@ -2444,8 +2455,8 @@ func (ss *stageState) bundleReady(em *ElementManager, emNow mtime.Time) (mtime.T
24442455
return upstreamW, ready, ptimeEventsReady, injectedReady
24452456
}
24462457

2447-
// ProcessingTimeNow gives the current processing time for the runner.
2448-
func (em *ElementManager) ProcessingTimeNow() (ret mtime.Time) {
2458+
// processingTimeNow gives the current processing time for the runner.
2459+
func (em *ElementManager) processingTimeNow() (ret mtime.Time) {
24492460
if em.testStreamHandler != nil && !em.testStreamHandler.completed {
24502461
return em.testStreamHandler.Now()
24512462
}

sdks/go/pkg/beam/runners/prism/internal/engine/teststream.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ func (ev tsProcessingTimeEvent) Execute(em *ElementManager) {
238238
}
239239

240240
// Add the refreshes now so our block prevention logic works.
241-
emNow := em.ProcessingTimeNow()
241+
emNow := em.processingTimeNow()
242242
toRefresh := em.processTimeEvents.AdvanceTo(emNow)
243243
em.changedStages.merge(toRefresh)
244244
}

0 commit comments

Comments
 (0)