forked from apache/beam
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathset_pickler.py
More file actions
164 lines (143 loc) · 5.62 KB
/
set_pickler.py
File metadata and controls
164 lines (143 loc) · 5.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Custom pickling logic for sets to make the serialization semi-deterministic.
To make set serialization semi-deterministic, we must pick an order for the set
elements. Sets may contain elements of types not defining a comparison "<"
operator. To provide an order, we define our own custom comparison function
which supports elements of near-arbitrary types and use that to sort the
contents of each set during serialization. Attempts at determinism are made on a
best-effort basis to improve hit rates for cached workflows and the ordering
does not define a total order for all values.
"""
import enum
import functools
def compare(lhs, rhs):
"""Returns -1, 0, or 1 depending on whether lhs <, =, or > rhs."""
if lhs < rhs:
return -1
elif lhs > rhs:
return 1
else:
return 0
def generic_object_comparison(lhs, rhs, lhs_path, rhs_path, max_depth):
"""Identifies which object goes first in an (almost) total order of objects.
Args:
lhs: An arbitrary Python object or built-in type.
rhs: An arbitrary Python object or built-in type.
lhs_path: Traversal path from the root lhs object up to, but not including,
lhs. The original contents of lhs_path are restored before the function
returns.
rhs_path: Same as lhs_path except for the rhs.
max_depth: Maximum recursion depth.
Returns:
-1, 0, or 1 depending on whether lhs or rhs goes first in the total order.
0 if max_depth is exhausted.
0 if lhs is in lhs_path or rhs is in rhs_path (there is a cycle).
"""
if id(lhs) == id(rhs):
# Fast path
return 0
if type(lhs) != type(rhs):
return compare(str(type(lhs)), str(type(rhs)))
if type(lhs) in [int, float, bool, str, bool, bytes, bytearray]:
return compare(lhs, rhs)
if isinstance(lhs, enum.Enum):
# Enums can have values with arbitrary types. The names are strings.
return compare(lhs.name, rhs.name)
# To avoid exceeding the recursion depth limit, set a limit on recursion.
max_depth -= 1
if max_depth < 0:
return 0
# Check for cycles in the traversal path to avoid getting stuck in a loop.
if id(lhs) in lhs_path or id(rhs) in rhs_path:
return 0
lhs_path.append(id(lhs))
rhs_path.append(id(rhs))
# The comparison logic is split across two functions to simplifying updating
# and restoring the traversal paths.
result = _generic_object_comparison_recursive_path(
lhs, rhs, lhs_path, rhs_path, max_depth)
lhs_path.pop()
rhs_path.pop()
return result
def _generic_object_comparison_recursive_path(
lhs, rhs, lhs_path, rhs_path, max_depth):
if type(lhs) == tuple or type(lhs) == list:
result = compare(len(lhs), len(rhs))
if result != 0:
return result
for i in range(len(lhs)):
result = generic_object_comparison(
lhs[i], rhs[i], lhs_path, rhs_path, max_depth)
if result != 0:
return result
return 0
if type(lhs) == frozenset or type(lhs) == set:
return generic_object_comparison(
tuple(sort_if_possible(lhs, lhs_path, rhs_path, max_depth)),
tuple(sort_if_possible(rhs, lhs_path, rhs_path, max_depth)),
lhs_path,
rhs_path,
max_depth)
if type(lhs) == dict:
lhs_keys = list(lhs.keys())
rhs_keys = list(rhs.keys())
result = compare(len(lhs_keys), len(rhs_keys))
if result != 0:
return result
lhs_keys = sort_if_possible(lhs_keys, lhs_path, rhs_path, max_depth)
rhs_keys = sort_if_possible(rhs_keys, lhs_path, rhs_path, max_depth)
for lhs_key, rhs_key in zip(lhs_keys, rhs_keys):
result = generic_object_comparison(
lhs_key, rhs_key, lhs_path, rhs_path, max_depth)
if result != 0:
return result
result = generic_object_comparison(
lhs[lhs_key], rhs[rhs_key], lhs_path, rhs_path, max_depth)
if result != 0:
return result
lhs_fields = dir(lhs)
rhs_fields = dir(rhs)
result = compare(len(lhs_fields), len(rhs_fields))
if result != 0:
return result
for i in range(len(lhs_fields)):
result = compare(lhs_fields[i], rhs_fields[i])
if result != 0:
return result
result = generic_object_comparison(
getattr(lhs, lhs_fields[i], None),
getattr(rhs, rhs_fields[i], None),
lhs_path,
rhs_path,
max_depth)
if result != 0:
return result
return 0
def sort_if_possible(obj, lhs_path=None, rhs_path=None, max_depth=4):
def cmp(lhs, rhs):
if lhs_path is None:
# Start the traversal at the root call to cmp.
return generic_object_comparison(lhs, rhs, [], [], max_depth)
else:
# Continue the existing traversal path for recursive calls to cmp.
return generic_object_comparison(lhs, rhs, lhs_path, rhs_path, max_depth)
return sorted(obj, key=functools.cmp_to_key(cmp))
def save_set(pickler, obj):
pickler.save_set(sort_if_possible(obj))
def save_frozenset(pickler, obj):
pickler.save_frozenset(sort_if_possible(obj))