@@ -258,6 +258,9 @@ def __call__(self, *args, **kwds):
258258class BaseTestCase (object ):
259259
260260 ALLOWED_TYPES = ('processes' , 'manager' , 'threads' )
261+ # If not empty, limit which start method suites run this class.
262+ START_METHODS : set [str ] = set ()
263+ start_method = None # set by install_tests_in_module_dict()
261264
262265 def assertTimingAlmostEqual (self , a , b ):
263266 if CHECK_TIMINGS :
@@ -6403,7 +6406,9 @@ def test_atexit(self):
64036406class _TestSpawnedSysPath (BaseTestCase ):
64046407 """Test that sys.path is setup in forkserver and spawn processes."""
64056408
6406- ALLOWED_TYPES = ('processes' ,)
6409+ ALLOWED_TYPES = {'processes' }
6410+ # Not applicable to fork which inherits everything from the process as is.
6411+ START_METHODS = {"forkserver" , "spawn" }
64076412
64086413 def setUp (self ):
64096414 self ._orig_sys_path = list (sys .path )
@@ -6415,11 +6420,8 @@ def setUp(self):
64156420 sys .path [:] = [p for p in sys .path if p ] # remove any existing ""s
64166421 sys .path .insert (0 , self ._temp_dir )
64176422 sys .path .insert (0 , "" ) # Replaced with an abspath in child.
6418- try :
6419- self ._ctx_forkserver = multiprocessing .get_context ("forkserver" )
6420- except ValueError :
6421- self ._ctx_forkserver = None
6422- self ._ctx_spawn = multiprocessing .get_context ("spawn" )
6423+ self .assertIn (self .start_method , self .START_METHODS )
6424+ self ._ctx = multiprocessing .get_context (self .start_method )
64236425
64246426 def tearDown (self ):
64256427 sys .path [:] = self ._orig_sys_path
@@ -6430,15 +6432,15 @@ def enq_imported_module_names(queue):
64306432 queue .put (tuple (sys .modules ))
64316433
64326434 def test_forkserver_preload_imports_sys_path (self ):
6433- ctx = self ._ctx_forkserver
6434- if not ctx :
6435- self .skipTest ("requires forkserver start method." )
6435+ if self ._ctx .get_start_method () != "forkserver" :
6436+ self .skipTest ("forkserver specific test." )
64366437 self .assertNotIn (self ._mod_name , sys .modules )
64376438 multiprocessing .forkserver ._forkserver ._stop () # Must be fresh.
6438- ctx .set_forkserver_preload (
6439+ self . _ctx .set_forkserver_preload (
64396440 ["test.test_multiprocessing_forkserver" , self ._mod_name ])
6440- q = ctx .Queue ()
6441- proc = ctx .Process (target = self .enq_imported_module_names , args = (q ,))
6441+ q = self ._ctx .Queue ()
6442+ proc = self ._ctx .Process (
6443+ target = self .enq_imported_module_names , args = (q ,))
64426444 proc .start ()
64436445 proc .join ()
64446446 child_imported_modules = q .get ()
@@ -6456,23 +6458,19 @@ def enq_sys_path_and_import(queue, mod_name):
64566458 queue .put (None )
64576459
64586460 def test_child_sys_path (self ):
6459- for ctx in (self ._ctx_spawn , self ._ctx_forkserver ):
6460- if not ctx :
6461- continue
6462- with self .subTest (f"{ ctx .get_start_method ()} start method" ):
6463- q = ctx .Queue ()
6464- proc = ctx .Process (target = self .enq_sys_path_and_import ,
6465- args = (q , self ._mod_name ))
6466- proc .start ()
6467- proc .join ()
6468- child_sys_path = q .get ()
6469- import_error = q .get ()
6470- q .close ()
6471- self .assertNotIn ("" , child_sys_path ) # replaced by an abspath
6472- self .assertIn (self ._temp_dir , child_sys_path ) # our addition
6473- # ignore the first element, it is the absolute "" replacement
6474- self .assertEqual (child_sys_path [1 :], sys .path [1 :])
6475- self .assertIsNone (import_error , msg = f"child could not import { self ._mod_name } " )
6461+ q = self ._ctx .Queue ()
6462+ proc = self ._ctx .Process (
6463+ target = self .enq_sys_path_and_import , args = (q , self ._mod_name ))
6464+ proc .start ()
6465+ proc .join ()
6466+ child_sys_path = q .get ()
6467+ import_error = q .get ()
6468+ q .close ()
6469+ self .assertNotIn ("" , child_sys_path ) # replaced by an abspath
6470+ self .assertIn (self ._temp_dir , child_sys_path ) # our addition
6471+ # ignore the first element, it is the absolute "" replacement
6472+ self .assertEqual (child_sys_path [1 :], sys .path [1 :])
6473+ self .assertIsNone (import_error , msg = f"child could not import { self ._mod_name } " )
64766474
64776475
64786476class MiscTestCase (unittest .TestCase ):
@@ -6669,6 +6667,8 @@ def install_tests_in_module_dict(remote_globs, start_method,
66696667 if base is BaseTestCase :
66706668 continue
66716669 assert set (base .ALLOWED_TYPES ) <= ALL_TYPES , base .ALLOWED_TYPES
6670+ if base .START_METHODS and start_method not in base .START_METHODS :
6671+ continue # class not intended for this start method.
66726672 for type_ in base .ALLOWED_TYPES :
66736673 if only_type and type_ != only_type :
66746674 continue
@@ -6682,6 +6682,7 @@ class Temp(base, Mixin, unittest.TestCase):
66826682 Temp = hashlib_helper .requires_hashdigest ('sha256' )(Temp )
66836683 Temp .__name__ = Temp .__qualname__ = newname
66846684 Temp .__module__ = __module__
6685+ Temp .start_method = start_method
66856686 remote_globs [newname ] = Temp
66866687 elif issubclass (base , unittest .TestCase ):
66876688 if only_type :
0 commit comments