From 277fa2c98bf796a81a93c006e30ca4727fa3bbe6 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Wed, 13 Aug 2025 23:13:59 -0400 Subject: [PATCH 1/2] Fix a bug in cogbk for not using registered coder. --- sdks/python/apache_beam/transforms/util.py | 4 +- .../apache_beam/transforms/util_test.py | 65 +++++++++++-------- 2 files changed, 42 insertions(+), 27 deletions(-) diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index 5e1e5b06fcbd..c47d75e34d4d 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -266,7 +266,9 @@ def collect_values(key, tagged_values): ] | Flatten(pipeline=self.pipeline) | GroupByKey() - | MapTuple(collect_values)) + | MapTuple(collect_values).with_input_types( + tuple[K, Iterable[tuple[str, V]]]).with_output_types( + tuple[K, dict[str, list[V]]])) @ptransform_fn diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index d08c7a860210..ad185ac6a6d1 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -87,6 +87,31 @@ 'ignore', category=FutureWarning, module='apache_beam.transform.util_test') +class _Unpicklable(object): + def __init__(self, value): + self.value = value + + def __getstate__(self): + raise NotImplementedError() + + def __setstate__(self, state): + raise NotImplementedError() + + +class _UnpicklableCoder(beam.coders.Coder): + def encode(self, value): + return str(value.value).encode() + + def decode(self, encoded): + return _Unpicklable(int(encoded.decode())) + + def to_type_hint(self): + return _Unpicklable + + def is_deterministic(self): + return True + + class CoGroupByKeyTest(unittest.TestCase): def test_co_group_by_key_on_tuple(self): with TestPipeline() as pipeline: @@ -186,6 +211,20 @@ def test_co_group_by_key_on_one(self): equal_to(expected), label='AssertOneDict') + def test_co_group_by_key_on_unpickled(self): + beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder) + values = [_Unpicklable(i) for i in range(5)] + with TestPipeline() as pipeline: + xs = pipeline | beam.Create(values) | beam.WithKeys(lambda x: x) + pcoll = ({ + 'x': xs + } + | beam.CoGroupByKey() + | beam.FlatMapTuple( + lambda k, tagged: (k.value, tagged['x'][0].value * 2))) + expected = [0, 0, 1, 2, 2, 4, 3, 6, 4, 8] + assert_that(pcoll, equal_to(expected)) + class FakeClock(object): def __init__(self, now=time.time()): @@ -1205,32 +1244,6 @@ def format_with_timestamp(element, timestamp=beam.DoFn.TimestampParam): equal_to(expected_data), label="formatted_after_reshuffle") - global _Unpicklable - global _UnpicklableCoder - - class _Unpicklable(object): - def __init__(self, value): - self.value = value - - def __getstate__(self): - raise NotImplementedError() - - def __setstate__(self, state): - raise NotImplementedError() - - class _UnpicklableCoder(beam.coders.Coder): - def encode(self, value): - return str(value.value).encode() - - def decode(self, encoded): - return _Unpicklable(int(encoded.decode())) - - def to_type_hint(self): - return _Unpicklable - - def is_deterministic(self): - return True - def reshuffle_unpicklable_in_global_window_helper( self, update_compatibility_version=None): with TestPipeline(options=PipelineOptions( From e7316161ff11df3b4b71ff7c4af88712f8989f18 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Wed, 13 Aug 2025 23:27:32 -0400 Subject: [PATCH 2/2] Allow custom tag type in typehint. --- sdks/python/apache_beam/transforms/util.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index c47d75e34d4d..c60ded52df26 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -106,6 +106,7 @@ K = TypeVar('K') V = TypeVar('V') T = TypeVar('T') +U = TypeVar('U') RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION = "2.64.0" @@ -267,8 +268,8 @@ def collect_values(key, tagged_values): | Flatten(pipeline=self.pipeline) | GroupByKey() | MapTuple(collect_values).with_input_types( - tuple[K, Iterable[tuple[str, V]]]).with_output_types( - tuple[K, dict[str, list[V]]])) + tuple[K, Iterable[tuple[U, V]]]).with_output_types( + tuple[K, dict[U, list[V]]])) @ptransform_fn