Skip to content

Commit 82540a8

Browse files
authored
Merge pull request #34740 Only declare state backed iterables on code channel.
2 parents ae7bf20 + a2fcbd9 commit 82540a8

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

sdks/python/apache_beam/runners/portability/fn_api_runner/translations.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ def add_or_get_coder_id(
459459
self.components.coders[new_coder_id].CopyFrom(coder_proto)
460460
return new_coder_id
461461

462-
def add_data_channel_coder(self, pcoll_id):
462+
def add_data_channel_coder(self, pcoll_id, is_input=False):
463463
pcoll = self.components.pcollections[pcoll_id]
464464
proto = beam_runner_api_pb2.Coder(
465465
spec=beam_runner_api_pb2.FunctionSpec(
@@ -469,8 +469,12 @@ def add_data_channel_coder(self, pcoll_id):
469469
self.components.windowing_strategies[
470470
pcoll.windowing_strategy_id].window_coder_id
471471
])
472+
windowed_coder_id = self.add_or_get_coder_id(
473+
proto, pcoll.coder_id + '_windowed')
474+
if is_input and self.use_state_iterables:
475+
windowed_coder_id = self.with_state_iterables(windowed_coder_id)
472476
self.data_channel_coders[pcoll_id] = self.maybe_length_prefixed_coder(
473-
self.add_or_get_coder_id(proto, pcoll.coder_id + '_windowed'))
477+
windowed_coder_id)
474478

475479
@memoize_on_instance
476480
def with_state_iterables(self, coder_id):
@@ -1692,10 +1696,6 @@ def expand_gbk(stages, pipeline_context):
16921696
for pcoll_id in transform.inputs.values():
16931697
pipeline_context.length_prefix_pcoll_coders(pcoll_id)
16941698
for pcoll_id in transform.outputs.values():
1695-
if pipeline_context.use_state_iterables:
1696-
pipeline_context.components.pcollections[
1697-
pcoll_id].coder_id = pipeline_context.with_state_iterables(
1698-
pipeline_context.components.pcollections[pcoll_id].coder_id)
16991699
pipeline_context.length_prefix_pcoll_coders(pcoll_id)
17001700

17011701
# This is used later to correlate the read and write.
@@ -2078,7 +2078,8 @@ def populate_data_channel_coders(stages, pipeline_context):
20782078
sdk_pcoll_id = only_element(transform.outputs.values())
20792079
else:
20802080
sdk_pcoll_id = only_element(transform.inputs.values())
2081-
pipeline_context.add_data_channel_coder(sdk_pcoll_id)
2081+
pipeline_context.add_data_channel_coder(
2082+
sdk_pcoll_id, transform.spec.urn == bundle_processor.DATA_INPUT_URN)
20822083

20832084
return stages
20842085

0 commit comments

Comments
 (0)