Skip to content

Commit 66967d8

Browse files
authored
[tuner] Convert from parallel baseline benchmark to serial run (#1265)
convert from parallel baseline benchmark to serial run
1 parent 0886ba7 commit 66967d8

File tree

1 file changed

+28
-19
lines changed

1 file changed

+28
-19
lines changed

tuner/tuner/libtuner.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -631,21 +631,21 @@ def multiprocess_progress_wrapper(
631631
results = []
632632
initializer_inputs = initializer_inputs or ()
633633

634-
# Create a multiprocessing pool
634+
# Create a multiprocessing pool.
635635
sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
636636
with multiprocessing.Pool(
637637
num_worker, initializer, initializer_inputs
638638
) as worker_pool:
639639
signal.signal(signal.SIGINT, sigint_handler)
640-
# Use tqdm to create a progress bar
640+
# Use tqdm to create a progress bar.
641641
with tqdm(total=len(task_list)) as pbar:
642642
try:
643-
# Use imap_unordered to asynchronously execute the worker function on each task
643+
# Use imap_unordered to asynchronously execute the worker function on each task.
644644
for result in worker_pool.imap_unordered(function, task_list):
645645
pbar.update(1) # Update progress bar
646646
results.append(result)
647647
except KeyboardInterrupt:
648-
# If Ctrl+C is pressed, terminate all child processes
648+
# If Ctrl+C is pressed, terminate all child processes.
649649
worker_pool.terminate()
650650
worker_pool.join()
651651
sys.exit(1) # Exit the script
@@ -850,22 +850,31 @@ def benchmark_baseline(
850850
tuning_client: TuningClient,
851851
candidate_tracker: CandidateTracker,
852852
) -> list[BenchmarkResult]:
853-
task_list = [
854-
BenchmarkPack(
855-
iree_benchmark_module_flags=tuning_client.get_iree_benchmark_module_flags(),
856-
benchmark_timeout=tuning_client.get_benchmark_timeout_s(),
857-
candidate_tracker=candidate_tracker,
858-
)
859-
] * len(devices)
860853

861-
worker_context_queue = create_worker_context_queue(devices)
862-
baseline_results = multiprocess_progress_wrapper(
863-
num_worker=len(devices),
864-
task_list=task_list,
865-
function=run_iree_benchmark_module_command,
866-
initializer=init_worker_context,
867-
initializer_inputs=(worker_context_queue,),
868-
)
854+
global worker_id, device_id
855+
856+
baseline_results = list()
857+
858+
# Use tqdm to create a progress bar.
859+
with tqdm(total=len(devices)) as pbar:
860+
try:
861+
for worker_id_, device_id_ in enumerate(devices):
862+
worker_id = worker_id_
863+
device_id = device_id_
864+
result = run_iree_benchmark_module_command(
865+
BenchmarkPack(
866+
iree_benchmark_module_flags=tuning_client.get_iree_benchmark_module_flags(),
867+
benchmark_timeout=tuning_client.get_benchmark_timeout_s(),
868+
candidate_tracker=candidate_tracker,
869+
)
870+
)
871+
872+
baseline_results.append(result)
873+
pbar.update(1) # Update progress bar
874+
except KeyboardInterrupt:
875+
# If Ctrl+C is pressed, terminate all child processes.
876+
sys.exit(1) # Exit the script.
877+
869878
return baseline_results
870879

871880

0 commit comments

Comments
 (0)