Skip to content

Commit f8009d8

Browse files
committed
Clear cached args only when finishing bundle
1 parent cc490e7 commit f8009d8

File tree

3 files changed

+37
-30
lines changed

3 files changed

+37
-30
lines changed

sdks/python/apache_beam/runners/common.pxd

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,11 @@ cdef class PerWindowInvoker(DoFnInvoker):
100100
cdef dict kwargs_for_process_batch
101101
cdef list placeholders_for_process_batch
102102
cdef bint has_windowed_inputs
103-
cdef bint recalculate_window_args
104-
cdef bint has_cached_window_args
105-
cdef bint has_cached_window_batch_args
103+
cdef bint should_cache_args
104+
cdef list cached_args_for_process
105+
cdef dict cached_kwargs_for_process
106+
cdef list cached_args_for_process_batch
107+
cdef dict cached_kwargs_for_process_batch
106108
cdef object process_method
107109
cdef object process_batch_method
108110
cdef bint is_splittable

sdks/python/apache_beam/runners/common.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -761,16 +761,15 @@ def __init__(self,
761761
self.current_window_index = None
762762
self.stop_window_index = None
763763

764-
# TODO(https://github.com/apache/beam/issues/28776): Remove caching after
765-
# fully rolling out.
766-
# If true, always recalculate window args. If false, has_cached_window_args
767-
# and has_cached_window_batch_args will be set to true if the corresponding
768-
# self.args_for_process,have been updated and should be reused directly.
769-
self.recalculate_window_args = (
770-
self.has_windowed_inputs or 'disable_global_windowed_args_caching' in
771-
RuntimeValueProvider.experiments)
772-
self.has_cached_window_args = False
773-
self.has_cached_window_batch_args = False
764+
# If true, after the first process invocation the args for process will
765+
# be cached in cached_args_for_process and cached_kwargs_for_process and
766+
# reused on subsequent invocations in the same bundle..
767+
self.should_cache_args = (not self.has_windowed_inputs)
768+
self.cached_args_for_process = None
769+
self.cached_kwargs_for_process = None
770+
# See above, similar cached args for process_batch invocations.
771+
self.cached_args_for_process_batch = None
772+
self.cached_kwargs_for_process_batch = None
774773

775774
# Try to prepare all the arguments that can just be filled in
776775
# without any additional work. in the process function.
@@ -932,9 +931,9 @@ def _invoke_process_per_window(self,
932931
additional_kwargs,
933932
):
934933
# type: (...) -> Optional[SplitResultResidual]
935-
if self.has_cached_window_args:
934+
if self.cached_args_for_process:
936935
args_for_process, kwargs_for_process = (
937-
self.args_for_process, self.kwargs_for_process)
936+
self.cached_args_for_process, self.cached_kwargs_for_process)
938937
else:
939938
if self.has_windowed_inputs:
940939
assert len(windowed_value.windows) <= 1
@@ -945,10 +944,9 @@ def _invoke_process_per_window(self,
945944
side_inputs.extend(additional_args)
946945
args_for_process, kwargs_for_process = util.insert_values_in_args(
947946
self.args_for_process, self.kwargs_for_process, side_inputs)
948-
if not self.recalculate_window_args:
949-
self.args_for_process, self.kwargs_for_process = (
947+
if self.should_cache_args:
948+
self.cached_args_for_process, self.cached_kwargs_for_process = (
950949
args_for_process, kwargs_for_process)
951-
self.has_cached_window_args = True
952950

953951
# Extract key in the case of a stateful DoFn. Note that in the case of a
954952
# stateful DoFn, we set during __init__ self.has_windowed_inputs to be
@@ -1030,9 +1028,10 @@ def _invoke_process_batch_per_window(
10301028
):
10311029
# type: (...) -> Optional[SplitResultResidual]
10321030

1033-
if self.has_cached_window_batch_args:
1031+
if self.cached_args_for_process_batch:
10341032
args_for_process_batch, kwargs_for_process_batch = (
1035-
self.args_for_process_batch, self.kwargs_for_process_batch)
1033+
self.cached_args_for_process_batch,
1034+
self.cached_kwargs_for_process_batch)
10361035
else:
10371036
if self.has_windowed_inputs:
10381037
assert isinstance(windowed_batch, HomogeneousWindowedBatch)
@@ -1049,10 +1048,9 @@ def _invoke_process_batch_per_window(
10491048
side_inputs,
10501049
)
10511050
)
1052-
if not self.recalculate_window_args:
1053-
self.args_for_process_batch, self.kwargs_for_process_batch = (
1054-
args_for_process_batch, kwargs_for_process_batch)
1055-
self.has_cached_window_batch_args = True
1051+
if self.should_cache_args:
1052+
self.cached_args_for_process_batch = args_for_process_batch
1053+
self.cached_kwargs_for_process_batch = kwargs_for_process_batch
10561054

10571055
for i, p in self.placeholders_for_process_batch:
10581056
if core.DoFn.ElementParam == p:
@@ -1088,6 +1086,18 @@ def _invoke_process_batch_per_window(
10881086
*args_for_process_batch, **kwargs_for_process_batch),
10891087
self.threadsafe_watermark_estimator)
10901088

1089+
def invoke_finish_bundle(self):
1090+
# type: () -> None
1091+
# Clear the cached args to allow for refreshing of side inputs
1092+
# across bundles.
1093+
self.cached_args_for_process = None
1094+
self.cached_kwargs_for_process = None
1095+
self.cached_args_for_process_batch = None
1096+
self.cached_kwargs_for_process_batch = None
1097+
# super() doesn't appear to work with cython
1098+
# https://github.com/cython/cython/issues/3726
1099+
DoFnInvoker.invoke_finish_bundle(self)
1100+
10911101
@staticmethod
10921102
def _try_split(fraction,
10931103
window_index, # type: Optional[int]

sdks/python/apache_beam/tools/map_fn_microbenchmark.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,6 @@ def run_benchmark(
8888
profile_filename_base=None,
8989
):
9090
suite = [
91-
utils.BenchmarkConfig(
92-
map_with_fixed_window_side_input_pipeline,
93-
starting_point * 1000,
94-
num_runs,
95-
),
9691
utils.LinearRegressionBenchmarkConfig(
9792
map_pipeline, starting_point, num_elements_step, num_runs),
9893
utils.BenchmarkConfig(
@@ -101,7 +96,7 @@ def run_benchmark(
10196
num_runs,
10297
),
10398
utils.BenchmarkConfig(
104-
map_with_global_side_input_pipeline_uncached,
99+
map_with_fixed_window_side_input_pipeline,
105100
starting_point * 1000,
106101
num_runs,
107102
),

0 commit comments

Comments
 (0)