Skip to content

Commit 0c433e3

Browse files
authored
[Prism] Disable combiner lifting for TriggerAlways (#36146)
1 parent 06dd9b0 commit 0c433e3

File tree

3 files changed

+93
-24
lines changed

3 files changed

+93
-24
lines changed

sdks/go/pkg/beam/runners/prism/internal/handlecombine.go

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -64,43 +64,52 @@ func (h *combine) PrepareTransform(tid string, t *pipepb.PTransform, comps *pipe
6464
combineInput := comps.GetPcollections()[onlyInput]
6565
ws := comps.GetWindowingStrategies()[combineInput.GetWindowingStrategyId()]
6666

67-
var hasElementCount func(tpb *pipepb.Trigger) bool
67+
var hasTriggerType func(tpb *pipepb.Trigger, targetTriggerType reflect.Type) bool
6868

69-
hasElementCount = func(tpb *pipepb.Trigger) bool {
70-
elCount := false
69+
hasTriggerType = func(tpb *pipepb.Trigger, targetTriggerType reflect.Type) bool {
70+
if tpb == nil {
71+
return false
72+
}
7173
switch at := tpb.GetTrigger().(type) {
72-
case *pipepb.Trigger_ElementCount_:
73-
return true
7474
case *pipepb.Trigger_AfterAll_:
7575
for _, st := range at.AfterAll.GetSubtriggers() {
76-
elCount = elCount || hasElementCount(st)
76+
if hasTriggerType(st, targetTriggerType) {
77+
return true
78+
}
7779
}
78-
return elCount
80+
return false
7981
case *pipepb.Trigger_AfterAny_:
8082
for _, st := range at.AfterAny.GetSubtriggers() {
81-
elCount = elCount || hasElementCount(st)
83+
if hasTriggerType(st, targetTriggerType) {
84+
return true
85+
}
8286
}
83-
return elCount
87+
return false
8488
case *pipepb.Trigger_AfterEach_:
8589
for _, st := range at.AfterEach.GetSubtriggers() {
86-
elCount = elCount || hasElementCount(st)
90+
if hasTriggerType(st, targetTriggerType) {
91+
return true
92+
}
8793
}
88-
return elCount
94+
return false
8995
case *pipepb.Trigger_AfterEndOfWindow_:
90-
return hasElementCount(at.AfterEndOfWindow.GetEarlyFirings()) ||
91-
hasElementCount(at.AfterEndOfWindow.GetLateFirings())
96+
return hasTriggerType(at.AfterEndOfWindow.GetEarlyFirings(), targetTriggerType) ||
97+
hasTriggerType(at.AfterEndOfWindow.GetLateFirings(), targetTriggerType)
9298
case *pipepb.Trigger_OrFinally_:
93-
return hasElementCount(at.OrFinally.GetMain()) ||
94-
hasElementCount(at.OrFinally.GetFinally())
99+
return hasTriggerType(at.OrFinally.GetMain(), targetTriggerType) ||
100+
hasTriggerType(at.OrFinally.GetFinally(), targetTriggerType)
95101
case *pipepb.Trigger_Repeat_:
96-
return hasElementCount(at.Repeat.GetSubtrigger())
102+
return hasTriggerType(at.Repeat.GetSubtrigger(), targetTriggerType)
97103
default:
98-
return false
104+
return reflect.TypeOf(at) == targetTriggerType
99105
}
100106
}
101107

102108
// If we aren't lifting, the "default impl" for combines should be sufficient.
103-
if !h.config.EnableLifting || hasElementCount(ws.GetTrigger()) {
109+
// Disable lifting if there is any TriggerElementCount or TriggerAlways.
110+
if (!h.config.EnableLifting ||
111+
hasTriggerType(ws.GetTrigger(), reflect.TypeOf(&pipepb.Trigger_ElementCount_{})) ||
112+
hasTriggerType(ws.GetTrigger(), reflect.TypeOf(&pipepb.Trigger_Always_{}))) {
104113
return prepareResult{} // Strip the composite layer when lifting is disabled.
105114
}
106115

sdks/go/pkg/beam/runners/prism/internal/handlecombine_test.go

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,14 @@ import (
2525
"google.golang.org/protobuf/testing/protocmp"
2626
)
2727

28-
func TestHandleCombine(t *testing.T) {
29-
undertest := "UnderTest"
28+
func makeWindowingStrategy(trigger *pipepb.Trigger) *pipepb.WindowingStrategy {
29+
return &pipepb.WindowingStrategy{
30+
Trigger: trigger,
31+
}
32+
}
3033

31-
combineTransform := &pipepb.PTransform{
34+
func makeCombineTransform(inputPCollectionID string) *pipepb.PTransform {
35+
return &pipepb.PTransform{
3236
UniqueName: "COMBINE",
3337
Spec: &pipepb.FunctionSpec{
3438
Urn: urns.TransformCombinePerKey,
@@ -41,7 +45,7 @@ func TestHandleCombine(t *testing.T) {
4145
}),
4246
},
4347
Inputs: map[string]string{
44-
"input": "combineIn",
48+
"input": inputPCollectionID,
4549
},
4650
Outputs: map[string]string{
4751
"input": "combineOut",
@@ -51,6 +55,15 @@ func TestHandleCombine(t *testing.T) {
5155
"combine_values",
5256
},
5357
}
58+
}
59+
60+
func TestHandleCombine(t *testing.T) {
61+
undertest := "UnderTest"
62+
63+
combineTransform := makeCombineTransform("combineIn")
64+
combineTransformWithTriggerElementCount := makeCombineTransform("combineInWithTriggerElementCount")
65+
combineTransformWithTriggerAlways := makeCombineTransform("combineInWithTriggerAlways")
66+
5467
combineValuesTransform := &pipepb.PTransform{
5568
UniqueName: "combine_values",
5669
Subtransforms: []string{
@@ -64,6 +77,14 @@ func TestHandleCombine(t *testing.T) {
6477
"combineOut": {
6578
CoderId: "outputCoder",
6679
},
80+
"combineInWithTriggerElementCount": {
81+
CoderId: "inputCoder",
82+
WindowingStrategyId: "wsElementCount",
83+
},
84+
"combineInWithTriggerAlways": {
85+
CoderId: "inputCoder",
86+
WindowingStrategyId: "wsAlways",
87+
},
6788
}
6889
baseCoderMap := map[string]*pipepb.Coder{
6990
"int": {
@@ -84,7 +105,20 @@ func TestHandleCombine(t *testing.T) {
84105
ComponentCoderIds: []string{"int", "string"},
85106
},
86107
}
87-
108+
baseWindowingStrategyMap := map[string]*pipepb.WindowingStrategy{
109+
"wsElementCount": makeWindowingStrategy(&pipepb.Trigger{
110+
Trigger: &pipepb.Trigger_ElementCount_{
111+
ElementCount: &pipepb.Trigger_ElementCount{
112+
ElementCount: 10,
113+
},
114+
},
115+
}),
116+
"wsAlways": makeWindowingStrategy(&pipepb.Trigger{
117+
Trigger: &pipepb.Trigger_Always_{
118+
Always: &pipepb.Trigger_Always{},
119+
},
120+
}),
121+
}
88122
tests := []struct {
89123
name string
90124
lifted bool
@@ -188,6 +222,32 @@ func TestHandleCombine(t *testing.T) {
188222
},
189223
},
190224
},
225+
}, {
226+
name: "noLift_triggerElementCount",
227+
lifted: true, // Lifting is enabled, but should be disabled in the present of the trigger
228+
comps: &pipepb.Components{
229+
Transforms: map[string]*pipepb.PTransform{
230+
undertest: combineTransformWithTriggerElementCount,
231+
"combine_values": combineValuesTransform,
232+
},
233+
Pcollections: basePCollectionMap,
234+
Coders: baseCoderMap,
235+
WindowingStrategies: baseWindowingStrategyMap,
236+
},
237+
want: prepareResult{},
238+
}, {
239+
name: "noLift_triggerAlways",
240+
lifted: true, // Lifting is enabled, but should be disabled in the present of the trigger
241+
comps: &pipepb.Components{
242+
Transforms: map[string]*pipepb.PTransform{
243+
undertest: combineTransformWithTriggerAlways,
244+
"combine_values": combineValuesTransform,
245+
},
246+
Pcollections: basePCollectionMap,
247+
Coders: baseCoderMap,
248+
WindowingStrategies: baseWindowingStrategyMap,
249+
},
250+
want: prepareResult{},
191251
},
192252
}
193253
for _, test := range tests {

sdks/go/pkg/beam/runners/prism/internal/unimplemented_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ func TestUnimplemented(t *testing.T) {
4949
// See https://github.com/apache/beam/issues/31153.
5050
{pipeline: primitives.TriggerElementCount},
5151
{pipeline: primitives.TriggerOrFinally},
52-
{pipeline: primitives.TriggerAlways},
5352

5453
// Currently unimplemented triggers.
5554
// https://github.com/apache/beam/issues/31438
@@ -87,6 +86,7 @@ func TestImplemented(t *testing.T) {
8786
{pipeline: primitives.ParDoProcessElementBundleFinalizer},
8887

8988
{pipeline: primitives.TriggerNever},
89+
{pipeline: primitives.TriggerAlways},
9090
{pipeline: primitives.Panes},
9191
{pipeline: primitives.TriggerAfterAll},
9292
{pipeline: primitives.TriggerAfterAny},

0 commit comments

Comments
 (0)