@@ -761,16 +761,15 @@ def __init__(self,
761761 self .current_window_index = None
762762 self .stop_window_index = None
763763
764- # TODO(https://github.com/apache/beam/issues/28776): Remove caching after
765- # fully rolling out.
766- # If true, always recalculate window args. If false, has_cached_window_args
767- # and has_cached_window_batch_args will be set to true if the corresponding
768- # self.args_for_process,have been updated and should be reused directly.
769- self .recalculate_window_args = (
770- self .has_windowed_inputs or 'disable_global_windowed_args_caching' in
771- RuntimeValueProvider .experiments )
772- self .has_cached_window_args = False
773- self .has_cached_window_batch_args = False
764+ # If true, after the first process invocation the args for process will
765+ # be cached in cached_args_for_process and cached_kwargs_for_process and
766+ # reused on subsequent invocations in the same bundle..
767+ self .should_cache_args = (not self .has_windowed_inputs )
768+ self .cached_args_for_process = None
769+ self .cached_kwargs_for_process = None
770+ # See above, similar cached args for process_batch invocations.
771+ self .cached_args_for_process_batch = None
772+ self .cached_kwargs_for_process_batch = None
774773
775774 # Try to prepare all the arguments that can just be filled in
776775 # without any additional work. in the process function.
@@ -932,9 +931,9 @@ def _invoke_process_per_window(self,
932931 additional_kwargs ,
933932 ):
934933 # type: (...) -> Optional[SplitResultResidual]
935- if self .has_cached_window_args :
934+ if self .cached_args_for_process :
936935 args_for_process , kwargs_for_process = (
937- self .args_for_process , self .kwargs_for_process )
936+ self .cached_args_for_process , self .cached_kwargs_for_process )
938937 else :
939938 if self .has_windowed_inputs :
940939 assert len (windowed_value .windows ) <= 1
@@ -945,10 +944,9 @@ def _invoke_process_per_window(self,
945944 side_inputs .extend (additional_args )
946945 args_for_process , kwargs_for_process = util .insert_values_in_args (
947946 self .args_for_process , self .kwargs_for_process , side_inputs )
948- if not self .recalculate_window_args :
949- self .args_for_process , self .kwargs_for_process = (
947+ if self .should_cache_args :
948+ self .cached_args_for_process , self .cached_kwargs_for_process = (
950949 args_for_process , kwargs_for_process )
951- self .has_cached_window_args = True
952950
953951 # Extract key in the case of a stateful DoFn. Note that in the case of a
954952 # stateful DoFn, we set during __init__ self.has_windowed_inputs to be
@@ -1030,9 +1028,10 @@ def _invoke_process_batch_per_window(
10301028 ):
10311029 # type: (...) -> Optional[SplitResultResidual]
10321030
1033- if self .has_cached_window_batch_args :
1031+ if self .cached_args_for_process_batch :
10341032 args_for_process_batch , kwargs_for_process_batch = (
1035- self .args_for_process_batch , self .kwargs_for_process_batch )
1033+ self .cached_args_for_process_batch ,
1034+ self .cached_kwargs_for_process_batch )
10361035 else :
10371036 if self .has_windowed_inputs :
10381037 assert isinstance (windowed_batch , HomogeneousWindowedBatch )
@@ -1049,10 +1048,9 @@ def _invoke_process_batch_per_window(
10491048 side_inputs ,
10501049 )
10511050 )
1052- if not self .recalculate_window_args :
1053- self .args_for_process_batch , self .kwargs_for_process_batch = (
1054- args_for_process_batch , kwargs_for_process_batch )
1055- self .has_cached_window_batch_args = True
1051+ if self .should_cache_args :
1052+ self .cached_args_for_process_batch = args_for_process_batch
1053+ self .cached_kwargs_for_process_batch = kwargs_for_process_batch
10561054
10571055 for i , p in self .placeholders_for_process_batch :
10581056 if core .DoFn .ElementParam == p :
@@ -1088,6 +1086,18 @@ def _invoke_process_batch_per_window(
10881086 * args_for_process_batch , ** kwargs_for_process_batch ),
10891087 self .threadsafe_watermark_estimator )
10901088
1089+ def invoke_finish_bundle (self ):
1090+ # type: () -> None
1091+ # Clear the cached args to allow for refreshing of side inputs
1092+ # across bundles.
1093+ self .cached_args_for_process = None
1094+ self .cached_kwargs_for_process = None
1095+ self .cached_args_for_process_batch = None
1096+ self .cached_kwargs_for_process_batch = None
1097+ # super() doesn't appear to work with cython
1098+ # https://github.com/cython/cython/issues/3726
1099+ DoFnInvoker .invoke_finish_bundle (self )
1100+
10911101 @staticmethod
10921102 def _try_split (fraction ,
10931103 window_index , # type: Optional[int]
0 commit comments