Skip to content

Commit 48743e5

Browse files
authored
Support custom coders in Reshuffle (#33932)
* Revert "Revert three commits related to supporting custom coder in reshuffle" This reverts commit 4cbf257. * Use update_compatibility_version flag to determine whether to use new typehint behavior. * Highlight changes in CHANGES.md * Minor refactoring based on feedback. * Fix lints. * Simply the documented breaking changes. * Fix typo.
1 parent 251c63d commit 48743e5

File tree

8 files changed

+177
-2
lines changed

8 files changed

+177
-2
lines changed

CHANGES.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,12 @@
6666

6767
## New Features / Improvements
6868

69+
* Support custom coders in Reshuffle ([#29908](https://github.com/apache/beam/issues/29908), [#33356](https://github.com/apache/beam/issues/33356)).
70+
6971
* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).
7072

7173
## Breaking Changes
74+
* [Python] Reshuffle now correctly respects user-specified type hints, fixing a previous bug where it might use FastPrimitivesCoder wrongly. This change could break pipelines with incorrect type hints in Reshuffle. If you have issues after upgrading, temporarily set update_compatibility_version to a previous Beam version to use the old behavior. The recommended solution is to fix the type hints in your code. ([#33932](https://github.com/apache/beam/pull/33932))
7275

7376
* X behavior was changed ([#X](https://github.com/apache/beam/issues/X)).
7477

sdks/python/apache_beam/coders/coders.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1445,6 +1445,17 @@ def __hash__(self):
14451445
return hash(
14461446
(self.wrapped_value_coder, self.timestamp_coder, self.window_coder))
14471447

1448+
@classmethod
1449+
def from_type_hint(cls, typehint, registry):
1450+
# type: (Any, CoderRegistry) -> WindowedValueCoder
1451+
# Ideally this'd take two parameters so that one could hint at
1452+
# the window type as well instead of falling back to the
1453+
# pickle coders.
1454+
return cls(registry.get_coder(typehint.inner_type))
1455+
1456+
def to_type_hint(self):
1457+
return typehints.WindowedValue[self.wrapped_value_coder.to_type_hint()]
1458+
14481459

14491460
Coder.register_structured_urn(
14501461
common_urns.coders.WINDOWED_VALUE.urn, WindowedValueCoder)

sdks/python/apache_beam/coders/coders_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,12 @@ def test_numpy_int(self):
276276
_ = indata | "CombinePerKey" >> beam.CombinePerKey(sum)
277277

278278

279+
class WindowedValueCoderTest(unittest.TestCase):
280+
def test_to_type_hint(self):
281+
coder = coders.WindowedValueCoder(coders.VarIntCoder())
282+
self.assertEqual(coder.to_type_hint(), typehints.WindowedValue[int]) # type: ignore[misc]
283+
284+
279285
if __name__ == '__main__':
280286
logging.getLogger().setLevel(logging.INFO)
281287
unittest.main()

sdks/python/apache_beam/coders/typecoders.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ def register_standard_coders(self, fallback_coder):
9494
self._register_coder_internal(str, coders.StrUtf8Coder)
9595
self._register_coder_internal(typehints.TupleConstraint, coders.TupleCoder)
9696
self._register_coder_internal(typehints.DictConstraint, coders.MapCoder)
97+
self._register_coder_internal(
98+
typehints.WindowedTypeConstraint, coders.WindowedValueCoder)
9799
# Default fallback coders applied in that order until the first matching
98100
# coder found.
99101
default_fallback_coders = [

sdks/python/apache_beam/transforms/util.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from collections.abc import Iterable
3333
from typing import TYPE_CHECKING
3434
from typing import Any
35+
from typing import Optional
3536
from typing import TypeVar
3637
from typing import Union
3738

@@ -40,6 +41,7 @@
4041
from apache_beam import pvalue
4142
from apache_beam import typehints
4243
from apache_beam.metrics import Metrics
44+
from apache_beam.options import pipeline_options
4345
from apache_beam.portability import common_urns
4446
from apache_beam.portability.api import beam_runner_api_pb2
4547
from apache_beam.pvalue import AsSideInput
@@ -71,11 +73,13 @@
7173
from apache_beam.transforms.window import TimestampedValue
7274
from apache_beam.typehints import trivial_inference
7375
from apache_beam.typehints.decorators import get_signature
76+
from apache_beam.typehints.native_type_compatibility import TypedWindowedValue
7477
from apache_beam.typehints.sharded_key_type import ShardedKeyType
7578
from apache_beam.utils import shared
7679
from apache_beam.utils import windowed_value
7780
from apache_beam.utils.annotations import deprecated
7881
from apache_beam.utils.sharded_key import ShardedKey
82+
from apache_beam.utils.timestamp import Timestamp
7983

8084
if TYPE_CHECKING:
8185
from apache_beam.runners.pipeline_context import PipelineContext
@@ -102,6 +106,8 @@
102106
V = TypeVar('V')
103107
T = TypeVar('T')
104108

109+
RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION = "2.64.0"
110+
105111

106112
class CoGroupByKey(PTransform):
107113
"""Groups results across several PCollections by key.
@@ -922,6 +928,27 @@ def get_window_coder(self):
922928
return self._window_coder
923929

924930

931+
def is_compat_version_prior_to(options, breaking_change_version):
932+
# This function is used in a branch statement to determine whether we should
933+
# keep the old behavior prior to a breaking change or use the new behavior.
934+
# - If update_compatibility_version < breaking_change_version, we will return
935+
# True and keep the old behavior.
936+
# - If update_compatibility_version is None or >= breaking_change_version, we
937+
# will return False and use the behavior from the breaking change.
938+
update_compatibility_version = options.view_as(
939+
pipeline_options.StreamingOptions).update_compatibility_version
940+
941+
if update_compatibility_version is None:
942+
return False
943+
944+
compat_version = tuple(map(int, update_compatibility_version.split('.')[0:3]))
945+
change_version = tuple(map(int, breaking_change_version.split('.')[0:3]))
946+
for i in range(min(len(compat_version), len(change_version))):
947+
if compat_version[i] < change_version[i]:
948+
return True
949+
return False
950+
951+
925952
@typehints.with_input_types(tuple[K, V])
926953
@typehints.with_output_types(tuple[K, V])
927954
class ReshufflePerKey(PTransform):
@@ -951,6 +978,14 @@ def restore_timestamps(element):
951978
window.GlobalWindows.windowed_value((key, value), timestamp)
952979
for (value, timestamp) in values
953980
]
981+
982+
if is_compat_version_prior_to(pcoll.pipeline.options,
983+
RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION):
984+
pre_gbk_map = Map(reify_timestamps).with_output_types(Any)
985+
else:
986+
pre_gbk_map = Map(reify_timestamps).with_input_types(
987+
tuple[K, V]).with_output_types(
988+
tuple[K, tuple[V, Optional[Timestamp]]])
954989
else:
955990

956991
# typing: All conditional function variants must have identical signatures
@@ -964,7 +999,14 @@ def restore_timestamps(element):
964999
key, windowed_values = element
9651000
return [wv.with_value((key, wv.value)) for wv in windowed_values]
9661001

967-
ungrouped = pcoll | Map(reify_timestamps).with_output_types(Any)
1002+
if is_compat_version_prior_to(pcoll.pipeline.options,
1003+
RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION):
1004+
pre_gbk_map = Map(reify_timestamps).with_output_types(Any)
1005+
else:
1006+
pre_gbk_map = Map(reify_timestamps).with_input_types(
1007+
tuple[K, V]).with_output_types(tuple[K, TypedWindowedValue[V]])
1008+
1009+
ungrouped = pcoll | pre_gbk_map
9681010

9691011
# TODO(https://github.com/apache/beam/issues/19785) Using global window as
9701012
# one of the standard window. This is to mitigate the Dataflow Java Runner
@@ -1012,11 +1054,17 @@ def __init__(self, num_buckets=None):
10121054

10131055
def expand(self, pcoll):
10141056
# type: (pvalue.PValue) -> pvalue.PCollection
1057+
if is_compat_version_prior_to(pcoll.pipeline.options,
1058+
RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION):
1059+
reshuffle_step = ReshufflePerKey()
1060+
else:
1061+
reshuffle_step = ReshufflePerKey().with_input_types(
1062+
tuple[int, T]).with_output_types(tuple[int, T])
10151063
return (
10161064
pcoll | 'AddRandomKeys' >>
10171065
Map(lambda t: (random.randrange(0, self.num_buckets), t)
10181066
).with_input_types(T).with_output_types(tuple[int, T])
1019-
| ReshufflePerKey()
1067+
| reshuffle_step
10201068
| 'RemoveRandomKeys' >> Map(lambda t: t[1]).with_input_types(
10211069
tuple[int, T]).with_output_types(T))
10221070

sdks/python/apache_beam/transforms/util_test.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,6 +1010,82 @@ def format_with_timestamp(element, timestamp=beam.DoFn.TimestampParam):
10101010
equal_to(expected_data),
10111011
label="formatted_after_reshuffle")
10121012

1013+
global _Unpicklable
1014+
global _UnpicklableCoder
1015+
1016+
class _Unpicklable(object):
1017+
def __init__(self, value):
1018+
self.value = value
1019+
1020+
def __getstate__(self):
1021+
raise NotImplementedError()
1022+
1023+
def __setstate__(self, state):
1024+
raise NotImplementedError()
1025+
1026+
class _UnpicklableCoder(beam.coders.Coder):
1027+
def encode(self, value):
1028+
return str(value.value).encode()
1029+
1030+
def decode(self, encoded):
1031+
return _Unpicklable(int(encoded.decode()))
1032+
1033+
def to_type_hint(self):
1034+
return _Unpicklable
1035+
1036+
def is_deterministic(self):
1037+
return True
1038+
1039+
def reshuffle_unpicklable_in_global_window_helper(
1040+
self, update_compatibility_version=None):
1041+
with TestPipeline(options=PipelineOptions(
1042+
update_compatibility_version=update_compatibility_version)) as pipeline:
1043+
data = [_Unpicklable(i) for i in range(5)]
1044+
expected_data = [0, 10, 20, 30, 40]
1045+
result = (
1046+
pipeline
1047+
| beam.Create(data)
1048+
| beam.WindowInto(GlobalWindows())
1049+
| beam.Reshuffle()
1050+
| beam.Map(lambda u: u.value * 10))
1051+
assert_that(result, equal_to(expected_data))
1052+
1053+
def test_reshuffle_unpicklable_in_global_window(self):
1054+
beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder)
1055+
1056+
self.reshuffle_unpicklable_in_global_window_helper()
1057+
# An exception is raised when running reshuffle on unpicklable objects
1058+
# prior to 2.64.0
1059+
self.assertRaises(
1060+
RuntimeError,
1061+
self.reshuffle_unpicklable_in_global_window_helper,
1062+
"2.63.0")
1063+
1064+
def reshuffle_unpicklable_in_non_global_window_helper(
1065+
self, update_compatibility_version=None):
1066+
with TestPipeline(options=PipelineOptions(
1067+
update_compatibility_version=update_compatibility_version)) as pipeline:
1068+
data = [_Unpicklable(i) for i in range(5)]
1069+
expected_data = [0, 0, 0, 10, 10, 10, 20, 20, 20, 30, 30, 30, 40, 40, 40]
1070+
result = (
1071+
pipeline
1072+
| beam.Create(data)
1073+
| beam.WindowInto(window.SlidingWindows(size=3, period=1))
1074+
| beam.Reshuffle()
1075+
| beam.Map(lambda u: u.value * 10))
1076+
assert_that(result, equal_to(expected_data))
1077+
1078+
def test_reshuffle_unpicklable_in_non_global_window(self):
1079+
beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder)
1080+
1081+
self.reshuffle_unpicklable_in_non_global_window_helper()
1082+
# An exception is raised when running reshuffle on unpicklable objects
1083+
# prior to 2.64.0
1084+
self.assertRaises(
1085+
RuntimeError,
1086+
self.reshuffle_unpicklable_in_non_global_window_helper,
1087+
"2.63.0")
1088+
10131089

10141090
class WithKeysTest(unittest.TestCase):
10151091
def setUp(self):

sdks/python/apache_beam/typehints/native_type_compatibility.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,13 @@
2525
import sys
2626
import types
2727
import typing
28+
from typing import Generic
29+
from typing import TypeVar
2830

2931
from apache_beam.typehints import typehints
3032

33+
T = TypeVar('T')
34+
3135
_LOGGER = logging.getLogger(__name__)
3236

3337
# Describes an entry in the type map in convert_to_beam_type.
@@ -277,6 +281,18 @@ def is_builtin(typ):
277281
return getattr(typ, '__origin__', None) in _BUILTINS
278282

279283

284+
# During type inference of WindowedValue, we need to pass in the inner value
285+
# type. This cannot be achieved immediately with WindowedValue class because it
286+
# is not parameterized. Changing it to a generic class (e.g. WindowedValue[T])
287+
# could work in theory. However, the class is cythonized and it seems that
288+
# cython does not handle generic classes well.
289+
# The workaround here is to create a separate class solely for the type
290+
# inference purpose. This class should never be used for creating instances.
291+
class TypedWindowedValue(Generic[T]):
292+
def __init__(self, *args, **kwargs):
293+
raise NotImplementedError("This class is solely for type inference")
294+
295+
280296
def convert_to_beam_type(typ):
281297
"""Convert a given typing type to a Beam type.
282298
@@ -385,6 +401,10 @@ def convert_to_beam_type(typ):
385401
match=_match_is_exactly_collection,
386402
arity=1,
387403
beam_type=typehints.Collection),
404+
_TypeMapEntry(
405+
match=_match_issubclass(TypedWindowedValue),
406+
arity=1,
407+
beam_type=typehints.WindowedValue),
388408
]
389409

390410
# Find the first matching entry.

sdks/python/apache_beam/typehints/typehints.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,6 +1213,15 @@ def type_check(self, instance):
12131213
repr(self.inner_type),
12141214
instance.value.__class__.__name__))
12151215

1216+
def bind_type_variables(self, bindings):
1217+
bound_inner_type = bind_type_variables(self.inner_type, bindings)
1218+
if bound_inner_type == self.inner_type:
1219+
return self
1220+
return WindowedValue[bound_inner_type]
1221+
1222+
def __repr__(self):
1223+
return 'WindowedValue[%s]' % repr(self.inner_type)
1224+
12161225

12171226
class GeneratorHint(IteratorHint):
12181227
"""A Generator type hint.

0 commit comments

Comments
 (0)