diff --git a/sdks/python/apache_beam/internal/cloudpickle_pickler.py b/sdks/python/apache_beam/internal/cloudpickle_pickler.py index da069954754b..ccfe798c5c08 100644 --- a/sdks/python/apache_beam/internal/cloudpickle_pickler.py +++ b/sdks/python/apache_beam/internal/cloudpickle_pickler.py @@ -107,8 +107,16 @@ def _pickle_enum_descriptor(obj): return _reconstruct_enum_descriptor, (full_name, ) -def dumps(o, enable_trace=True, use_zlib=False) -> bytes: +def dumps( + o, + enable_trace=True, + use_zlib=False, + enable_best_effort_determinism=False) -> bytes: """For internal use only; no backwards-compatibility guarantees.""" + if enable_best_effort_determinism: + # TODO: Add support once https://github.com/cloudpipe/cloudpickle/pull/563 + # is merged in. + raise NotImplementedError('This option has only been implemeneted for dill') with _pickle_lock: with io.BytesIO() as file: pickler = cloudpickle.CloudPickler(file) diff --git a/sdks/python/apache_beam/internal/cloudpickle_pickler_test.py b/sdks/python/apache_beam/internal/cloudpickle_pickler_test.py index 550745ac9ddf..17cf3c2994bd 100644 --- a/sdks/python/apache_beam/internal/cloudpickle_pickler_test.py +++ b/sdks/python/apache_beam/internal/cloudpickle_pickler_test.py @@ -135,6 +135,10 @@ def test_dataclass(self): self.assertEqual(DataClass(datum='abc'), loads(dumps(DataClass(datum='abc')))) ''') + def test_best_effort_determinism_not_implemented(self): + with self.assertRaises(NotImplementedError): + dumps(123, enable_best_effort_determinism=True) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/internal/dill_pickler.py b/sdks/python/apache_beam/internal/dill_pickler.py index e1d6b7e74e49..35953f438576 100644 --- a/sdks/python/apache_beam/internal/dill_pickler.py +++ b/sdks/python/apache_beam/internal/dill_pickler.py @@ -44,6 +44,9 @@ import dill +from apache_beam.internal.set_pickler import save_frozenset +from apache_beam.internal.set_pickler import save_set + settings = {'dill_byref': None} patch_save_code = sys.version_info >= (3, 10) and dill.__version__ == "0.3.1.1" @@ -376,9 +379,18 @@ def new_log_info(msg, *args, **kwargs): logging.getLogger('dill').setLevel(logging.WARN) -def dumps(o, enable_trace=True, use_zlib=False) -> bytes: +def dumps( + o, + enable_trace=True, + use_zlib=False, + enable_best_effort_determinism=False) -> bytes: """For internal use only; no backwards-compatibility guarantees.""" with _pickle_lock: + if enable_best_effort_determinism: + old_save_set = dill.dill.Pickler.dispatch[set] + old_save_frozenset = dill.dill.Pickler.dispatch[frozenset] + dill.dill.pickle(set, save_set) + dill.dill.pickle(frozenset, save_frozenset) try: s = dill.dumps(o, byref=settings['dill_byref']) except Exception: # pylint: disable=broad-except @@ -389,6 +401,9 @@ def dumps(o, enable_trace=True, use_zlib=False) -> bytes: raise finally: dill.dill._trace(False) # pylint: disable=protected-access + if enable_best_effort_determinism: + dill.dill.pickle(set, old_save_set) + dill.dill.pickle(frozenset, old_save_frozenset) # Compress as compactly as possible (compresslevel=9) to decrease peak memory # usage (of multiple in-memory copies) and to avoid hitting protocol buffer diff --git a/sdks/python/apache_beam/internal/pickler.py b/sdks/python/apache_beam/internal/pickler.py index 79ebd16314bf..c577bd3d4a25 100644 --- a/sdks/python/apache_beam/internal/pickler.py +++ b/sdks/python/apache_beam/internal/pickler.py @@ -38,10 +38,17 @@ desired_pickle_lib = dill_pickler -def dumps(o, enable_trace=True, use_zlib=False) -> bytes: +def dumps( + o, + enable_trace=True, + use_zlib=False, + enable_best_effort_determinism=False) -> bytes: return desired_pickle_lib.dumps( - o, enable_trace=enable_trace, use_zlib=use_zlib) + o, + enable_trace=enable_trace, + use_zlib=use_zlib, + enable_best_effort_determinism=enable_best_effort_determinism) def loads(encoded, enable_trace=True, use_zlib=False): diff --git a/sdks/python/apache_beam/internal/pickler_test.py b/sdks/python/apache_beam/internal/pickler_test.py index c26a8ee3e653..60fa1e075522 100644 --- a/sdks/python/apache_beam/internal/pickler_test.py +++ b/sdks/python/apache_beam/internal/pickler_test.py @@ -19,6 +19,7 @@ # pytype: skip-file +import random import sys import threading import types @@ -115,6 +116,52 @@ def test_dataclass(self): self.assertEqual(DataClass(datum='abc'), loads(dumps(DataClass(datum='abc')))) ''') + def maybe_get_sets_with_different_iteration_orders(self): + # Use a mix of types in an attempt to create sets with the same elements + # whose iteration order is different. + elements = [ + 100, + 'hello', + 3.14159, + True, + None, + -50, + 'world', + False, (1, 2), (4, 3), ('hello', 'world') + ] + set1 = set(elements) + # Try random addition orders until finding an order that works. + for _ in range(100): + set2 = set() + random.shuffle(elements) + for e in elements: + set2.add(e) + if list(set1) != list(set2): + break + return set1, set2 + + def test_best_effort_determinism(self): + set1, set2 = self.maybe_get_sets_with_different_iteration_orders() + self.assertEqual( + dumps(set1, enable_best_effort_determinism=True), + dumps(set2, enable_best_effort_determinism=True)) + # The test relies on the sets having different iteration orders for the + # elements. Iteration order is implementation dependent and undefined, + # meaning the test won't always be able to setup these conditions. + if list(set1) == list(set2): + self.skipTest('Set iteration orders matched. Test results inconclusive.') + + def test_disable_best_effort_determinism(self): + set1, set2 = self.maybe_get_sets_with_different_iteration_orders() + # The test relies on the sets having different iteration orders for the + # elements. Iteration order is implementation dependent and undefined, + # meaning the test won't always be able to setup these conditions. + if list(set1) == list(set2): + self.skipTest('Set iteration orders matched. Unable to complete test.') + self.assertNotEqual( + dumps(set1, enable_best_effort_determinism=False), + dumps(set2, enable_best_effort_determinism=False)) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/internal/set_pickler.py b/sdks/python/apache_beam/internal/set_pickler.py new file mode 100644 index 000000000000..90b7c646b488 --- /dev/null +++ b/sdks/python/apache_beam/internal/set_pickler.py @@ -0,0 +1,164 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Custom pickling logic for sets to make the serialization semi-deterministic. + +To make set serialization semi-deterministic, we must pick an order for the set +elements. Sets may contain elements of types not defining a comparison "<" +operator. To provide an order, we define our own custom comparison function +which supports elements of near-arbitrary types and use that to sort the +contents of each set during serialization. Attempts at determinism are made on a +best-effort basis to improve hit rates for cached workflows and the ordering +does not define a total order for all values. +""" + +import enum +import functools + + +def compare(lhs, rhs): + """Returns -1, 0, or 1 depending on whether lhs <, =, or > rhs.""" + if lhs < rhs: + return -1 + elif lhs > rhs: + return 1 + else: + return 0 + + +def generic_object_comparison(lhs, rhs, lhs_path, rhs_path, max_depth): + """Identifies which object goes first in an (almost) total order of objects. + + Args: + lhs: An arbitrary Python object or built-in type. + rhs: An arbitrary Python object or built-in type. + lhs_path: Traversal path from the root lhs object up to, but not including, + lhs. The original contents of lhs_path are restored before the function + returns. + rhs_path: Same as lhs_path except for the rhs. + max_depth: Maximum recursion depth. + + Returns: + -1, 0, or 1 depending on whether lhs or rhs goes first in the total order. + 0 if max_depth is exhausted. + 0 if lhs is in lhs_path or rhs is in rhs_path (there is a cycle). + """ + if id(lhs) == id(rhs): + # Fast path + return 0 + if type(lhs) != type(rhs): + return compare(str(type(lhs)), str(type(rhs))) + if type(lhs) in [int, float, bool, str, bool, bytes, bytearray]: + return compare(lhs, rhs) + if isinstance(lhs, enum.Enum): + # Enums can have values with arbitrary types. The names are strings. + return compare(lhs.name, rhs.name) + + # To avoid exceeding the recursion depth limit, set a limit on recursion. + max_depth -= 1 + if max_depth < 0: + return 0 + + # Check for cycles in the traversal path to avoid getting stuck in a loop. + if id(lhs) in lhs_path or id(rhs) in rhs_path: + return 0 + lhs_path.append(id(lhs)) + rhs_path.append(id(rhs)) + # The comparison logic is split across two functions to simplifying updating + # and restoring the traversal paths. + result = _generic_object_comparison_recursive_path( + lhs, rhs, lhs_path, rhs_path, max_depth) + lhs_path.pop() + rhs_path.pop() + return result + + +def _generic_object_comparison_recursive_path( + lhs, rhs, lhs_path, rhs_path, max_depth): + if type(lhs) == tuple or type(lhs) == list: + result = compare(len(lhs), len(rhs)) + if result != 0: + return result + for i in range(len(lhs)): + result = generic_object_comparison( + lhs[i], rhs[i], lhs_path, rhs_path, max_depth) + if result != 0: + return result + return 0 + if type(lhs) == frozenset or type(lhs) == set: + return generic_object_comparison( + tuple(sort_if_possible(lhs, lhs_path, rhs_path, max_depth)), + tuple(sort_if_possible(rhs, lhs_path, rhs_path, max_depth)), + lhs_path, + rhs_path, + max_depth) + if type(lhs) == dict: + lhs_keys = list(lhs.keys()) + rhs_keys = list(rhs.keys()) + result = compare(len(lhs_keys), len(rhs_keys)) + if result != 0: + return result + lhs_keys = sort_if_possible(lhs_keys, lhs_path, rhs_path, max_depth) + rhs_keys = sort_if_possible(rhs_keys, lhs_path, rhs_path, max_depth) + for lhs_key, rhs_key in zip(lhs_keys, rhs_keys): + result = generic_object_comparison( + lhs_key, rhs_key, lhs_path, rhs_path, max_depth) + if result != 0: + return result + result = generic_object_comparison( + lhs[lhs_key], rhs[rhs_key], lhs_path, rhs_path, max_depth) + if result != 0: + return result + + lhs_fields = dir(lhs) + rhs_fields = dir(rhs) + result = compare(len(lhs_fields), len(rhs_fields)) + if result != 0: + return result + for i in range(len(lhs_fields)): + result = compare(lhs_fields[i], rhs_fields[i]) + if result != 0: + return result + result = generic_object_comparison( + getattr(lhs, lhs_fields[i], None), + getattr(rhs, rhs_fields[i], None), + lhs_path, + rhs_path, + max_depth) + if result != 0: + return result + return 0 + + +def sort_if_possible(obj, lhs_path=None, rhs_path=None, max_depth=4): + def cmp(lhs, rhs): + if lhs_path is None: + # Start the traversal at the root call to cmp. + return generic_object_comparison(lhs, rhs, [], [], max_depth) + else: + # Continue the existing traversal path for recursive calls to cmp. + return generic_object_comparison(lhs, rhs, lhs_path, rhs_path, max_depth) + + return sorted(obj, key=functools.cmp_to_key(cmp)) + + +def save_set(pickler, obj): + pickler.save_set(sort_if_possible(obj)) + + +def save_frozenset(pickler, obj): + pickler.save_frozenset(sort_if_possible(obj)) diff --git a/sdks/python/apache_beam/internal/set_pickler_test.py b/sdks/python/apache_beam/internal/set_pickler_test.py new file mode 100644 index 000000000000..9507d056eb95 --- /dev/null +++ b/sdks/python/apache_beam/internal/set_pickler_test.py @@ -0,0 +1,278 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Unit tests for best-effort deterministic pickling of sets.""" + +import enum +import unittest + +from apache_beam.internal.set_pickler import sort_if_possible + + +class A: + def __init__(self, x, y): + self.x = x + self.y = y + + +class B: + pass + + +class C: + pass + + +class SortIfPossibleTest(unittest.TestCase): + def test_order_by_type(self): + a = A(1, 2) + self.assertEqual( + sort_if_possible([123, "abc", True, a]), + # A, bool, int, str + [a, True, 123, "abc"], + ) + + def test_sorts_ints(self): + self.assertEqual(sort_if_possible({5, 2, 3, 1, 4}), [1, 2, 3, 4, 5]) + + def test_sorts_booleans(self): + self.assertEqual(sort_if_possible({False, True}), [False, True]) + + def test_sorts_floats(self): + self.assertEqual(sort_if_possible({-18.0, 3.14, 2.71}), [-18.0, 2.71, 3.14]) + + def test_sorts_strings(self): + self.assertEqual( + sort_if_possible({"when", "in", "the", "course", "of"}), + ["course", "in", "of", "the", "when"], + ) + + def test_sorts_bytes(self): + self.assertEqual( + sort_if_possible({b"when", b"in", b"the", b"course", b"of"}), + [b"course", b"in", b"of", b"the", b"when"], + ) + + def test_sorts_bytearrays(self): + f = bytearray + self.assertEqual( + sort_if_possible( + [f(b"when"), f(b"in"), f(b"the"), f(b"course"), f(b"of")]), + [f(b"course"), f(b"in"), f(b"of"), f(b"the"), f(b"when")], + ) + + def test_sort_tuples_by_length(self): + self.assertEqual( + sort_if_possible({(1, 1, 1), (1, 1), (1, )}), [(1, ), (1, 1), + (1, 1, 1)]) + + def test_sort_tuples_by_element_values(self): + self.assertEqual( + sort_if_possible({(0, 0), (1, 1), (0, 1), (1, 0)}), + [(0, 0), (0, 1), (1, 0), (1, 1)], + ) + + def test_sort_nested_tuples(self): + self.assertEqual( + sort_if_possible({(1, (4, )), (1, (1, )), (1, (3, )), (1, (2, ))}), + [(1, (1, )), (1, (2, )), (1, (3, )), (1, (4, ))], + ) + + def test_sort_lists_by_length(self): + self.assertEqual( + sort_if_possible([[1, 1, 1], [1, 1], [ + 1, + ]]), [[ + 1, + ], [1, 1], [1, 1, 1]]) + + def test_sort_lists_by_element_values(self): + self.assertEqual( + sort_if_possible([[0, 0], [1, 1], [0, 1], [1, 0]]), + [[0, 0], [0, 1], [1, 0], [1, 1]], + ) + + def test_sort_frozenset_like_sorted_tuple(self): + self.assertEqual( + sort_if_possible( + {frozenset([1, 2, 3]), frozenset([1]), frozenset([1, 2, 4])}), + [frozenset([1]), frozenset([1, 2, 3]), frozenset([1, 2, 4])], + ) + + def test_sort_set_like_sorted_tuple(self): + self.assertEqual( + sort_if_possible([set([1, 2, 3]), set([1]), set([1, 2, 4])]), + [set([1]), set([1, 2, 3]), set([1, 2, 4])], + ) + + def test_order_objects_by_class_name(self): + a = A(1, 2) + b = B() + c = C() + self.assertEqual(sort_if_possible({b, c, a}), [a, b, c]) + + def test_order_objects_by_number_of_fields(self): + o1 = C() + o2 = C() + setattr(o2, "f1", 1) + o3 = C() + setattr(o3, "f1", 1) + setattr(o3, "f2", 2) + + self.assertEqual(sort_if_possible({o3, o2, o1}), [o1, o2, o3]) + + def test_order_objects_by_field_name(self): + o1 = C() + setattr(o1, "aaa", 1) + o2 = C() + setattr(o2, "bbb", 1) + o3 = C() + setattr(o3, "ccc", 1) + + self.assertEqual(sort_if_possible({o3, o2, o1}), [o1, o2, o3]) + + def test_order_objects_by_field_value(self): + a1_1 = A(1, 1) + a1_2 = A(1, 2) + a2_1 = A(2, 1) + a2_2 = A(2, 2) + + self.assertEqual( + sort_if_possible({a2_1, a1_1, a2_2, a1_2}), [a1_1, a1_2, a2_1, a2_2]) + + def test_cyclic_data(self): + def create_tuple_with_cycles(): + o = C() + t = (o, ) + setattr(t[0], "t", t) + return t + + t1 = create_tuple_with_cycles() + t2 = create_tuple_with_cycles() + t3 = create_tuple_with_cycles() + + actual = {hash(t) for t in sort_if_possible({t1, t2, t3})} + expected = {hash(t1), hash(t2), hash(t3)} + self.assertEqual(actual, expected) + + def test_order_dict_by_length(self): + self.assertEqual( + sort_if_possible([{ + 'a': 1, 'b': 2 + }, { + 'a': 1 + }, { + 'a': 1, 'b': 2, 'c': 3 + }]), [{ + 'a': 1 + }, { + 'a': 1, 'b': 2 + }, { + 'a': 1, 'b': 2, 'c': 3 + }]) + + def test_order_dict_by_key(self): + self.assertEqual( + sort_if_possible([{ + 'b': 1 + }, { + 'a': 1 + }, { + 'c': 1 + }]), [{ + 'a': 1 + }, { + 'b': 1 + }, { + 'c': 1 + }]) + + def test_order_dict_by_value(self): + self.assertEqual( + sort_if_possible([{ + 'a': 2 + }, { + 'a': 1 + }, { + 'a': 3 + }]), [{ + 'a': 1 + }, { + 'a': 2 + }, { + 'a': 3 + }]) + + def test_dict_keys_do_not_have_lt(self): + self.assertEqual( + sort_if_possible([{(1, 1): 1}, {(1, ): 1}, {(1, 1, 1): 1}]), + [{(1, ): 1}, {(1, 1): 1}, {(1, 1, 1): 1}]) + + def test_dict_values_do_not_have_lt(self): + self.assertEqual( + sort_if_possible([{ + 'a': (1, 1) + }, { + 'a': (1, ) + }, { + 'a': (1, 1, 1) + }]), [{ + 'a': (1, ) + }, { + 'a': (1, 1) + }, { + 'a': (1, 1, 1) + }]) + + def test_order_enums_by_name(self): + class CardinalDirection(enum.Enum): + NORTH = 1 + EAST = 2 + SOUTH = 3 + WEST = 4 + + self.assertEqual( + sort_if_possible({ + CardinalDirection.NORTH, + CardinalDirection.SOUTH, + CardinalDirection.EAST, + CardinalDirection.WEST, + }), + [ + CardinalDirection.EAST, + CardinalDirection.NORTH, + CardinalDirection.SOUTH, + CardinalDirection.WEST, + ]) + + def test_enum_with_many_values(self): + MyEnum = enum.Enum('MyEnum', ' '.join(f'N{i}' for i in range(10000))) + + self.assertEqual( + sort_if_possible({ + MyEnum.N789, + MyEnum.N123, + MyEnum.N456, + }), [ + MyEnum.N123, + MyEnum.N456, + MyEnum.N789, + ]) + + +if __name__ == "__main__": + unittest.main() diff --git a/sdks/python/apache_beam/pipeline.py b/sdks/python/apache_beam/pipeline.py index 61d6b190d04f..963ad51bbb97 100644 --- a/sdks/python/apache_beam/pipeline.py +++ b/sdks/python/apache_beam/pipeline.py @@ -947,6 +947,11 @@ def to_runner_api( raise ValueError( 'Only one of context or default_environment may be specified.') + # The FlumeRunner is the only runner setting this option. Use getattr + # because other runners do not have this option. + context.enable_best_effort_deterministic_pickling = getattr( + self.runner, 'enable_best_effort_deterministic_pickling', False) + # The RunnerAPI spec requires certain transforms and side-inputs to have KV # inputs (and corresponding outputs). # Currently we only upgrade to KV pairs. If there is a need for more diff --git a/sdks/python/apache_beam/runners/pipeline_context.py b/sdks/python/apache_beam/runners/pipeline_context.py index 13ab665c1eb1..132a1aedca33 100644 --- a/sdks/python/apache_beam/runners/pipeline_context.py +++ b/sdks/python/apache_beam/runners/pipeline_context.py @@ -226,6 +226,7 @@ def __init__( self.iterable_state_read = iterable_state_read self.iterable_state_write = iterable_state_write self._requirements = set(requirements) + self.enable_best_effort_deterministic_pickling = False def add_requirement(self, requirement: str) -> None: self._requirements.add(requirement) diff --git a/sdks/python/apache_beam/transforms/ptransform.py b/sdks/python/apache_beam/transforms/ptransform.py index 7d36abf638f8..aeec91cdfc97 100644 --- a/sdks/python/apache_beam/transforms/ptransform.py +++ b/sdks/python/apache_beam/transforms/ptransform.py @@ -778,9 +778,16 @@ def to_runner_api_parameter( python_urns.GENERIC_COMPOSITE_TRANSFORM, getattr(self, '_fn_api_payload', str(self))) - def to_runner_api_pickled(self, unused_context): + def to_runner_api_pickled(self, context): # type: (PipelineContext) -> tuple[str, bytes] - return (python_urns.PICKLED_TRANSFORM, pickler.dumps(self)) + return ( + python_urns.PICKLED_TRANSFORM, + pickler.dumps( + self, + enable_best_effort_determinism=context. + enable_best_effort_deterministic_pickling, + ), + ) def runner_api_requires_keyed_input(self): return False