Skip to content

Commit 3ae5fc2

Browse files
committed
Update backend based on code review
1 parent ca6fde6 commit 3ae5fc2

File tree

3 files changed

+63
-67
lines changed

3 files changed

+63
-67
lines changed

src/leopard_em/backend/core_differentiable_refine.py

Lines changed: 20 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -15,36 +15,10 @@
1515
)
1616
from leopard_em.backend.cross_correlation import (
1717
do_batched_orientation_cross_correlate,
18-
do_batched_orientation_cross_correlate_cpu,
1918
)
19+
from leopard_em.backend.utils import EULER_ANGLE_FMT, combine_euler_angles
2020
from leopard_em.utils.ctf_utils import calculate_ctf_filter_stack_full_args
2121

22-
# This is assuming the Euler angles are in the ZYZ intrinsic format
23-
# AND that the angles are ordered in (phi, theta, psi)
24-
EULER_ANGLE_FMT = "ZYZ"
25-
26-
27-
def combine_euler_angles(angle_a: torch.Tensor, angle_b: torch.Tensor) -> torch.Tensor:
28-
"""Helper function for composing rotations defined by two sets of Euler angles."""
29-
# Ensure both input angles have the same dtype
30-
common_dtype = angle_a.dtype
31-
if angle_b.dtype != common_dtype:
32-
angle_b = angle_b.to(common_dtype)
33-
34-
rotmat_a = roma.euler_to_rotmat(
35-
EULER_ANGLE_FMT, angle_a, degrees=True, device=angle_a.device
36-
)
37-
rotmat_b = roma.euler_to_rotmat(
38-
EULER_ANGLE_FMT, angle_b, degrees=True, device=angle_b.device
39-
)
40-
# Ensure both rotation matrices have the same dtype
41-
if rotmat_b.dtype != rotmat_a.dtype:
42-
rotmat_b = rotmat_b.to(rotmat_a.dtype)
43-
rotmat_c = roma.rotmat_composition((rotmat_a, rotmat_b))
44-
euler_angles_c = roma.rotmat_to_euler(EULER_ANGLE_FMT, rotmat_c, degrees=True)
45-
46-
return euler_angles_c
47-
4822

4923
# NOTE: Disabling pylint for too many arguments because we are taking a data-oriented
5024
# approach where each argument is independent and explicit.
@@ -69,7 +43,7 @@ def core_differentiable_refine(
6943
batch_size: int = 32,
7044
num_cuda_streams: int = 1,
7145
mag_matrix: torch.Tensor | None = None,
72-
) -> dict[torch.Tensor, torch.Tensor]:
46+
) -> dict[str, torch.Tensor]:
7347
"""Core function to refine orientations and defoci of a set of particles.
7448
7549
Parameters
@@ -119,13 +93,22 @@ def core_differentiable_refine(
11993
12094
Returns
12195
-------
122-
dict[torch.Tensor, torch.Tensor]
96+
dict[str, torch.Tensor]
12397
Tensor containing the refined parameters for all particles.
12498
"""
12599
# Convert single device to list for consistent handling
126100
if isinstance(device, torch.device):
127101
device = [device]
128102

103+
# Check that all devices are GPU devices (CUDA)
104+
# Differentiable refinement requires GPU for gradient computation
105+
for dev in device:
106+
if dev.type != "cuda":
107+
raise ValueError(
108+
f"Differentiable refinement can only be run on GPU devices. "
109+
f"Got device type: {dev.type}. Please use a CUDA device."
110+
)
111+
129112
###########################################
130113
### Split particle stack across devices ###
131114
###########################################
@@ -607,25 +590,14 @@ def _core_refine_template_single_thread(
607590
)
608591

609592
# Calculate the cross-correlation
610-
if particle_image_dft.device.type == "cuda":
611-
# NOTE: Here we are setting to only a single stream, but this can easily
612-
# be extended to multiple streams if needed.
613-
cross_correlation = do_batched_orientation_cross_correlate(
614-
image_dft=particle_image_dft,
615-
template_dft=template_dft,
616-
rotation_matrices=rot_matrix_batch,
617-
projective_filters=combined_projective_filter,
618-
requires_grad=True,
619-
mag_matrix=mag_matrix,
620-
)
621-
else:
622-
cross_correlation = do_batched_orientation_cross_correlate_cpu(
623-
image_dft=particle_image_dft,
624-
template_dft=template_dft,
625-
rotation_matrices=rot_matrix_batch,
626-
projective_filters=combined_projective_filter,
627-
mag_matrix=mag_matrix,
628-
)
593+
cross_correlation = do_batched_orientation_cross_correlate(
594+
image_dft=particle_image_dft,
595+
template_dft=template_dft,
596+
rotation_matrices=rot_matrix_batch,
597+
projective_filters=combined_projective_filter,
598+
requires_grad=True,
599+
mag_matrix=mag_matrix,
600+
)
629601

