|  | 
| 12 | 12 | import sys | 
| 13 | 13 | import os | 
| 14 | 14 | import gc | 
|  | 15 | +import importlib | 
| 15 | 16 | import errno | 
| 16 | 17 | import functools | 
| 17 | 18 | import signal | 
|  | 
| 20 | 21 | import socket | 
| 21 | 22 | import random | 
| 22 | 23 | import logging | 
|  | 24 | +import shutil | 
| 23 | 25 | import subprocess | 
| 24 | 26 | import struct | 
|  | 27 | +import tempfile | 
| 25 | 28 | import operator | 
| 26 | 29 | import pickle | 
| 27 | 30 | import weakref | 
| @@ -6343,6 +6346,81 @@ def test_atexit(self): | 
| 6343 | 6346 |                 self.assertEqual(f.read(), 'deadbeef') | 
| 6344 | 6347 | 
 | 
| 6345 | 6348 | 
 | 
|  | 6349 | +class _TestSpawnedSysPath(BaseTestCase): | 
|  | 6350 | +    """Test that sys.path is setup in forkserver and spawn processes.""" | 
|  | 6351 | + | 
|  | 6352 | +    ALLOWED_TYPES = ('processes',) | 
|  | 6353 | + | 
|  | 6354 | +    def setUp(self): | 
|  | 6355 | +        self._orig_sys_path = list(sys.path) | 
|  | 6356 | +        self._temp_dir = tempfile.mkdtemp(prefix="test_sys_path-") | 
|  | 6357 | +        self._mod_name = "unique_test_mod" | 
|  | 6358 | +        module_path = os.path.join(self._temp_dir, f"{self._mod_name}.py") | 
|  | 6359 | +        with open(module_path, "w", encoding="utf-8") as mod: | 
|  | 6360 | +            mod.write("# A simple test module\n") | 
|  | 6361 | +        sys.path[:] = [p for p in sys.path if p]  # remove any existing ""s | 
|  | 6362 | +        sys.path.insert(0, self._temp_dir) | 
|  | 6363 | +        sys.path.insert(0, "")  # Replaced with an abspath in child. | 
|  | 6364 | +        try: | 
|  | 6365 | +            self._ctx_forkserver = multiprocessing.get_context("forkserver") | 
|  | 6366 | +        except ValueError: | 
|  | 6367 | +            self._ctx_forkserver = None | 
|  | 6368 | +        self._ctx_spawn = multiprocessing.get_context("spawn") | 
|  | 6369 | + | 
|  | 6370 | +    def tearDown(self): | 
|  | 6371 | +        sys.path[:] = self._orig_sys_path | 
|  | 6372 | +        shutil.rmtree(self._temp_dir, ignore_errors=True) | 
|  | 6373 | + | 
|  | 6374 | +    @staticmethod | 
|  | 6375 | +    def enq_imported_module_names(queue): | 
|  | 6376 | +        queue.put(tuple(sys.modules)) | 
|  | 6377 | + | 
|  | 6378 | +    def test_forkserver_preload_imports_sys_path(self): | 
|  | 6379 | +        ctx = self._ctx_forkserver | 
|  | 6380 | +        if not ctx: | 
|  | 6381 | +            self.skipTest("requires forkserver start method.") | 
|  | 6382 | +        self.assertNotIn(self._mod_name, sys.modules) | 
|  | 6383 | +        multiprocessing.forkserver._forkserver._stop()  # Must be fresh. | 
|  | 6384 | +        ctx.set_forkserver_preload( | 
|  | 6385 | +            ["test.test_multiprocessing_forkserver", self._mod_name]) | 
|  | 6386 | +        q = ctx.Queue() | 
|  | 6387 | +        proc = ctx.Process(target=self.enq_imported_module_names, args=(q,)) | 
|  | 6388 | +        proc.start() | 
|  | 6389 | +        proc.join() | 
|  | 6390 | +        child_imported_modules = q.get() | 
|  | 6391 | +        q.close() | 
|  | 6392 | +        self.assertIn(self._mod_name, child_imported_modules) | 
|  | 6393 | + | 
|  | 6394 | +    @staticmethod | 
|  | 6395 | +    def enq_sys_path_and_import(queue, mod_name): | 
|  | 6396 | +        queue.put(sys.path) | 
|  | 6397 | +        try: | 
|  | 6398 | +            importlib.import_module(mod_name) | 
|  | 6399 | +        except ImportError as exc: | 
|  | 6400 | +            queue.put(exc) | 
|  | 6401 | +        else: | 
|  | 6402 | +            queue.put(None) | 
|  | 6403 | + | 
|  | 6404 | +    def test_child_sys_path(self): | 
|  | 6405 | +        for ctx in (self._ctx_spawn, self._ctx_forkserver): | 
|  | 6406 | +            if not ctx: | 
|  | 6407 | +                continue | 
|  | 6408 | +            with self.subTest(f"{ctx.get_start_method()} start method"): | 
|  | 6409 | +                q = ctx.Queue() | 
|  | 6410 | +                proc = ctx.Process(target=self.enq_sys_path_and_import, | 
|  | 6411 | +                                   args=(q, self._mod_name)) | 
|  | 6412 | +                proc.start() | 
|  | 6413 | +                proc.join() | 
|  | 6414 | +                child_sys_path = q.get() | 
|  | 6415 | +                import_error = q.get() | 
|  | 6416 | +                q.close() | 
|  | 6417 | +                self.assertNotIn("", child_sys_path)  # replaced by an abspath | 
|  | 6418 | +                self.assertIn(self._temp_dir, child_sys_path)  # our addition | 
|  | 6419 | +                # ignore the first element, it is the absolute "" replacement | 
|  | 6420 | +                self.assertEqual(child_sys_path[1:], sys.path[1:]) | 
|  | 6421 | +                self.assertIsNone(import_error, msg=f"child could not import {self._mod_name}") | 
|  | 6422 | + | 
|  | 6423 | + | 
| 6346 | 6424 | class MiscTestCase(unittest.TestCase): | 
| 6347 | 6425 |     def test__all__(self): | 
| 6348 | 6426 |         # Just make sure names in not_exported are excluded | 
|  | 
0 commit comments