|
20 | 20 | # pytype: skip-file |
21 | 21 |
|
22 | 22 | import os |
| 23 | +import tempfile |
23 | 24 | import threading |
24 | 25 | import types |
25 | 26 | import unittest |
|
31 | 32 | from apache_beam.internal import module_test |
32 | 33 | from apache_beam.internal.cloudpickle_pickler import dumps |
33 | 34 | from apache_beam.internal.cloudpickle_pickler import loads |
| 35 | +from apache_beam.typehints.schemas import LogicalTypeRegistry |
34 | 36 | from apache_beam.utils import shared |
35 | 37 |
|
36 | 38 | GLOBAL_DICT_REF = module_test.GLOBAL_DICT |
@@ -244,6 +246,24 @@ def sample_func(): |
244 | 246 | unpickled_filename = os.path.abspath(unpickled_code.co_filename) |
245 | 247 | self.assertEqual(unpickled_filename, original_filename) |
246 | 248 |
|
| 249 | + @mock.patch( |
| 250 | + "apache_beam.coders.typecoders.registry.load_custom_type_coder_tuples") |
| 251 | + @mock.patch( |
| 252 | + "apache_beam.typehints.schemas.LogicalType._known_logical_types.load") |
| 253 | + def test_dump_load_session(self, logicaltype_mock, coder_mock): |
| 254 | + session_file = 'pickled' |
| 255 | + |
| 256 | + with tempfile.TemporaryDirectory() as tmp_dirname: |
| 257 | + pickled_session_file = os.path.join(tmp_dirname, session_file) |
| 258 | + beam_cloudpickle.dump_session(pickled_session_file) |
| 259 | + beam_cloudpickle.load_session(pickled_session_file) |
| 260 | + load_logical_types = logicaltype_mock.call_args.args |
| 261 | + load_coders = coder_mock.call_args.args |
| 262 | + self.assertEqual(len(load_logical_types), 1) |
| 263 | + self.assertEqual(len(load_coders), 1) |
| 264 | + self.assertTrue(isinstance(load_logical_types[0], LogicalTypeRegistry)) |
| 265 | + self.assertTrue(isinstance(load_coders[0], list)) |
| 266 | + |
247 | 267 |
|
248 | 268 | if __name__ == '__main__': |
249 | 269 | unittest.main() |
0 commit comments