Skip to content

Commit bb9ab00

Browse files
authored
Fix a bug in cogbk for not using registered coder. (#35862)
* Fix a bug in cogbk for not using registered coder. * Allow custom tag type in typehint.
1 parent 1a000da commit bb9ab00

File tree

2 files changed

+43
-27
lines changed

2 files changed

+43
-27
lines changed

sdks/python/apache_beam/transforms/util.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
K = TypeVar('K')
107107
V = TypeVar('V')
108108
T = TypeVar('T')
109+
U = TypeVar('U')
109110

110111
RESHUFFLE_TYPEHINT_BREAKING_CHANGE_VERSION = "2.64.0"
111112

@@ -266,7 +267,9 @@ def collect_values(key, tagged_values):
266267
]
267268
| Flatten(pipeline=self.pipeline)
268269
| GroupByKey()
269-
| MapTuple(collect_values))
270+
| MapTuple(collect_values).with_input_types(
271+
tuple[K, Iterable[tuple[U, V]]]).with_output_types(
272+
tuple[K, dict[U, list[V]]]))
270273

271274

272275
@ptransform_fn

sdks/python/apache_beam/transforms/util_test.py

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,31 @@
8787
'ignore', category=FutureWarning, module='apache_beam.transform.util_test')
8888

8989

90+
class _Unpicklable(object):
91+
def __init__(self, value):
92+
self.value = value
93+
94+
def __getstate__(self):
95+
raise NotImplementedError()
96+
97+
def __setstate__(self, state):
98+
raise NotImplementedError()
99+
100+
101+
class _UnpicklableCoder(beam.coders.Coder):
102+
def encode(self, value):
103+
return str(value.value).encode()
104+
105+
def decode(self, encoded):
106+
return _Unpicklable(int(encoded.decode()))
107+
108+
def to_type_hint(self):
109+
return _Unpicklable
110+
111+
def is_deterministic(self):
112+
return True
113+
114+
90115
class CoGroupByKeyTest(unittest.TestCase):
91116
def test_co_group_by_key_on_tuple(self):
92117
with TestPipeline() as pipeline:
@@ -186,6 +211,20 @@ def test_co_group_by_key_on_one(self):
186211
equal_to(expected),
187212
label='AssertOneDict')
188213

214+
def test_co_group_by_key_on_unpickled(self):
215+
beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder)
216+
values = [_Unpicklable(i) for i in range(5)]
217+
with TestPipeline() as pipeline:
218+
xs = pipeline | beam.Create(values) | beam.WithKeys(lambda x: x)
219+
pcoll = ({
220+
'x': xs
221+
}
222+
| beam.CoGroupByKey()
223+
| beam.FlatMapTuple(
224+
lambda k, tagged: (k.value, tagged['x'][0].value * 2)))
225+
expected = [0, 0, 1, 2, 2, 4, 3, 6, 4, 8]
226+
assert_that(pcoll, equal_to(expected))
227+
189228

190229
class FakeClock(object):
191230
def __init__(self, now=time.time()):
@@ -1205,32 +1244,6 @@ def format_with_timestamp(element, timestamp=beam.DoFn.TimestampParam):
12051244
equal_to(expected_data),
12061245
label="formatted_after_reshuffle")
12071246

1208-
global _Unpicklable
1209-
global _UnpicklableCoder
1210-
1211-
class _Unpicklable(object):
1212-
def __init__(self, value):
1213-
self.value = value
1214-
1215-
def __getstate__(self):
1216-
raise NotImplementedError()
1217-
1218-
def __setstate__(self, state):
1219-
raise NotImplementedError()
1220-
1221-
class _UnpicklableCoder(beam.coders.Coder):
1222-
def encode(self, value):
1223-
return str(value.value).encode()
1224-
1225-
def decode(self, encoded):
1226-
return _Unpicklable(int(encoded.decode()))
1227-
1228-
def to_type_hint(self):
1229-
return _Unpicklable
1230-
1231-
def is_deterministic(self):
1232-
return True
1233-
12341247
def reshuffle_unpicklable_in_global_window_helper(
12351248
self, update_compatibility_version=None):
12361249
with TestPipeline(options=PipelineOptions(

0 commit comments

Comments
 (0)