630602
cross_correlation = cross_correlation[..., :crop_h, :crop_w] # valid crop
631603
# Scale cross_correlation to be "z-score"-like

src/leopard_em/backend/core_refine_template.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,28 +16,14 @@
1616
do_batched_orientation_cross_correlate_cpu,
1717
)
1818
from leopard_em.backend.distributed import run_multiprocess_jobs
19-
from leopard_em.backend.utils import normalize_template_projection
19+
from leopard_em.backend.utils import (
20+
EULER_ANGLE_FMT,
21+
combine_euler_angles,
22+
normalize_template_projection,
23+
)
2024
from leopard_em.utils.cross_correlation import handle_correlation_mode
2125
from leopard_em.utils.ctf_utils import calculate_ctf_filter_stack_full_args
2226

23-
# This is assuming the Euler angles are in the ZYZ intrinsic format
24-
# AND that the angles are ordered in (phi, theta, psi)
25-
EULER_ANGLE_FMT = "ZYZ"
26-
27-
28-
def combine_euler_angles(angle_a: torch.Tensor, angle_b: torch.Tensor) -> torch.Tensor:
29-
"""Helper function for composing rotations defined by two sets of Euler angles."""
30-
rotmat_a = roma.euler_to_rotmat(
31-
EULER_ANGLE_FMT, angle_a, degrees=True, device=angle_a.device
32-
)
33-
rotmat_b = roma.euler_to_rotmat(
34-
EULER_ANGLE_FMT, angle_b, degrees=True, device=angle_b.device
35-
)
36-
rotmat_c = roma.rotmat_composition((rotmat_a, rotmat_b))
37-
euler_angles_c = roma.rotmat_to_euler(EULER_ANGLE_FMT, rotmat_c, degrees=True)
38-
39-
return euler_angles_c
40-
4127

4228
# NOTE: Disabling pylint for too many arguments because we are taking a data-oriented
4329
# approach where each argument is independent and explicit.
@@ -427,6 +413,7 @@ def _core_refine_template_single_gpu(
427413
corr_mean = corr_mean.to(device)
428414
corr_std = corr_std.to(device)
429415
projective_filters = projective_filters.to(device)
416+
mag_matrix = mag_matrix.to(device) if mag_matrix is not None else None
430417

431418
########################################
432419
### Setup constants and progress bar ###

src/leopard_em/backend/utils.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import warnings
66
from typing import Any, Callable, TypeVar
77

8+
import roma
89
import torch
910

1011
# Suppress the specific deprecation warnings from PyTorch internals
@@ -16,6 +17,42 @@
1617

1718
F = TypeVar("F", bound=Callable[..., Any])
1819

20+
# This is assuming the Euler angles are in the ZYZ intrinsic format
21+
# AND that the angles are ordered in (phi, theta, psi)
22+
EULER_ANGLE_FMT = "ZYZ"
23+
24+
25+
def combine_euler_angles(angle_a: torch.Tensor, angle_b: torch.Tensor) -> torch.Tensor:
26+
"""Helper function for composing rotations defined by two sets of Euler angles.
27+
28+
Parameters
29+
----------
30+
angle_a : torch.Tensor
31+
First set of Euler angles in ZYZ convention.
32+
angle_b : torch.Tensor
33+
Second set of Euler angles in ZYZ convention.
34+
35+
Returns
36+
-------
37+
torch.Tensor
38+
Composed Euler angles representing the combined rotation.
39+
"""
40+
# Ensure both input angles have the same dtype
41+
common_dtype = angle_a.dtype
42+
if angle_b.dtype != common_dtype:
43+
angle_b = angle_b.to(common_dtype)
44+
45+
rotmat_a = roma.euler_to_rotmat(
46+
EULER_ANGLE_FMT, angle_a, degrees=True, device=angle_a.device
47+
)
48+
rotmat_b = roma.euler_to_rotmat(
49+
EULER_ANGLE_FMT, angle_b, degrees=True, device=angle_b.device
50+
)
51+
rotmat_c = roma.rotmat_composition((rotmat_a, rotmat_b))
52+
euler_angles_c = roma.rotmat_to_euler(EULER_ANGLE_FMT, rotmat_c, degrees=True)
53+
54+
return euler_angles_c
55+
1956

2057
def attempt_torch_compilation(
2158
target_func: F, backend: str = "inductor", mode: str = "default"

0 commit comments

Comments
 (0)