Skip to content

Commit 990b5ff

Browse files
authored
[Prism] Support injecting triggered bundle for a batch of elements. (#36219)
* Support injecting trigger bundle for a batch of elements. * Override streaming mode if there is an unbounded pcollection. * Refactor some code. * Enable prism on faild pipelines and rebench. * Add tests for streaming and batch mode on data trigger for prism. * Revert "Enable prism on faild pipelines and rebench." This reverts commit bc648d5. * Fix the newly added tests.
1 parent 6ec4678 commit 990b5ff

File tree

3 files changed

+143
-35
lines changed

3 files changed

+143
-35
lines changed

sdks/go/pkg/beam/runners/prism/internal/engine/elementmanager.go

Lines changed: 61 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,8 @@ type Config struct {
184184
MaxBundleSize int
185185
// Whether to use real-time clock as processing time
186186
EnableRTC bool
187+
// Whether to process the data in a streaming mode
188+
StreamingMode bool
187189
}
188190

189191
// ElementManager handles elements, watermarks, and related errata to determine
@@ -1296,6 +1298,43 @@ func (ss *stageState) AddPending(em *ElementManager, newPending []element) int {
12961298
return ss.kind.addPending(ss, em, newPending)
12971299
}
12981300

1301+
func (ss *stageState) injectTriggeredBundlesIfReady(em *ElementManager, window typex.Window, key string) int {
1302+
// Check on triggers for this key.
1303+
// We use an empty linkID as the key into state for aggregations.
1304+
count := 0
1305+
if ss.state == nil {
1306+
ss.state = make(map[LinkID]map[typex.Window]map[string]StateData)
1307+
}
1308+
lv, ok := ss.state[LinkID{}]
1309+
if !ok {
1310+
lv = make(map[typex.Window]map[string]StateData)
1311+
ss.state[LinkID{}] = lv
1312+
}
1313+
wv, ok := lv[window]
1314+
if !ok {
1315+
wv = make(map[string]StateData)
1316+
lv[window] = wv
1317+
}
1318+
state := wv[key]
1319+
endOfWindowReached := window.MaxTimestamp() < ss.input
1320+
ready := ss.strat.IsTriggerReady(triggerInput{
1321+
newElementCount: 1,
1322+
endOfWindowReached: endOfWindowReached,
1323+
}, &state)
1324+
1325+
if ready {
1326+
state.Pane = computeNextTriggeredPane(state.Pane, endOfWindowReached)
1327+
}
1328+
// Store the state as triggers may have changed it.
1329+
ss.state[LinkID{}][window][key] = state
1330+
1331+
// If we're ready, it's time to fire!
1332+
if ready {
1333+
count += ss.buildTriggeredBundle(em, key, window)
1334+
}
1335+
return count
1336+
}
1337+
12991338
// addPending for aggregate stages behaves likes stateful stages, but don't need to handle timers or a separate window
13001339
// expiration condition.
13011340
func (*aggregateStageKind) addPending(ss *stageState, em *ElementManager, newPending []element) int {
@@ -1315,6 +1354,13 @@ func (*aggregateStageKind) addPending(ss *stageState, em *ElementManager, newPen
13151354
if ss.pendingByKeys == nil {
13161355
ss.pendingByKeys = map[string]*dataAndTimers{}
13171356
}
1357+
1358+
type windowKey struct {
1359+
window typex.Window
1360+
key string
1361+
}
1362+
pendingWindowKeys := set[windowKey]{}
1363+
13181364
count := 0
13191365
for _, e := range newPending {
13201366
count++
@@ -1327,37 +1373,18 @@ func (*aggregateStageKind) addPending(ss *stageState, em *ElementManager, newPen
13271373
ss.pendingByKeys[string(e.keyBytes)] = dnt
13281374
}
13291375
heap.Push(&dnt.elements, e)
1330-
// Check on triggers for this key.
1331-
// We use an empty linkID as the key into state for aggregations.
1332-
if ss.state == nil {
1333-
ss.state = make(map[LinkID]map[typex.Window]map[string]StateData)
1334-
}
1335-
lv, ok := ss.state[LinkID{}]
1336-
if !ok {
1337-
lv = make(map[typex.Window]map[string]StateData)
1338-
ss.state[LinkID{}] = lv
1339-
}
1340-
wv, ok := lv[e.window]
1341-
if !ok {
1342-
wv = make(map[string]StateData)
1343-
lv[e.window] = wv
1344-
}
1345-
state := wv[string(e.keyBytes)]
1346-
endOfWindowReached := e.window.MaxTimestamp() < ss.input
1347-
ready := ss.strat.IsTriggerReady(triggerInput{
1348-
newElementCount: 1,
1349-
endOfWindowReached: endOfWindowReached,
1350-
}, &state)
13511376

1352-
if ready {
1353-
state.Pane = computeNextTriggeredPane(state.Pane, endOfWindowReached)
1377+
if em.config.StreamingMode {
1378+
// In streaming mode, we check trigger readiness on each element
1379+
count += ss.injectTriggeredBundlesIfReady(em, e.window, string(e.keyBytes))
1380+
} else {
1381+
// In batch mode, we store key + window pairs here and check trigger readiness for each of them later.
1382+
pendingWindowKeys.insert(windowKey{window: e.window, key: string(e.keyBytes)})
13541383
}
1355-
// Store the state as triggers may have changed it.
1356-
ss.state[LinkID{}][e.window][string(e.keyBytes)] = state
1357-
1358-
// If we're ready, it's time to fire!
1359-
if ready {
1360-
count += ss.buildTriggeredBundle(em, e.keyBytes, e.window)
1384+
}
1385+
if !em.config.StreamingMode {
1386+
for wk := range pendingWindowKeys {
1387+
count += ss.injectTriggeredBundlesIfReady(em, wk.window, wk.key)
13611388
}
13621389
}
13631390
return count
@@ -1493,9 +1520,9 @@ func (ss *stageState) savePanes(bundID string, panesInBundle []bundlePane) {
14931520
// buildTriggeredBundle must be called with the stage.mu lock held.
14941521
// When in discarding mode, returns 0.
14951522
// When in accumulating mode, returns the number of fired elements to maintain a correct pending count.
1496-
func (ss *stageState) buildTriggeredBundle(em *ElementManager, key []byte, win typex.Window) int {
1523+
func (ss *stageState) buildTriggeredBundle(em *ElementManager, key string, win typex.Window) int {
14971524
var toProcess []element
1498-
dnt := ss.pendingByKeys[string(key)]
1525+
dnt := ss.pendingByKeys[key]
14991526
var notYet []element
15001527

15011528
rb := RunBundle{StageID: ss.ID, BundleID: "agg-" + em.nextBundID(), Watermark: ss.input}
@@ -1524,7 +1551,7 @@ func (ss *stageState) buildTriggeredBundle(em *ElementManager, key []byte, win t
15241551
}
15251552
dnt.elements = append(dnt.elements, notYet...)
15261553
if dnt.elements.Len() == 0 {
1527-
delete(ss.pendingByKeys, string(key))
1554+
delete(ss.pendingByKeys, key)
15281555
} else {
15291556
// Ensure the heap invariants are maintained.
15301557
heap.Init(&dnt.elements)
@@ -1537,15 +1564,15 @@ func (ss *stageState) buildTriggeredBundle(em *ElementManager, key []byte, win t
15371564
{
15381565
win: win,
15391566
key: string(key),
1540-
pane: ss.state[LinkID{}][win][string(key)].Pane,
1567+
pane: ss.state[LinkID{}][win][key].Pane,
15411568
},
15421569
}
15431570

15441571
ss.makeInProgressBundle(
15451572
func() string { return rb.BundleID },
15461573
toProcess,
15471574
ss.input,
1548-
singleSet(string(key)),
1575+
singleSet(key),
15491576
nil,
15501577
panesInBundle,
15511578
)

sdks/go/pkg/beam/runners/prism/internal/execute.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic
153153

154154
topo := prepro.preProcessGraph(comps, j)
155155
ts := comps.GetTransforms()
156+
pcols := comps.GetPcollections()
156157

157158
config := engine.Config{}
158159
m := j.PipelineOptions().AsMap()
@@ -167,6 +168,18 @@ func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservic
167168
}
168169
}
169170

171+
if streaming, ok := m["beam:option:streaming:v1"].(bool); ok {
172+
config.StreamingMode = streaming
173+
}
174+
175+
// Set StreamingMode to true if there is any unbounded PCollection.
176+
for _, pcoll := range pcols {
177+
if pcoll.GetIsBounded() == pipepb.IsBounded_UNBOUNDED {
178+
config.StreamingMode = true
179+
break
180+
}
181+
}
182+
170183
em := engine.NewElementManager(config)
171184

172185
// TODO move this loop and code into the preprocessor instead.

sdks/python/apache_beam/runners/portability/prism_runner_test.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,14 @@
3535
import apache_beam as beam
3636
from apache_beam.options.pipeline_options import DebugOptions
3737
from apache_beam.options.pipeline_options import PortableOptions
38+
from apache_beam.options.pipeline_options import StandardOptions
39+
from apache_beam.options.pipeline_options import TypeOptions
3840
from apache_beam.runners.portability import portable_runner_test
3941
from apache_beam.runners.portability import prism_runner
4042
from apache_beam.testing.util import assert_that
4143
from apache_beam.testing.util import equal_to
44+
from apache_beam.transforms import trigger
45+
from apache_beam.transforms import window
4246
from apache_beam.utils import shared
4347

4448
# Run as
@@ -64,6 +68,8 @@ def __init__(self, *args, **kwargs):
6468
self.environment_type = None
6569
self.environment_config = None
6670
self.enable_commit = False
71+
self.streaming = False
72+
self.allow_unsafe_triggers = False
6773

6874
def setUp(self):
6975
self.enable_commit = False
@@ -175,6 +181,9 @@ def create_options(self):
175181
options.view_as(
176182
PortableOptions).environment_options = self.environment_options
177183

184+
options.view_as(StandardOptions).streaming = self.streaming
185+
options.view_as(
186+
TypeOptions).allow_unsafe_triggers = self.allow_unsafe_triggers
178187
return options
179188

180189
# Can't read host files from within docker, read a "local" file there.
@@ -225,7 +234,66 @@ def test_custom_window_type(self):
225234
def test_metrics(self):
226235
super().test_metrics(check_bounded_trie=False)
227236

228-
# Inherits all other tests.
237+
def construct_timestamped(k, t):
238+
return window.TimestampedValue((k, t), t)
239+
240+
def format_result(k, vs):
241+
return ('%s-%s' % (k, len(list(vs))), set(vs))
242+
243+
def test_after_count_trigger_batch(self):
244+
self.allow_unsafe_triggers = True
245+
with self.create_pipeline() as p:
246+
result = (
247+
p
248+
| beam.Create([1, 2, 3, 4, 5, 10, 11])
249+
| beam.FlatMap(lambda t: [('A', t), ('B', t + 5)])
250+
#A1, A2, A3, A4, A5, A10, A11, B6, B7, B8, B9, B10, B15, B16
251+
| beam.MapTuple(PrismRunnerTest.construct_timestamped)
252+
| beam.WindowInto(
253+
window.FixedWindows(10),
254+
trigger=trigger.AfterCount(3),
255+
accumulation_mode=trigger.AccumulationMode.DISCARDING,
256+
)
257+
| beam.GroupByKey()
258+
| beam.MapTuple(PrismRunnerTest.format_result))
259+
assert_that(
260+
result,
261+
equal_to(
262+
list([
263+
('A-5', {1, 2, 3, 4, 5}),
264+
('A-2', {10, 11}),
265+
('B-4', {6, 7, 8, 9}),
266+
('B-3', {10, 15, 16}),
267+
])))
268+
269+
def test_after_count_trigger_streaming(self):
270+
self.allow_unsafe_triggers = True
271+
self.streaming = True
272+
with self.create_pipeline() as p:
273+
result = (
274+
p
275+
| beam.Create([1, 2, 3, 4, 5, 10, 11])
276+
| beam.FlatMap(lambda t: [('A', t), ('B', t + 5)])
277+
#A1, A2, A3, A4, A5, A10, A11, B6, B7, B8, B9, B10, B15, B16
278+
| beam.MapTuple(PrismRunnerTest.construct_timestamped)
279+
| beam.WindowInto(
280+
window.FixedWindows(10),
281+
trigger=trigger.AfterCount(3),
282+
accumulation_mode=trigger.AccumulationMode.DISCARDING,
283+
)
284+
| beam.GroupByKey()
285+
| beam.MapTuple(PrismRunnerTest.format_result))
286+
assert_that(
287+
result,
288+
equal_to(
289+
list([
290+
('A-3', {1, 2, 3}),
291+
('A-2', {4, 5}),
292+
('A-2', {10, 11}),
293+
('B-3', {6, 7, 8}),
294+
('B-1', {9}),
295+
('B-3', {10, 15, 16}),
296+
])))
229297

230298

231299
class PrismJobServerTest(unittest.TestCase):

0 commit comments

Comments
 (0)