diff --git a/CHANGES.md b/CHANGES.md index fc4a32120afa..30672d375f1b 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -66,9 +66,12 @@ ## New Features / Improvements +* Support custom coders in Reshuffle ([#29908](https://github.com/apache/beam/issues/29908), [#33356](https://github.com/apache/beam/issues/33356)). + * X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). ## Breaking Changes +* [Python] Reshuffle now correctly respects user-specified type hints, fixing a previous bug where it might use FastPrimitivesCoder wrongly. This change could break pipelines with incorrect type hints in Reshuffle. If you have issues after upgrading, temporarily set update_compatibility_version to a previous Beam version to use the old behavior. The recommended solution is to fix the type hints in your code. ([#33932](https://github.com/apache/beam/pull/33932)) * X behavior was changed ([#X](https://github.com/apache/beam/issues/X)). diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py index 0f2a42686854..4a94885b886c 100644 --- a/sdks/python/apache_beam/coders/coders.py +++ b/sdks/python/apache_beam/coders/coders.py @@ -1445,6 +1445,17 @@ def __hash__(self): return hash( (self.wrapped_value_coder, self.timestamp_coder, self.window_coder)) + @classmethod + def from_type_hint(cls, typehint, registry): + # type: (Any, CoderRegistry) -> WindowedValueCoder + # Ideally this'd take two parameters so that one could hint at + # the window type as well instead of falling back to the + # pickle coders. + return cls(registry.get_coder(typehint.inner_type)) + + def to_type_hint(self): + return typehints.WindowedValue[self.wrapped_value_coder.to_type_hint()] + Coder.register_structured_urn( common_urns.coders.WINDOWED_VALUE.urn, WindowedValueCoder) diff --git a/sdks/python/apache_beam/coders/coders_test.py b/sdks/python/apache_beam/coders/coders_test.py index 5e5debca36e6..2cde92a76def 100644 --- a/sdks/python/apache_beam/coders/coders_test.py +++ b/sdks/python/apache_beam/coders/coders_test.py @@ -276,6 +276,12 @@ def test_numpy_int(self): _ = indata | "CombinePerKey" >> beam.CombinePerKey(sum) +class WindowedValueCoderTest(unittest.TestCase): + def test_to_type_hint(self): + coder = coders.WindowedValueCoder(coders.VarIntCoder()) + self.assertEqual(coder.to_type_hint(), typehints.WindowedValue[int]) # type: ignore[misc] + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/coders/typecoders.py b/sdks/python/apache_beam/coders/typecoders.py index 1667cb7a916a..892f508d0136 100644 --- a/sdks/python/apache_beam/coders/typecoders.py +++ b/sdks/python/apache_beam/coders/typecoders.py @@ -94,6 +94,8 @@ def register_standard_coders(self, fallback_coder): self._register_coder_internal(str, coders.StrUtf8Coder) self._register_coder_internal(typehints.TupleConstraint, coders.TupleCoder) self._register_coder_internal(typehints.DictConstraint, coders.MapCoder) + self._register_coder_internal( + typehints.WindowedTypeConstraint, coders.WindowedValueCoder) # Default fallback coders applied in that order until the first matching # coder found. default_fallback_coders = [ diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index f7b2658b3228..812c95c36519 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -32,6 +32,7 @@ from collections.abc import Iterable from typing import TYPE_CHECKING from typing import Any +from typing import Optional from typing import TypeVar from typing import Union @@ -40,6 +41,7 @@ from apache_beam import pvalue from apache_beam import typehints from apache_beam.metrics import Metrics +from apache_beam.options import pipeline_options from apache_beam.portability import common_urns from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.pvalue import AsSideInput @@ -71,11 +73,13 @@ from apache_beam.transforms.window import TimestampedValue from apache_beam.typehints import trivial_inference from apache_beam.typehints.decorators import get_signature +from apache_beam.typehints.native_type_compatibility import TypedWindowedValue from apache_beam.typehints.sharded_key_type import ShardedKeyType from apache_beam.utils import shared from apache_beam.utils import windowed_value from apache_beam.utils.annotations import deprecated from apache_beam.utils.sharded_key import ShardedKey +from apache_beam.utils.timestamp import Timestamp if TYPE_CHECKING: from apache_beam.runners.pipeline_context import PipelineContext @@ -102,6 +106,8 @@ V = TypeVar('V') T = TypeVar('T') +RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION = "2.64.0" + class CoGroupByKey(PTransform): """Groups results across several PCollections by key. @@ -922,6 +928,27 @@ def get_window_coder(self): return self._window_coder +def is_compat_version_prior_to(options, breaking_change_version): + # This function is used in a branch statement to determine whether we should + # keep the old behavior prior to a breaking change or use the new behavior. + # - If update_compatibility_version < breaking_change_version, we will return + # True and keep the old behavior. + # - If update_compatibility_version is None or >= breaking_change_version, we + # will return False and use the behavior from the breaking change. + update_compatibility_version = options.view_as( + pipeline_options.StreamingOptions).update_compatibility_version + + if update_compatibility_version is None: + return False + + compat_version = tuple(map(int, update_compatibility_version.split('.')[0:3])) + change_version = tuple(map(int, breaking_change_version.split('.')[0:3])) + for i in range(min(len(compat_version), len(change_version))): + if compat_version[i] < change_version[i]: + return True + return False + + @typehints.with_input_types(tuple[K, V]) @typehints.with_output_types(tuple[K, V]) class ReshufflePerKey(PTransform): @@ -951,6 +978,14 @@ def restore_timestamps(element): window.GlobalWindows.windowed_value((key, value), timestamp) for (value, timestamp) in values ] + + if is_compat_version_prior_to(pcoll.pipeline.options, + RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION): + pre_gbk_map = Map(reify_timestamps).with_output_types(Any) + else: + pre_gbk_map = Map(reify_timestamps).with_input_types( + tuple[K, V]).with_output_types( + tuple[K, tuple[V, Optional[Timestamp]]]) else: # typing: All conditional function variants must have identical signatures @@ -964,7 +999,14 @@ def restore_timestamps(element): key, windowed_values = element return [wv.with_value((key, wv.value)) for wv in windowed_values] - ungrouped = pcoll | Map(reify_timestamps).with_output_types(Any) + if is_compat_version_prior_to(pcoll.pipeline.options, + RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION): + pre_gbk_map = Map(reify_timestamps).with_output_types(Any) + else: + pre_gbk_map = Map(reify_timestamps).with_input_types( + tuple[K, V]).with_output_types(tuple[K, TypedWindowedValue[V]]) + + ungrouped = pcoll | pre_gbk_map # TODO(https://github.com/apache/beam/issues/19785) Using global window as # one of the standard window. This is to mitigate the Dataflow Java Runner @@ -1012,11 +1054,17 @@ def __init__(self, num_buckets=None): def expand(self, pcoll): # type: (pvalue.PValue) -> pvalue.PCollection + if is_compat_version_prior_to(pcoll.pipeline.options, + RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION): + reshuffle_step = ReshufflePerKey() + else: + reshuffle_step = ReshufflePerKey().with_input_types( + tuple[int, T]).with_output_types(tuple[int, T]) return ( pcoll | 'AddRandomKeys' >> Map(lambda t: (random.randrange(0, self.num_buckets), t) ).with_input_types(T).with_output_types(tuple[int, T]) - | ReshufflePerKey() + | reshuffle_step | 'RemoveRandomKeys' >> Map(lambda t: t[1]).with_input_types( tuple[int, T]).with_output_types(T)) diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index 6ffd5fa46795..2443a049ddba 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -1010,6 +1010,82 @@ def format_with_timestamp(element, timestamp=beam.DoFn.TimestampParam): equal_to(expected_data), label="formatted_after_reshuffle") + global _Unpicklable + global _UnpicklableCoder + + class _Unpicklable(object): + def __init__(self, value): + self.value = value + + def __getstate__(self): + raise NotImplementedError() + + def __setstate__(self, state): + raise NotImplementedError() + + class _UnpicklableCoder(beam.coders.Coder): + def encode(self, value): + return str(value.value).encode() + + def decode(self, encoded): + return _Unpicklable(int(encoded.decode())) + + def to_type_hint(self): + return _Unpicklable + + def is_deterministic(self): + return True + + def reshuffle_unpicklable_in_global_window_helper( + self, update_compatibility_version=None): + with TestPipeline(options=PipelineOptions( + update_compatibility_version=update_compatibility_version)) as pipeline: + data = [_Unpicklable(i) for i in range(5)] + expected_data = [0, 10, 20, 30, 40] + result = ( + pipeline + | beam.Create(data) + | beam.WindowInto(GlobalWindows()) + | beam.Reshuffle() + | beam.Map(lambda u: u.value * 10)) + assert_that(result, equal_to(expected_data)) + + def test_reshuffle_unpicklable_in_global_window(self): + beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder) + + self.reshuffle_unpicklable_in_global_window_helper() + # An exception is raised when running reshuffle on unpicklable objects + # prior to 2.64.0 + self.assertRaises( + RuntimeError, + self.reshuffle_unpicklable_in_global_window_helper, + "2.63.0") + + def reshuffle_unpicklable_in_non_global_window_helper( + self, update_compatibility_version=None): + with TestPipeline(options=PipelineOptions( + update_compatibility_version=update_compatibility_version)) as pipeline: + data = [_Unpicklable(i) for i in range(5)] + expected_data = [0, 0, 0, 10, 10, 10, 20, 20, 20, 30, 30, 30, 40, 40, 40] + result = ( + pipeline + | beam.Create(data) + | beam.WindowInto(window.SlidingWindows(size=3, period=1)) + | beam.Reshuffle() + | beam.Map(lambda u: u.value * 10)) + assert_that(result, equal_to(expected_data)) + + def test_reshuffle_unpicklable_in_non_global_window(self): + beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder) + + self.reshuffle_unpicklable_in_non_global_window_helper() + # An exception is raised when running reshuffle on unpicklable objects + # prior to 2.64.0 + self.assertRaises( + RuntimeError, + self.reshuffle_unpicklable_in_non_global_window_helper, + "2.63.0") + class WithKeysTest(unittest.TestCase): def setUp(self): diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py index 55653ecec19b..e9332635f255 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py @@ -25,9 +25,13 @@ import sys import types import typing +from typing import Generic +from typing import TypeVar from apache_beam.typehints import typehints +T = TypeVar('T') + _LOGGER = logging.getLogger(__name__) # Describes an entry in the type map in convert_to_beam_type. @@ -277,6 +281,18 @@ def is_builtin(typ): return getattr(typ, '__origin__', None) in _BUILTINS +# During type inference of WindowedValue, we need to pass in the inner value +# type. This cannot be achieved immediately with WindowedValue class because it +# is not parameterized. Changing it to a generic class (e.g. WindowedValue[T]) +# could work in theory. However, the class is cythonized and it seems that +# cython does not handle generic classes well. +# The workaround here is to create a separate class solely for the type +# inference purpose. This class should never be used for creating instances. +class TypedWindowedValue(Generic[T]): + def __init__(self, *args, **kwargs): + raise NotImplementedError("This class is solely for type inference") + + def convert_to_beam_type(typ): """Convert a given typing type to a Beam type. @@ -385,6 +401,10 @@ def convert_to_beam_type(typ): match=_match_is_exactly_collection, arity=1, beam_type=typehints.Collection), + _TypeMapEntry( + match=_match_issubclass(TypedWindowedValue), + arity=1, + beam_type=typehints.WindowedValue), ] # Find the first matching entry. diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py index 0e18e887c2a0..a65a0f753826 100644 --- a/sdks/python/apache_beam/typehints/typehints.py +++ b/sdks/python/apache_beam/typehints/typehints.py @@ -1213,6 +1213,15 @@ def type_check(self, instance): repr(self.inner_type), instance.value.__class__.__name__)) + def bind_type_variables(self, bindings): + bound_inner_type = bind_type_variables(self.inner_type, bindings) + if bound_inner_type == self.inner_type: + return self + return WindowedValue[bound_inner_type] + + def __repr__(self): + return 'WindowedValue[%s]' % repr(self.inner_type) + class GeneratorHint(IteratorHint): """A Generator type hint.