Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion sdks/python/apache_beam/internal/cloudpickle_pickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,16 @@ def _pickle_enum_descriptor(obj):
return _reconstruct_enum_descriptor, (full_name, )


def dumps(o, enable_trace=True, use_zlib=False) -> bytes:
def dumps(
o,
enable_trace=True,
use_zlib=False,
enable_best_effort_determinism=False) -> bytes:
"""For internal use only; no backwards-compatibility guarantees."""
if enable_best_effort_determinism:
# TODO: Add support once https://github.com/cloudpipe/cloudpickle/pull/563
# is merged in.
raise NotImplementedError('This option has only been implemeneted for dill')
with _pickle_lock:
with io.BytesIO() as file:
pickler = cloudpickle.CloudPickler(file)
Expand Down
4 changes: 4 additions & 0 deletions sdks/python/apache_beam/internal/cloudpickle_pickler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@ def test_dataclass(self):
self.assertEqual(DataClass(datum='abc'), loads(dumps(DataClass(datum='abc'))))
''')

def test_best_effort_determinism_not_implemented(self):
with self.assertRaises(NotImplementedError):
dumps(123, enable_best_effort_determinism=True)


if __name__ == '__main__':
unittest.main()
10 changes: 9 additions & 1 deletion sdks/python/apache_beam/internal/dill_pickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from typing import Any
from typing import Dict
from typing import Tuple
from apache_beam.internal.set_pickler import save_frozenset, save_set

import dill

Expand Down Expand Up @@ -376,9 +377,16 @@ def new_log_info(msg, *args, **kwargs):
logging.getLogger('dill').setLevel(logging.WARN)


def dumps(o, enable_trace=True, use_zlib=False) -> bytes:
def dumps(
o,
enable_trace=True,
use_zlib=False,
enable_best_effort_determinism=False) -> bytes:
"""For internal use only; no backwards-compatibility guarantees."""
with _pickle_lock:
if enable_best_effort_determinism:
dill.dill.pickle(set, save_set)
dill.dill.pickle(frozenset, save_frozenset)
try:
s = dill.dumps(o, byref=settings['dill_byref'])
except Exception: # pylint: disable=broad-except
Expand Down
11 changes: 9 additions & 2 deletions sdks/python/apache_beam/internal/pickler.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,17 @@
desired_pickle_lib = dill_pickler


def dumps(o, enable_trace=True, use_zlib=False) -> bytes:
def dumps(
o,
enable_trace=True,
use_zlib=False,
enable_best_effort_determinism=False) -> bytes:

return desired_pickle_lib.dumps(
o, enable_trace=enable_trace, use_zlib=use_zlib)
o,
enable_trace=enable_trace,
use_zlib=use_zlib,
enable_best_effort_determinism=enable_best_effort_determinism)


def loads(encoded, enable_trace=True, use_zlib=False):
Expand Down
5 changes: 5 additions & 0 deletions sdks/python/apache_beam/internal/pickler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ def test_dataclass(self):
self.assertEqual(DataClass(datum='abc'), loads(dumps(DataClass(datum='abc'))))
''')

def test_best_effort_determinism(self):
self.assertEqual(
dumps({'a', 'b', 'c'}, enable_best_effort_determinism=True),
dumps({'c', 'b', 'a'}, enable_best_effort_determinism=True))


if __name__ == '__main__':
unittest.main()
164 changes: 164 additions & 0 deletions sdks/python/apache_beam/internal/set_pickler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Custom pickling logic for sets to make the serialization semi-deterministic.

To make set serialization semi-deterministic, we must pick an order for the set
elements. Sets may contain elements of types not defining a comparison "<"
operator. To provide an order, we define our own custom comparison function
which supports elements of near-arbitrary types and use that to sort the
contents of each set during serialization. Attempts at determinism are made on a
best-effort basis to improve hit rates for cached workflows and the ordering
does not define a total order for all values.
"""

import enum
import functools


def compare(lhs, rhs):
"""Returns -1, 0, or 1 depending on whether lhs <, =, or > rhs."""
if lhs < rhs:
return -1
elif lhs > rhs:
return 1
else:
return 0


def generic_object_comparison(lhs, rhs, lhs_path, rhs_path, max_depth):
"""Identifies which object goes first in an (almost) total order of objects.

