Skip to content

Commit 463792a

Browse files
authored
Merge pull request #116 from mgiammar/mdg-improved-benchmark
Updated benchmarking scripts for zipFFT backend
2 parents f4f0d1f + 007532b commit 463792a

File tree

2 files changed

+1155
-8
lines changed

2 files changed

+1155
-8
lines changed

benchmark/benchmark_match_template.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
77
NOTE: This benchmark can take up to 10 minutes given the moderate sized search space and
88
GPU requirements.
9+
10+
NOTE: This script _must_ be run from the root of the Leopard-EM repository.
911
"""
1012

1113
import json
@@ -135,12 +137,28 @@ def benchmark_match_template_single_run(
135137
}
136138

137139

138-
def run_benchmark(orientation_batch_size: int, num_runs: int) -> dict[str, Any]:
139-
"""Run multiple benchmark iterations and collect statistics."""
140+
def run_benchmark(
141+
orientation_batch_size: int, num_runs: int, download_data: bool
142+
) -> dict[str, Any]:
143+
"""Run multiple benchmark iterations and collect statistics.
144+
145+
Parameters
146+
----------
147+
orientation_batch_size : int
148+
Number of orientations to process in a single batch (config option for
149+
core_match_template).
150+
num_runs : int
151+
Number of benchmark runs to perform for statistical analysis.
152+
download_data : bool
153+
Whether to download the benchmark data from Zenodo. If False, assumes data
154+
has already been downloaded. This can be useful if you want to modify the
155+
YAML config file to be something different than what's provided from Zenodo.
156+
"""
140157
# Download example data to use for benchmarking
141-
print("Downloading benchmarking data...")
142-
# download_comparison_data()
143-
print("Done!")
158+
if download_data:
159+
print("Downloading benchmarking data...")
160+
download_comparison_data()
161+
print("Done!")
144162

145163
# Get CUDA device properties
146164
device = torch.cuda.get_device_properties(0)
@@ -230,7 +248,15 @@ def save_benchmark_results(result: dict, output_file: str) -> None:
230248
type=str,
231249
help="Output file for benchmark results (default: benchmark_results.json)",
232250
)
233-
def main(orientation_batch_size: int, num_runs: int, output_file: str):
251+
@click.option(
252+
"--download-data/--no-download-data",
253+
default=True,
254+
help="Whether to download benchmark data from Zenodo (default: --download-data). "
255+
"Use --no-download-data if you want to use existing files on disk.",
256+
)
257+
def main(
258+
orientation_batch_size: int, num_runs: int, output_file: str, download_data: bool
259+
) -> None:
234260
"""Main benchmarking function with Click CLI interface."""
235261
if not torch.cuda.is_available():
236262
print("CUDA not available exiting...")
@@ -240,9 +266,9 @@ def main(orientation_batch_size: int, num_runs: int, output_file: str):
240266
print(f" Orientation batch size: {orientation_batch_size}")
241267
print(f" Number of runs: {num_runs}")
242268
print(f" Output file: {output_file}")
269+
print(f" Download data: {download_data}")
243270

244-
result = run_benchmark(orientation_batch_size, num_runs)
245-
# pprint(result)
271+
result = run_benchmark(orientation_batch_size, num_runs, download_data)
246272
save_benchmark_results(result, output_file)
247273

248274

0 commit comments

Comments
 (0)