|
19 | 19 |
|
20 | 20 | # pytype: skip-file |
21 | 21 |
|
| 22 | +import os |
22 | 23 | import threading |
23 | 24 | import types |
24 | 25 | import unittest |
25 | 26 |
|
26 | 27 | from apache_beam.coders import proto2_coder_test_messages_pb2 |
| 28 | +from apache_beam.internal import cloudpickle_pickler as beam_cloudpickle |
| 29 | +from apache_beam.internal import code_object_pickler |
27 | 30 | from apache_beam.internal import module_test |
28 | 31 | from apache_beam.internal.cloudpickle_pickler import dumps |
29 | 32 | from apache_beam.internal.cloudpickle_pickler import loads |
@@ -220,6 +223,36 @@ def test_best_effort_determinism_not_implemented(self): |
220 | 223 | 'Ignoring unsupported option: enable_best_effort_determinism', |
221 | 224 | '\n'.join(l.output)) |
222 | 225 |
|
| 226 | + @unittest.mock.patch.object( |
| 227 | + code_object_pickler, |
| 228 | + 'get_normalized_path', |
| 229 | + wraps=code_object_pickler.get_normalized_path) |
| 230 | + def test_default_config_interceptor(self, mock_get_normalized_path): |
| 231 | + """Tests config.filepath_interceptor is called for CodeType pickling.""" |
| 232 | + |
| 233 | + def sample_func(): |
| 234 | + return "Beam" |
| 235 | + |
| 236 | + code_obj = sample_func.__code__ |
| 237 | + original_filename = os.path.abspath(code_obj.co_filename) |
| 238 | + |
| 239 | + try: |
| 240 | + pickled_code = beam_cloudpickle.dumps(code_obj) |
| 241 | + unpickled_code = beam_cloudpickle.loads(pickled_code) |
| 242 | + |
| 243 | + mock_get_normalized_path.assert_called() |
| 244 | + |
| 245 | + unpickled_filename = os.path.abspath(unpickled_code.co_filename) |
| 246 | + self.assertEqual(unpickled_filename, original_filename) |
| 247 | + |
| 248 | + except AttributeError as e: |
| 249 | + if 'get_code_object_params' in str(e): |
| 250 | + self.fail( |
| 251 | + "Vendored cloudpickle BUG: AttributeError 'get_code_object_params' " |
| 252 | + f"raised during CodeType pickling. Error: {e}") |
| 253 | + else: |
| 254 | + raise |
| 255 | + |
223 | 256 |
|
224 | 257 | if __name__ == '__main__': |
225 | 258 | unittest.main() |
0 commit comments