Args:
lhs: An arbitrary Python object or built-in type.
rhs: An arbitrary Python object or built-in type.
lhs_path: Traversal path from the root lhs object up to, but not including,
lhs. The original contents of lhs_path are restored before the function
returns.
rhs_path: Same as lhs_path except for the rhs.
max_depth: Maximum recursion depth.

Returns:
-1, 0, or 1 depending on whether lhs or rhs goes first in the total order.
0 if max_depth is exhausted.
0 if lhs is in lhs_path or rhs is in rhs_path (there is a cycle).
"""
if id(lhs) == id(rhs):
# Fast path
return 0
if type(lhs) != type(rhs):
return compare(str(type(lhs)), str(type(rhs)))
if type(lhs) in [int, float, bool, str, bool, bytes, bytearray]:
return compare(lhs, rhs)
if isinstance(lhs, enum.Enum):
# Enums can have values with arbitrary types. The names are strings.
return compare(lhs.name, rhs.name)

# To avoid exceeding the recursion depth limit, set a limit on recursion.
max_depth -= 1
if max_depth < 0:
return 0

# Check for cycles in the traversal path to avoid getting stuck in a loop.
if id(lhs) in lhs_path or id(rhs) in rhs_path:
return 0
lhs_path.append(id(lhs))
rhs_path.append(id(rhs))
# The comparison logic is split across two functions to simplifying updating
# and restoring the traversal paths.
result = _generic_object_comparison_recursive_path(
lhs, rhs, lhs_path, rhs_path, max_depth)
lhs_path.pop()
rhs_path.pop()
return result


def _generic_object_comparison_recursive_path(
lhs, rhs, lhs_path, rhs_path, max_depth):
if type(lhs) == tuple or type(lhs) == list:
result = compare(len(lhs), len(rhs))
if result != 0:
return result
for i in range(len(lhs)):
result = generic_object_comparison(
lhs[i], rhs[i], lhs_path, rhs_path, max_depth)
if result != 0:
return result
return 0
if type(lhs) == frozenset or type(lhs) == set:
return generic_object_comparison(
tuple(sort_if_possible(lhs, lhs_path, rhs_path, max_depth)),
tuple(sort_if_possible(rhs, lhs_path, rhs_path, max_depth)),
lhs_path,
rhs_path,
max_depth)
if type(lhs) == dict:
lhs_keys = list(lhs.keys())
rhs_keys = list(rhs.keys())
result = compare(len(lhs_keys), len(rhs_keys))
if result != 0:
return result
lhs_keys = sort_if_possible(lhs_keys, lhs_path, rhs_path, max_depth)
rhs_keys = sort_if_possible(rhs_keys, lhs_path, rhs_path, max_depth)
for lhs_key, rhs_key in zip(lhs_keys, rhs_keys):
result = generic_object_comparison(
lhs_key, rhs_key, lhs_path, rhs_path, max_depth)
if result != 0:
return result
result = generic_object_comparison(
lhs[lhs_key], rhs[rhs_key], lhs_path, rhs_path, max_depth)
if result != 0:
return result

lhs_fields = dir(lhs)
rhs_fields = dir(rhs)
result = compare(len(lhs_fields), len(rhs_fields))
if result != 0:
return result
for i in range(len(lhs_fields)):
result = compare(lhs_fields[i], rhs_fields[i])
if result != 0:
return result
result = generic_object_comparison(
getattr(lhs, lhs_fields[i], None),
getattr(rhs, rhs_fields[i], None),
lhs_path,
rhs_path,
max_depth)
if result != 0:
return result
return 0


def sort_if_possible(obj, lhs_path=None, rhs_path=None, max_depth=4):
def cmp(lhs, rhs):
if lhs_path is None:
# Start the traversal at the root call to cmp.
return generic_object_comparison(lhs, rhs, [], [], max_depth)
else:
# Continue the existing traversal path for recursive calls to cmp.
return generic_object_comparison(lhs, rhs, lhs_path, rhs_path, max_depth)

return sorted(obj, key=functools.cmp_to_key(cmp))


def save_set(pickler, obj):
pickler.save_set(sort_if_possible(obj))


def save_frozenset(pickler, obj):
pickler.save_frozenset(sort_if_possible(obj))
Loading
Loading