|
87 | 87 | 'ignore', category=FutureWarning, module='apache_beam.transform.util_test') |
88 | 88 |
|
89 | 89 |
|
| 90 | +class _Unpicklable(object): |
| 91 | + def __init__(self, value): |
| 92 | + self.value = value |
| 93 | + |
| 94 | + def __getstate__(self): |
| 95 | + raise NotImplementedError() |
| 96 | + |
| 97 | + def __setstate__(self, state): |
| 98 | + raise NotImplementedError() |
| 99 | + |
| 100 | + |
| 101 | +class _UnpicklableCoder(beam.coders.Coder): |
| 102 | + def encode(self, value): |
| 103 | + return str(value.value).encode() |
| 104 | + |
| 105 | + def decode(self, encoded): |
| 106 | + return _Unpicklable(int(encoded.decode())) |
| 107 | + |
| 108 | + def to_type_hint(self): |
| 109 | + return _Unpicklable |
| 110 | + |
| 111 | + def is_deterministic(self): |
| 112 | + return True |
| 113 | + |
| 114 | + |
90 | 115 | class CoGroupByKeyTest(unittest.TestCase): |
91 | 116 | def test_co_group_by_key_on_tuple(self): |
92 | 117 | with TestPipeline() as pipeline: |
@@ -186,6 +211,20 @@ def test_co_group_by_key_on_one(self): |
186 | 211 | equal_to(expected), |
187 | 212 | label='AssertOneDict') |
188 | 213 |
|
| 214 | + def test_co_group_by_key_on_unpickled(self): |
| 215 | + beam.coders.registry.register_coder(_Unpicklable, _UnpicklableCoder) |
| 216 | + values = [_Unpicklable(i) for i in range(5)] |
| 217 | + with TestPipeline() as pipeline: |
| 218 | + xs = pipeline | beam.Create(values) | beam.WithKeys(lambda x: x) |
| 219 | + pcoll = ({ |
| 220 | + 'x': xs |
| 221 | + } |
| 222 | + | beam.CoGroupByKey() |
| 223 | + | beam.FlatMapTuple( |
| 224 | + lambda k, tagged: (k.value, tagged['x'][0].value * 2))) |
| 225 | + expected = [0, 0, 1, 2, 2, 4, 3, 6, 4, 8] |
| 226 | + assert_that(pcoll, equal_to(expected)) |
| 227 | + |
189 | 228 |
|
190 | 229 | class FakeClock(object): |
191 | 230 | def __init__(self, now=time.time()): |
@@ -1205,32 +1244,6 @@ def format_with_timestamp(element, timestamp=beam.DoFn.TimestampParam): |
1205 | 1244 | equal_to(expected_data), |
1206 | 1245 | label="formatted_after_reshuffle") |
1207 | 1246 |
|
1208 | | - global _Unpicklable |
1209 | | - global _UnpicklableCoder |
1210 | | - |
1211 | | - class _Unpicklable(object): |
1212 | | - def __init__(self, value): |
1213 | | - self.value = value |
1214 | | - |
1215 | | - def __getstate__(self): |
1216 | | - raise NotImplementedError() |
1217 | | - |
1218 | | - def __setstate__(self, state): |
1219 | | - raise NotImplementedError() |
1220 | | - |
1221 | | - class _UnpicklableCoder(beam.coders.Coder): |
1222 | | - def encode(self, value): |
1223 | | - return str(value.value).encode() |
1224 | | - |
1225 | | - def decode(self, encoded): |
1226 | | - return _Unpicklable(int(encoded.decode())) |
1227 | | - |
1228 | | - def to_type_hint(self): |
1229 | | - return _Unpicklable |
1230 | | - |
1231 | | - def is_deterministic(self): |
1232 | | - return True |
1233 | | - |
1234 | 1247 | def reshuffle_unpicklable_in_global_window_helper( |
1235 | 1248 | self, update_compatibility_version=None): |
1236 | 1249 | with TestPipeline(options=PipelineOptions( |
|
0 commit comments