diff --git a/sdks/python/apache_beam/runners/common.pxd b/sdks/python/apache_beam/runners/common.pxd index 683bf8fcac18..4209206005c9 100644 --- a/sdks/python/apache_beam/runners/common.pxd +++ b/sdks/python/apache_beam/runners/common.pxd @@ -100,9 +100,11 @@ cdef class PerWindowInvoker(DoFnInvoker): cdef dict kwargs_for_process_batch cdef list placeholders_for_process_batch cdef bint has_windowed_inputs - cdef bint recalculate_window_args - cdef bint has_cached_window_args - cdef bint has_cached_window_batch_args + cdef bint should_cache_args + cdef list cached_args_for_process + cdef dict cached_kwargs_for_process + cdef list cached_args_for_process_batch + cdef dict cached_kwargs_for_process_batch cdef object process_method cdef object process_batch_method cdef bint is_splittable diff --git a/sdks/python/apache_beam/runners/common.py b/sdks/python/apache_beam/runners/common.py index 7a1cef4005e4..4411c0aa4d28 100644 --- a/sdks/python/apache_beam/runners/common.py +++ b/sdks/python/apache_beam/runners/common.py @@ -761,16 +761,15 @@ def __init__(self, self.current_window_index = None self.stop_window_index = None - # TODO(https://github.com/apache/beam/issues/28776): Remove caching after - # fully rolling out. - # If true, always recalculate window args. If false, has_cached_window_args - # and has_cached_window_batch_args will be set to true if the corresponding - # self.args_for_process,have been updated and should be reused directly. - self.recalculate_window_args = ( - self.has_windowed_inputs or 'disable_global_windowed_args_caching' in - RuntimeValueProvider.experiments) - self.has_cached_window_args = False - self.has_cached_window_batch_args = False + # If true, after the first process invocation the args for process will + # be cached in cached_args_for_process and cached_kwargs_for_process and + # reused on subsequent invocations in the same bundle.. + self.should_cache_args = (not self.has_windowed_inputs) + self.cached_args_for_process = None + self.cached_kwargs_for_process = None + # See above, similar cached args for process_batch invocations. + self.cached_args_for_process_batch = None + self.cached_kwargs_for_process_batch = None # Try to prepare all the arguments that can just be filled in # without any additional work. in the process function. @@ -932,9 +931,9 @@ def _invoke_process_per_window(self, additional_kwargs, ): # type: (...) -> Optional[SplitResultResidual] - if self.has_cached_window_args: + if self.cached_args_for_process: args_for_process, kwargs_for_process = ( - self.args_for_process, self.kwargs_for_process) + self.cached_args_for_process, self.cached_kwargs_for_process) else: if self.has_windowed_inputs: assert len(windowed_value.windows) <= 1 @@ -945,10 +944,9 @@ def _invoke_process_per_window(self, side_inputs.extend(additional_args) args_for_process, kwargs_for_process = util.insert_values_in_args( self.args_for_process, self.kwargs_for_process, side_inputs) - if not self.recalculate_window_args: - self.args_for_process, self.kwargs_for_process = ( + if self.should_cache_args: + self.cached_args_for_process, self.cached_kwargs_for_process = ( args_for_process, kwargs_for_process) - self.has_cached_window_args = True # Extract key in the case of a stateful DoFn. Note that in the case of a # stateful DoFn, we set during __init__ self.has_windowed_inputs to be @@ -1030,9 +1028,10 @@ def _invoke_process_batch_per_window( ): # type: (...) -> Optional[SplitResultResidual] - if self.has_cached_window_batch_args: + if self.cached_args_for_process_batch: args_for_process_batch, kwargs_for_process_batch = ( - self.args_for_process_batch, self.kwargs_for_process_batch) + self.cached_args_for_process_batch, + self.cached_kwargs_for_process_batch) else: if self.has_windowed_inputs: assert isinstance(windowed_batch, HomogeneousWindowedBatch) @@ -1049,10 +1048,9 @@ def _invoke_process_batch_per_window( side_inputs, ) ) - if not self.recalculate_window_args: - self.args_for_process_batch, self.kwargs_for_process_batch = ( - args_for_process_batch, kwargs_for_process_batch) - self.has_cached_window_batch_args = True + if self.should_cache_args: + self.cached_args_for_process_batch = args_for_process_batch + self.cached_kwargs_for_process_batch = kwargs_for_process_batch for i, p in self.placeholders_for_process_batch: if core.DoFn.ElementParam == p: @@ -1088,6 +1086,18 @@ def _invoke_process_batch_per_window( *args_for_process_batch, **kwargs_for_process_batch), self.threadsafe_watermark_estimator) + def invoke_finish_bundle(self): + # type: () -> None + # Clear the cached args to allow for refreshing of side inputs + # across bundles. + self.cached_args_for_process = None + self.cached_kwargs_for_process = None + self.cached_args_for_process_batch = None + self.cached_kwargs_for_process_batch = None + # super() doesn't appear to work with cython + # https://github.com/cython/cython/issues/3726 + DoFnInvoker.invoke_finish_bundle(self) + @staticmethod def _try_split(fraction, window_index, # type: Optional[int] diff --git a/sdks/python/apache_beam/tools/map_fn_microbenchmark.py b/sdks/python/apache_beam/tools/map_fn_microbenchmark.py index cdbc5c4e6cb4..d5bee1b8fc70 100644 --- a/sdks/python/apache_beam/tools/map_fn_microbenchmark.py +++ b/sdks/python/apache_beam/tools/map_fn_microbenchmark.py @@ -23,7 +23,7 @@ This executes the same codepaths that are run on the Fn API (and Dataflow) workers, but is generally easier to run (locally) and more stable. It does -not, on the other hand, excercise any non-trivial amount of IO (e.g. shuffle). +not, on the other hand, exercise any non-trivial amount of IO (e.g. shuffle). Run as @@ -32,41 +32,95 @@ # pytype: skip-file +import argparse import logging -import time - -from scipy import stats import apache_beam as beam from apache_beam.tools import utils +from apache_beam.transforms.window import FixedWindows -def run_benchmark(num_maps=100, num_runs=10, num_elements_step=1000): - timings = {} - for run in range(num_runs): - num_elements = num_elements_step * run + 1 - start = time.time() +def map_pipeline(num_elements, num_maps=100): + def _pipeline_runner(): with beam.Pipeline() as p: pc = p | beam.Create(list(range(num_elements))) for ix in range(num_maps): pc = pc | 'Map%d' % ix >> beam.FlatMap(lambda x: (None, )) - timings[num_elements] = time.time() - start - print( - "%6d element%s %g sec" % ( - num_elements, - " " if num_elements == 1 else "s", - timings[num_elements])) - - print() - # pylint: disable=unused-variable - gradient, intercept, r_value, p_value, std_err = stats.linregress( - *list(zip(*list(timings.items())))) - print("Fixed cost ", intercept) - print("Per-element ", gradient / num_maps) - print("R^2 ", r_value**2) + + return _pipeline_runner + + +def map_with_global_side_input_pipeline(num_elements, num_maps=100): + def add(element, side_input): + return element + side_input + + def _pipeline_runner(): + with beam.Pipeline() as p: + side = p | 'CreateSide' >> beam.Create([1]) + pc = p | 'CreateMain' >> beam.Create(list(range(num_elements))) + for ix in range(num_maps): + pc = pc | 'Map%d' % ix >> beam.Map(add, beam.pvalue.AsSingleton(side)) + + return _pipeline_runner + + +def map_with_fixed_window_side_input_pipeline(num_elements, num_maps=100): + def add(element, side_input): + return element + side_input + + def _pipeline_runner(): + with beam.Pipeline() as p: + side = p | 'CreateSide' >> beam.Create( + [1]) | 'WindowSide' >> beam.WindowInto(FixedWindows(1000)) + pc = p | 'CreateMain' >> beam.Create(list(range( + num_elements))) | 'WindowMain' >> beam.WindowInto(FixedWindows(1000)) + for ix in range(num_maps): + pc = pc | 'Map%d' % ix >> beam.Map(add, beam.pvalue.AsSingleton(side)) + + return _pipeline_runner + + +def run_benchmark( + starting_point=1, + num_runs=10, + num_elements_step=100, + verbose=True, + profile_filename_base=None, +): + suite = [ + utils.LinearRegressionBenchmarkConfig( + map_pipeline, starting_point, num_elements_step, num_runs), + utils.BenchmarkConfig( + map_with_global_side_input_pipeline, + starting_point * 1000, + num_runs, + ), + utils.BenchmarkConfig( + map_with_fixed_window_side_input_pipeline, + starting_point * 1000, + num_runs, + ), + ] + return utils.run_benchmarks( + suite, verbose=verbose, profile_filename_base=profile_filename_base) if __name__ == '__main__': logging.basicConfig() utils.check_compiled('apache_beam.runners.common') - run_benchmark() + + parser = argparse.ArgumentParser() + parser.add_argument('--num_runs', default=10, type=int) + parser.add_argument('--starting_point', default=1, type=int) + parser.add_argument('--increment', default=100, type=int) + parser.add_argument('--verbose', default=True, type=bool) + parser.add_argument('--profile_filename_base', default=None, type=str) + options = parser.parse_args() + + run_benchmark( + options.starting_point, + options.num_runs, + options.increment, + options.verbose, + options.profile_filename_base, + ) diff --git a/sdks/python/apache_beam/tools/utils.py b/sdks/python/apache_beam/tools/utils.py index e3df3f2c1c6f..adf82b82582f 100644 --- a/sdks/python/apache_beam/tools/utils.py +++ b/sdks/python/apache_beam/tools/utils.py @@ -20,6 +20,7 @@ # pytype: skip-file import collections +import cProfile import gc import importlib import os @@ -70,10 +71,11 @@ def __call__(self): size: int num_runs: int + def name(self): + return getattr(self.benchmark, '__name__', str(self.benchmark)) + def __str__(self): - return "%s, %s element(s)" % ( - getattr(self.benchmark, '__name__', str(self.benchmark)), - str(self.size)) + return "%s, %s element(s)" % (self.name(), str(self.size)) class LinearRegressionBenchmarkConfig(NamedTuple): @@ -102,14 +104,15 @@ def __call__(self): increment: int num_runs: int + def name(self): + return getattr(self.benchmark, '__name__', str(self.benchmark)) + def __str__(self): return "%s, %s element(s) at start, %s growth per run" % ( - getattr(self.benchmark, '__name__', str(self.benchmark)), - str(self.starting_point), - str(self.increment)) + self.name(), str(self.starting_point), str(self.increment)) -def run_benchmarks(benchmark_suite, verbose=True): +def run_benchmarks(benchmark_suite, verbose=True, profile_filename_base=None): """Runs benchmarks, and collects execution times. A simple instrumentation to run a callable several times, collect and print @@ -118,17 +121,25 @@ def run_benchmarks(benchmark_suite, verbose=True): Args: benchmark_suite: A list of BenchmarkConfig. verbose: bool, whether to print benchmark results to stdout. + profile_filename_base: str, if present each benchmark will be profiled and + have stats dumped to per-benchmark files beginning with this prefix. Returns: A dictionary of the form string -> list of floats. Keys of the dictionary are benchmark names, values are execution times in seconds for each run. """ + profiler = cProfile.Profile() if profile_filename_base else None + def run(benchmark: BenchmarkFactoryFn, size: int): # Contain each run of a benchmark inside a function so that any temporary # objects can be garbage-collected after the run. benchmark_instance_callable = benchmark(size) start = time.time() + if profiler: + profiler.enable() _ = benchmark_instance_callable() + if profiler: + profiler.disable() return time.time() - start cost_series = collections.defaultdict(list) @@ -161,6 +172,13 @@ def run(benchmark: BenchmarkFactoryFn, size: int): # Incrementing the size of the benchmark run by the step size size += step + + if profiler: + filename = profile_filename_base + benchmark_config.name() + '.prof' + if verbose: + print("Dumping profile to " + filename) + profiler.dump_stats(filename) + if verbose: print("")