Skip to content

Commit ed34f27

Browse files
authored
Merge pull request #82 from mgiammar/mdg_match_template_test
Mdg match template test
2 parents f2d2885 + 4fba8fa commit ed34f27

File tree

5 files changed

+362
-16
lines changed

5 files changed

+362
-16
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ coverage.xml
4848
*.cover
4949
.hypothesis/
5050
.pytest_cache/
51+
tests/tmp
52+
benchmark/tmp
5153

5254
# Translations
5355
*.mo
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
"""Benchmarking script for core_match_template performance.
2+
3+
Benchmark is using a 4096 x 4096 pixel image with a 512 x 512 x 512 template. Smaller
4+
size images will make a bigger performance impact than reducing the template volume.
5+
This script can be modified to benchmark images of other sizes (e.g. K3 images).
6+
7+
NOTE: This benchmark can take up to 10 minutes given the moderate sized search space and
8+
GPU requirements.
9+
"""
10+
11+
import json
12+
import subprocess
13+
import time
14+
from pathlib import Path
15+
from pprint import pprint
16+
from typing import Any, dict
17+
18+
import click
19+
import numpy as np
20+
import torch
21+
22+
from leopard_em.backend.core_match_template import core_match_template
23+
from leopard_em.pydantic_models.managers import MatchTemplateManager
24+
25+
DOWNLOAD_DIR = (Path(__file__).parent / "tmp").resolve()
26+
YAML_PATH = (
27+
Path(DOWNLOAD_DIR) / "test_match_template_xenon_216_000_0.0_DWS_config.yaml"
28+
).resolve()
29+
ZENODO_URL = "https://zenodo.org/records/17069838"
30+
31+
32+
def download_comparison_data() -> None:
33+
"""Downloads the example data from Zenodo."""
34+
subprocess.run(
35+
["zenodo_get", f"--output-dir={DOWNLOAD_DIR}", ZENODO_URL], check=True
36+
)
37+
38+
39+
def setup_match_template_manager() -> MatchTemplateManager:
40+
"""Instantiate the manager object and prepare for template matching."""
41+
return MatchTemplateManager.from_yaml(YAML_PATH)
42+
43+
44+
def benchmark_match_template_single_run(
45+
mt_manager: MatchTemplateManager, orientation_batch_size: int
46+
) -> dict[str, float]:
47+
"""Run a single benchmark and return timing statistics."""
48+
torch.cuda.synchronize()
49+
50+
####################################################
51+
### 1. Profile the make core backend kwargs time ###
52+
####################################################
53+
54+
start_time = time.perf_counter()
55+
56+
core_kwargs = mt_manager.make_backend_core_function_kwargs()
57+
58+
setup_time = time.perf_counter() - start_time
59+
60+
############################################
61+
### 2. Profile the backend function call ###
62+
############################################
63+
64+
start_time = time.perf_counter()
65+
66+
result = core_match_template(
67+
**core_kwargs,
68+
orientation_batch_size=orientation_batch_size,
69+
num_cuda_streams=mt_manager.computational_config.num_cpus,
70+
)
71+
total_projections = result["total_projections"] # number of CCGs calculated, N
72+
73+
execution_time = time.perf_counter() - start_time
74+
75+
##################################################
76+
### 3. Use extremely smalls search to estimate ###
77+
### constantcore_match_template setup cost. ###
78+
##################################################
79+
# This is using the timing model where the time -- T -- to compute N
80+
# cross-correlations is dependent on some device rate -- r -- and a constant setup
81+
# cost in terms of time (this is what we measure):
82+
#
83+
# T_N = N/r + k
84+
#
85+
# Taking a large number of cross-correlations -- N -- (the performance profiled
86+
# above) and a smaller number of cross-correlations -- n -- we can back out
87+
# the constants along this curve
88+
#
89+
# T_n = n/r + k
90+
# --> (T_N - T_n) = (N - n) / r
91+
# --> r = (N - n) / (T_N - T_n)
92+
# --> k = N * (T_N - T_n) / (N - n)
93+
94+
core_kwargs["euler_angles"] = torch.rand(size=(100, 3)) * 180
95+
start_time = time.perf_counter()
96+
97+
result = core_match_template(
98+
**core_kwargs,
99+
orientation_batch_size=orientation_batch_size,
100+
num_cuda_streams=mt_manager.computational_config.num_cpus,
101+
)
102+
adjustment_projections = result["total_projections"] # number of CCGs calculated, n
103+
104+
adjustment_time = time.perf_counter() - start_time
105+
106+
# Doing the adjustment computations
107+
N = total_projections
108+
n = adjustment_projections
109+
T = execution_time
110+
t = adjustment_time
111+
throughput = (N - n) / (T - t)
112+
core_deadtime = T - N * (T - t) / (N - n)
113+
114+
return {
115+
"setup_time": setup_time,
116+
"execution_time": execution_time,
117+
"total_projections": total_projections,
118+
"adjustment_time": adjustment_time,
119+
"adjustment_projections": adjustment_projections,
120+
"throughput": throughput,
121+
"core_deadtime": core_deadtime,
122+
}
123+
124+
125+
def run_benchmark(orientation_batch_size: int, num_runs: int) -> dict[str, Any]:
126+
"""Run multiple benchmark iterations and collect statistics."""
127+
# Download example data to use for benchmarking
128+
print("Downloading benchmarking data...")
129+
download_comparison_data()
130+
print("Done!")
131+
132+
# Get CUDA device properties
133+
device = torch.cuda.get_device_properties(0)
134+
device_name = str(device.name)
135+
sm_architecture = device.major * 10 + device.minor
136+
device_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
137+
print(f"GPU device SM architecture: {sm_architecture}")
138+
print("Running benchmark on device:", device_name)
139+
print(f"GPU device has {device_memory:.2f} GB of memory")
140+
141+
results = []
142+
143+
for run_idx in range(num_runs):
144+
print(f"Running benchmark iteration {run_idx + 1}/{num_runs}...")
145+
146+
mt_manager = setup_match_template_manager()
147+
result = benchmark_match_template_single_run(mt_manager, orientation_batch_size)
148+
results.append(result)
149+
150+
print()
151+
print()
152+
print(f" Setup time : {result['setup_time']:.3f} seconds")
153+
print(f" Execution time : {result['execution_time']:.3f} seconds")
154+
print(f" throughput (adj.) : {result['throughput']:.3f} corr/sec")
155+
print(f" core dead-time : {result['core_deadtime']:.3f} seconds")
156+
157+
torch.cuda.empty_cache()
158+
159+
execution_times = np.array([r["execution_time"] for r in results])
160+
setup_times = np.array([r["setup_time"] for r in results])
161+
throughputs = np.array([r["throughput"] for r in results])
162+
core_deadtimes = np.array([r["core_deadtime"] for r in results])
163+
total_projections_list = [r["total_projections"] for r in results]
164+
165+
mst, sst = setup_times.mean(), setup_times.std()
166+
mxt, sxt = execution_times.mean(), execution_times.std()
167+
mtt, stt = throughputs.mean(), throughputs.std()
168+
mct, sct = core_deadtimes.mean(), core_deadtimes.std()
169+
170+
print("\nSummary statistics over all runs (mean / std)")
171+
print("-------------------------------------------------------------")
172+
print(f" Setup time (seconds) {mst:.3f} / {sst:.3f}")
173+
print(f" Execution time (seconds) {mxt:.3f} / {sxt:.3f}")
174+
print(f" Throughput (adj.) (corr/sec) {mtt:.3f} / {stt:.3f}")
175+
print(f" Core dead-time (seconds) {mct:.3f} / {sct:.3f}")
176+
print("-------------------------------------------------------------")
177+
178+
stats = {
179+
"total_projections": total_projections_list,
180+
"device_name": device_name,
181+
"device_sm_arch": sm_architecture,
182+
"device_memory_gb": device_memory,
183+
"mean_setup_time": mst,
184+
"mean_execution_time": mxt,
185+
"mean_throughput": mtt,
186+
"mean_core_deatime": mct,
187+
"all_results": results,
188+
}
189+
190+
return stats
191+
192+
193+
def save_benchmark_results(result: dict, output_file: str) -> None:
194+
"""Save benchmark results to a JSON file."""
195+
with open(output_file, "w") as f:
196+
json.dump(result, f, indent=2)
197+
198+
print(f"\nBenchmark results saved to: {output_file}")
199+
200+
201+
@click.command()
202+
@click.option(
203+
"--orientation-batch-size",
204+
default=20,
205+
type=int,
206+
help="Batch size for orientation processing (default: 20). Vary based on GPU specs",
207+
)
208+
@click.option(
209+
"--num-runs",
210+
default=3,
211+
type=int,
212+
help="Number of benchmark runs for statistical analysis (default: 3)",
213+
)
214+
@click.option(
215+
"--output-file",
216+
default="benchmark_results.json",
217+
type=str,
218+
help="Output file for benchmark results (default: benchmark_results.json)",
219+
)
220+
def main(orientation_batch_size: int, num_runs: int, output_file: str):
221+
"""Main benchmarking function with Click CLI interface."""
222+
if not torch.cuda.is_available():
223+
print("CUDA not available exiting...")
224+
return
225+
226+
print("Benchmark configuration:")
227+
print(f" Orientation batch size: {orientation_batch_size}")
228+
print(f" Number of runs: {num_runs}")
229+
print(f" Output file: {output_file}")
230+
231+
result = run_benchmark(orientation_batch_size, num_runs)
232+
pprint(result)
233+
save_benchmark_results(result, output_file)
234+
235+
236+
if __name__ == "__main__":
237+
main()

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ readme = "README.md"
2222
requires-python = ">=3.10"
2323
license = { text = "BSD-3-Clause" }
2424
authors = [
25-
{ name = "Josh Dickerson", email = "[email protected]" },
2625
{ name = "Matthew Giammar", email = "[email protected]" },
26+
{ name = "Josh Dickerson", email = "[email protected]" },
2727
]
2828
# https://pypi.org/classifiers/
2929
classifiers = [
@@ -154,6 +154,8 @@ pretty = true
154154
minversion = "7.0"
155155
testpaths = ["tests"]
156156
filterwarnings = ["error"]
157+
addopts = "-m 'not slow'" # Skip slow tests on default
158+
markers = ["slow: marks test as slow"]
157159

158160
# https://coverage.readthedocs.io/
159161
[tool.coverage.report]

src/leopard_em/backend/utils.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
"""Utility and helper functions associated with the backend of Leopard-EM."""
22

33
import os
4+
import re
45
import warnings
56
from multiprocessing import Manager, Process
67
from typing import Any, Callable, Optional, TypeVar
78

89
import torch
910

11+
# Suppress the specific deprecation warnings from PyTorch internals
12+
warnings.filterwarnings(
13+
"ignore",
14+
category=UserWarning,
15+
message=re.escape("Logical operators 'and' and 'or' are deprecated for non-scalar"),
16+
)
17+
1018
F = TypeVar("F", bound=Callable[..., Any])
1119

1220

@@ -220,21 +228,19 @@ def do_iteration_statistics_updates(
220228
# using torch.where directly
221229
update_mask = max_values > mip
222230

223-
mip = torch.where(update_mask, max_values, mip)
224-
best_phi = torch.where(update_mask, euler_angles[max_orientation_idx, 0], best_phi)
225-
best_theta = torch.where(
226-
update_mask, euler_angles[max_orientation_idx, 1], best_theta
227-
)
228-
best_psi = torch.where(update_mask, euler_angles[max_orientation_idx, 2], best_psi)
229-
best_defocus = torch.where(
230-
update_mask, defocus_values[max_defocus_idx], best_defocus
231-
)
232-
best_pixel_size = torch.where(
233-
update_mask, pixel_values[max_cs_idx], best_pixel_size
234-
)
231+
# pylint: disable=line-too-long
232+
# fmt: off
233+
torch.where(update_mask, max_values, mip, out=mip)
234+
torch.where(update_mask, euler_angles[max_orientation_idx, 0], best_phi, out=best_phi) # noqa: E501
235+
torch.where(update_mask, euler_angles[max_orientation_idx, 1], best_theta, out=best_theta) # noqa: E501
236+
torch.where(update_mask, euler_angles[max_orientation_idx, 2], best_psi, out=best_psi) # noqa: E501
237+
torch.where(update_mask, defocus_values[max_defocus_idx], best_defocus, out=best_defocus) # noqa: E501
238+
torch.where(update_mask, pixel_values[max_cs_idx], best_pixel_size, out=best_pixel_size) # noqa: E501
239+
# fmt: on
240+
# pylint: enable=line-too-long
235241

236-
correlation_sum = correlation_sum + cc_reshaped.sum(dim=0)
237-
correlation_squared_sum = correlation_squared_sum + (cc_reshaped**2).sum(dim=0)
242+
correlation_sum += cc_reshaped.sum(dim=0)
243+
correlation_squared_sum += (cc_reshaped**2).sum(dim=0)
238244

239245

240246
def run_multiprocess_jobs(
@@ -310,5 +316,7 @@ def worker_fn(result_dict, idx, param1, param2):
310316
normalize_template_projection, backend="inductor", mode="default"
311317
)
312318
do_iteration_statistics_updates_compiled = attempt_torch_compilation(
313-
do_iteration_statistics_updates, backend="inductor", mode="max-autotune"
319+
do_iteration_statistics_updates,
320+
backend="inductor",
321+
mode="max-autotune-no-cudagraphs",
314322
)

0 commit comments

Comments
 (0)