@@ -631,21 +631,21 @@ def multiprocess_progress_wrapper(
631631 results = []
632632 initializer_inputs = initializer_inputs or ()
633633
634- # Create a multiprocessing pool
634+ # Create a multiprocessing pool.
635635 sigint_handler = signal .signal (signal .SIGINT , signal .SIG_IGN )
636636 with multiprocessing .Pool (
637637 num_worker , initializer , initializer_inputs
638638 ) as worker_pool :
639639 signal .signal (signal .SIGINT , sigint_handler )
640- # Use tqdm to create a progress bar
640+ # Use tqdm to create a progress bar.
641641 with tqdm (total = len (task_list )) as pbar :
642642 try :
643- # Use imap_unordered to asynchronously execute the worker function on each task
643+ # Use imap_unordered to asynchronously execute the worker function on each task.
644644 for result in worker_pool .imap_unordered (function , task_list ):
645645 pbar .update (1 ) # Update progress bar
646646 results .append (result )
647647 except KeyboardInterrupt :
648- # If Ctrl+C is pressed, terminate all child processes
648+ # If Ctrl+C is pressed, terminate all child processes.
649649 worker_pool .terminate ()
650650 worker_pool .join ()
651651 sys .exit (1 ) # Exit the script
@@ -850,22 +850,31 @@ def benchmark_baseline(
850850 tuning_client : TuningClient ,
851851 candidate_tracker : CandidateTracker ,
852852) -> list [BenchmarkResult ]:
853- task_list = [
854- BenchmarkPack (
855- iree_benchmark_module_flags = tuning_client .get_iree_benchmark_module_flags (),
856- benchmark_timeout = tuning_client .get_benchmark_timeout_s (),
857- candidate_tracker = candidate_tracker ,
858- )
859- ] * len (devices )
860853
861- worker_context_queue = create_worker_context_queue (devices )
862- baseline_results = multiprocess_progress_wrapper (
863- num_worker = len (devices ),
864- task_list = task_list ,
865- function = run_iree_benchmark_module_command ,
866- initializer = init_worker_context ,
867- initializer_inputs = (worker_context_queue ,),
868- )
854+ global worker_id , device_id
855+
856+ baseline_results = list ()
857+
858+ # Use tqdm to create a progress bar.
859+ with tqdm (total = len (devices )) as pbar :
860+ try :
861+ for worker_id_ , device_id_ in enumerate (devices ):
862+ worker_id = worker_id_
863+ device_id = device_id_
864+ result = run_iree_benchmark_module_command (
865+ BenchmarkPack (
866+ iree_benchmark_module_flags = tuning_client .get_iree_benchmark_module_flags (),
867+ benchmark_timeout = tuning_client .get_benchmark_timeout_s (),
868+ candidate_tracker = candidate_tracker ,
869+ )
870+ )
871+
872+ baseline_results .append (result )
873+ pbar .update (1 ) # Update progress bar
874+ except KeyboardInterrupt :
875+ # If Ctrl+C is pressed, terminate all child processes.
876+ sys .exit (1 ) # Exit the script.
877+
869878 return baseline_results
870879
871880
0 commit comments