Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,12 @@

## New Features / Improvements

* Support custom coders in Reshuffle ([#29908](https://github.com/apache/beam/issues/29908), [#33356](https://github.com/apache/beam/issues/33356)).

* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)).

## Breaking Changes
* [Python] Reshuffle now correctly respects user-specified type hints, fixing a previous bug where it might use FastPrimitivesCoder wrongly. This change could break pipelines with incorrect type hints in Reshuffle. If you have issues after upgrading, temporarily set update_compatibility_version to a previous Beam version to use the old behavior. The recommended solution is to fix the type hints in your code. ([#33932](https://github.com/apache/beam/pull/33932))

* X behavior was changed ([#X](https://github.com/apache/beam/issues/X)).

Expand Down
11 changes: 11 additions & 0 deletions sdks/python/apache_beam/coders/coders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,6 +1445,17 @@ def __hash__(self):
return hash(
(self.wrapped_value_coder, self.timestamp_coder, self.window_coder))

@classmethod
def from_type_hint(cls, typehint, registry):
# type: (Any, CoderRegistry) -> WindowedValueCoder
# Ideally this'd take two parameters so that one could hint at
# the window type as well instead of falling back to the
# pickle coders.
return cls(registry.get_coder(typehint.inner_type))

def to_type_hint(self):
return typehints.WindowedValue[self.wrapped_value_coder.to_type_hint()]


Coder.register_structured_urn(
common_urns.coders.WINDOWED_VALUE.urn, WindowedValueCoder)
Expand Down
6 changes: 6 additions & 0 deletions sdks/python/apache_beam/coders/coders_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,12 @@ def test_numpy_int(self):
_ = indata | "CombinePerKey" >> beam.CombinePerKey(sum)


class WindowedValueCoderTest(unittest.TestCase):
def test_to_type_hint(self):
coder = coders.WindowedValueCoder(coders.VarIntCoder())
self.assertEqual(coder.to_type_hint(), typehints.WindowedValue[int]) # type: ignore[misc]


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
2 changes: 2 additions & 0 deletions sdks/python/apache_beam/coders/typecoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def register_standard_coders(self, fallback_coder):
self._register_coder_internal(str, coders.StrUtf8Coder)
self._register_coder_internal(typehints.TupleConstraint, coders.TupleCoder)
self._register_coder_internal(typehints.DictConstraint, coders.MapCoder)
self._register_coder_internal(
typehints.WindowedTypeConstraint, coders.WindowedValueCoder)
# Default fallback coders applied in that order until the first matching
# coder found.
default_fallback_coders = [
Expand Down
52 changes: 50 additions & 2 deletions sdks/python/apache_beam/transforms/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from collections.abc import Iterable
from typing import TYPE_CHECKING
from typing import Any
from typing import Optional
from typing import TypeVar
from typing import Union

Expand All @@ -40,6 +41,7 @@
from apache_beam import pvalue
from apache_beam import typehints
from apache_beam.metrics import Metrics
from apache_beam.options import pipeline_options
from apache_beam.portability import common_urns
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.pvalue import AsSideInput
Expand Down Expand Up @@ -71,11 +73,13 @@
from apache_beam.transforms.window import TimestampedValue
from apache_beam.typehints import trivial_inference
from apache_beam.typehints.decorators import get_signature
from apache_beam.typehints.native_type_compatibility import TypedWindowedValue
from apache_beam.typehints.sharded_key_type import ShardedKeyType
from apache_beam.utils import shared
from apache_beam.utils import windowed_value
from apache_beam.utils.annotations import deprecated
from apache_beam.utils.sharded_key import ShardedKey
from apache_beam.utils.timestamp import Timestamp

if TYPE_CHECKING:
from apache_beam.runners.pipeline_context import PipelineContext
Expand All @@ -102,6 +106,8 @@
V = TypeVar('V')
T = TypeVar('T')

RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION = "2.64.0"


class CoGroupByKey(PTransform):
"""Groups results across several PCollections by key.
Expand Down Expand Up @@ -922,6 +928,27 @@ def get_window_coder(self):
return self._window_coder


def is_compat_version_prior_to(options, breaking_change_version):
# This function is used in a branch statement to determine whether we should
# keep the old behavior prior to a breaking change or use the new behavior.
# - If update_compatibility_version < breaking_change_version, we will return
# True and keep the old behavior.
# - If update_compatibility_version is None or >= breaking_change_version, we
# will return False and use the behavior from the breaking change.
update_compatibility_version = options.view_as(
pipeline_options.StreamingOptions).update_compatibility_version

if update_compatibility_version is None:
return False

compat_version = tuple(map(int, update_compatibility_version.split('.')[0:3]))
change_version = tuple(map(int, breaking_change_version.split('.')[0:3]))
for i in range(min(len(compat_version), len(change_version))):
if compat_version[i] < change_version[i]:
return True
return False


@typehints.with_input_types(tuple[K, V])
@typehints.with_output_types(tuple[K, V])
class ReshufflePerKey(PTransform):
Expand Down Expand Up @@ -951,6 +978,14 @@ def restore_timestamps(element):
window.GlobalWindows.windowed_value((key, value), timestamp)
for (value, timestamp) in values
]

if is_compat_version_prior_to(pcoll.pipeline.options,
RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION):
pre_gbk_map = Map(reify_timestamps).with_output_types(Any)
else:
pre_gbk_map = Map(reify_timestamps).with_input_types(
tuple[K, V]).with_output_types(
tuple[K, tuple[V, Optional[Timestamp]]])
else:

# typing: All conditional function variants must have identical signatures
Expand All @@ -964,7 +999,14 @@ def restore_timestamps(element):
key, windowed_values = element
return [wv.with_value((key, wv.value)) for wv in windowed_values]

ungrouped = pcoll | Map(reify_timestamps).with_output_types(Any)
if is_compat_version_prior_to(pcoll.pipeline.options,
RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION):
pre_gbk_map = Map(reify_timestamps).with_output_types(Any)
else:
pre_gbk_map = Map(reify_timestamps).with_input_types(
tuple[K, V]).with_output_types(tuple[K, TypedWindowedValue[V]])

ungrouped = pcoll | pre_gbk_map

# TODO(https://github.com/apache/beam/issues/19785) Using global window as
# one of the standard window. This is to mitigate the Dataflow Java Runner
Expand Down Expand Up @@ -1012,11 +1054,17 @@ def __init__(self, num_buckets=None):

def expand(self, pcoll):
# type: (pvalue.PValue) -> pvalue.PCollection
if is_compat_version_prior_to(pcoll.pipeline.options,
RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION):
reshuffle_step = ReshufflePerKey()
else:
reshuffle_step = ReshufflePerKey().with_input_types(
tuple[int, T]).with_output_types(tuple[int, T])
return (
pcoll | 'AddRandomKeys' >>
Map(lambda t: (random.randrange(0, self.num_buckets), t)
).with_input_types(T).with_output_types(tuple[int, T])
| ReshufflePerKey()
| reshuffle_step
| 'RemoveRandomKeys' >> Map(lambda t: t[1]).with_input_types(
tuple[int, T]).with_output_types(T))

Expand Down
76 changes: 76 additions & 0 deletions sdks/python/apache_beam/transforms/util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,82 @@ def format_with_timestamp(element, timestamp=beam.DoFn.TimestampParam):
equal_to(expected_data),
label="formatted_after_reshuffle")

global _Unpicklable
global _UnpicklableCoder

class _Unpicklable(object):
def __init__(self, value):
self.value = value

def __getstate__(self):
raise NotImplementedError()

def __setstate__(self, state):
raise NotImplementedError()

class _UnpicklableCoder(beam.coders.Coder):
def encode(self, value):
return str(value.value).encode()

def decode(self, encoded):
return _Unpicklable(int(encoded.decode()))

def to_type_hint(self):
return _Unpicklable

def is_deterministic(self):
return True

def reshuffle_unpicklable_in_global_window_helper(
self, update_compatibility_version=None):
with TestPipeline(options=PipelineOptions(
update_compatibility_version=update_compatibility_version)) as pipeline:
data = [_Unpicklable(i) for i in range(5)]
expected_data = [0, 10, 20, 30, 40]
result = (
pipeline
| beam.Create(data)
| beam.WindowInto(GlobalWindows())
| beam.Reshuffle()
| beam.Map(lambda u: u.value * 10))
assert_that(result, equal_to(expected_data))

def test_reshuffle_unpicklable_in_global_window(self):
beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder)

self.reshuffle_unpicklable_in_global_window_helper()
# An exception is raised when running reshuffle on unpicklable objects
# prior to 2.64.0
self.assertRaises(
RuntimeError,
self.reshuffle_unpicklable_in_global_window_helper,
"2.63.0")

def reshuffle_unpicklable_in_non_global_window_helper(
self, update_compatibility_version=None):
with TestPipeline(options=PipelineOptions(
update_compatibility_version=update_compatibility_version)) as pipeline:
data = [_Unpicklable(i) for i in range(5)]
expected_data = [0, 0, 0, 10, 10, 10, 20, 20, 20, 30, 30, 30, 40, 40, 40]
result = (
pipeline
| beam.Create(data)
| beam.WindowInto(window.SlidingWindows(size=3, period=1))
| beam.Reshuffle()
| beam.Map(lambda u: u.value * 10))
assert_that(result, equal_to(expected_data))

def test_reshuffle_unpicklable_in_non_global_window(self):
beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder)

self.reshuffle_unpicklable_in_non_global_window_helper()
# An exception is raised when running reshuffle on unpicklable objects
# prior to 2.64.0
self.assertRaises(
RuntimeError,
self.reshuffle_unpicklable_in_non_global_window_helper,
"2.63.0")


class WithKeysTest(unittest.TestCase):
def setUp(self):
Expand Down
20 changes: 20 additions & 0 deletions sdks/python/apache_beam/typehints/native_type_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@
import sys
import types
import typing
from typing import Generic
from typing import TypeVar

from apache_beam.typehints import typehints

T = TypeVar('T')

_LOGGER = logging.getLogger(__name__)

# Describes an entry in the type map in convert_to_beam_type.
Expand Down Expand Up @@ -277,6 +281,18 @@ def is_builtin(typ):
return getattr(typ, '__origin__', None) in _BUILTINS


# During type inference of WindowedValue, we need to pass in the inner value
# type. This cannot be achieved immediately with WindowedValue class because it
# is not parameterized. Changing it to a generic class (e.g. WindowedValue[T])
# could work in theory. However, the class is cythonized and it seems that
# cython does not handle generic classes well.
# The workaround here is to create a separate class solely for the type
# inference purpose. This class should never be used for creating instances.
class TypedWindowedValue(Generic[T]):
def __init__(self, *args, **kwargs):
raise NotImplementedError("This class is solely for type inference")


def convert_to_beam_type(typ):
"""Convert a given typing type to a Beam type.

Expand Down Expand Up @@ -385,6 +401,10 @@ def convert_to_beam_type(typ):
match=_match_is_exactly_collection,
arity=1,
beam_type=typehints.Collection),
_TypeMapEntry(
match=_match_issubclass(TypedWindowedValue),
arity=1,
beam_type=typehints.WindowedValue),
]

# Find the first matching entry.
Expand Down
9 changes: 9 additions & 0 deletions sdks/python/apache_beam/typehints/typehints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,15 @@ def type_check(self, instance):
repr(self.inner_type),
instance.value.__class__.__name__))

def bind_type_variables(self, bindings):
bound_inner_type = bind_type_variables(self.inner_type, bindings)
if bound_inner_type == self.inner_type:
return self
return WindowedValue[bound_inner_type]

def __repr__(self):
return 'WindowedValue[%s]' % repr(self.inner_type)


class GeneratorHint(IteratorHint):
"""A Generator type hint.
Expand Down
Loading