@@ -415,10 +415,11 @@ def interrupt_process(process: subprocess.Popen):
415
415
416
416
417
417
class ParallelTestRunner (TestRunner ):
418
- def __init__ (self , * , num_processes , subprocess_args , ** kwargs ):
418
+ def __init__ (self , * , num_processes , subprocess_args , separate_workers , ** kwargs ):
419
419
super ().__init__ (** kwargs )
420
420
self .num_processes = num_processes
421
421
self .subprocess_args = subprocess_args
422
+ self .separate_workers = separate_workers
422
423
self .stop_event = threading .Event ()
423
424
self .crashes = []
424
425
self .last_out_pos = 0
@@ -432,28 +433,35 @@ def report_result(self, result: TestResult):
432
433
def tests_failed (self ):
433
434
return super ().tests_failed () or bool (self .crashes )
434
435
436
+ def partition_tests_into_processes (self , suites : list ['TestSuite' ]) -> list [list [TestId ]]:
437
+ if self .separate_workers :
438
+ per_file_suites = suites
439
+ unpartitioned = []
440
+ else :
441
+ per_file_suites , unpartitioned = partition_list (suites , lambda suite : suite .config .new_worker_per_file )
442
+ partitions = [suite .collected_tests for suite in per_file_suites ]
443
+ per_partition = int (math .ceil (len (unpartitioned ) / max (1 , self .num_processes )))
444
+ while unpartitioned :
445
+ partitions .append ([test for suite in unpartitioned [:per_partition ] for test in suite .collected_tests ])
446
+ unpartitioned = unpartitioned [per_partition :]
447
+ return partitions
448
+
435
449
def run_tests (self , tests : list ['TestSuite' ]):
450
+ serial_suites , parallel_suites = partition_list (
451
+ tests ,
452
+ lambda suite : suite .test_file .name .removesuffix ('.py' ) in suite .config .serial_tests ,
453
+ )
454
+ parallel_partitions = self .partition_tests_into_processes (parallel_suites )
455
+ serial_partitions = self .partition_tests_into_processes (serial_suites )
456
+
436
457
start_time = time .time ()
437
- if tests :
438
- serial_suites , unpartitioned = partition_list (
439
- tests ,
440
- lambda suite : suite .test_file .name .removesuffix ('.py' ) in suite .config .serial_tests ,
441
- )
442
- per_file_suites , unpartitioned = partition_list (
443
- unpartitioned ,
444
- lambda suite : suite .config .new_worker_per_file ,
445
- )
446
- partitions = [suite .collected_tests for suite in per_file_suites ]
447
- per_partition = int (math .ceil (len (unpartitioned ) / self .num_processes ))
448
- while unpartitioned :
449
- partitions .append ([test for suite in unpartitioned [:per_partition ] for test in suite .collected_tests ])
450
- unpartitioned = unpartitioned [per_partition :]
451
-
452
- num_processes = max (1 , min (self .num_processes , len (partitions )))
458
+ if parallel_partitions :
459
+ num_processes = max (1 , min (self .num_processes , len (parallel_partitions )))
453
460
with concurrent .futures .ThreadPoolExecutor (num_processes ) as executor :
454
- self .run_partitions_in_subprocesses (executor , partitions )
455
- for serial_suite in serial_suites :
456
- self .run_partitions_in_subprocesses (executor , [serial_suite .collected_tests ])
461
+ self .run_partitions_in_subprocesses (executor , parallel_partitions )
462
+ if serial_partitions :
463
+ with concurrent .futures .ThreadPoolExecutor (1 ) as executor :
464
+ self .run_partitions_in_subprocesses (executor , serial_partitions )
457
465
458
466
self .total_duration = time .time () - start_time
459
467
self .display_summary ()
@@ -794,6 +802,8 @@ def main():
794
802
help = "Interpret test file names relative to tagged test directory" )
795
803
parser .add_argument ('-n' , '--num-processes' , type = int ,
796
804
help = "Run tests in N subprocess workers. Adds crash recovery, output capture and timeout handling" )
805
+ parser .add_argument ('--separate-workers' , action = 'store_true' ,
806
+ help = "Create a new worker process for each test file (when -n is specified). Default for tagged unit tests" )
797
807
parser .add_argument ('--ignore' , type = Path , action = 'append' , default = [],
798
808
help = "Ignore path during collection (multi-allowed)" )
799
809
parser .add_argument ('-f' , '--failfast' , action = 'store_true' ,
@@ -865,17 +875,18 @@ def main():
865
875
if not tests :
866
876
sys .exit ("No tests matched\n " )
867
877
868
- runner_args = {
869
- ' failfast' : args .failfast ,
870
- ' report_durations' : args .durations ,
871
- }
878
+ runner_args = dict (
879
+ failfast = args .failfast ,
880
+ report_durations = args .durations ,
881
+ )
872
882
if not args .num_processes :
873
883
runner = TestRunner (** runner_args )
874
884
else :
875
885
runner = ParallelTestRunner (
876
886
** runner_args ,
877
887
num_processes = args .num_processes ,
878
888
subprocess_args = args .subprocess_args ,
889
+ separate_workers = args .separate_workers ,
879
890
)
880
891
881
892
runner .run_tests (tests )
0 commit comments