1515)
1616from leopard_em .backend .cross_correlation import (
1717 do_batched_orientation_cross_correlate ,
18- do_batched_orientation_cross_correlate_cpu ,
1918)
19+ from leopard_em .backend .utils import EULER_ANGLE_FMT , combine_euler_angles
2020from leopard_em .utils .ctf_utils import calculate_ctf_filter_stack_full_args
2121
22- # This is assuming the Euler angles are in the ZYZ intrinsic format
23- # AND that the angles are ordered in (phi, theta, psi)
24- EULER_ANGLE_FMT = "ZYZ"
25-
26-
27- def combine_euler_angles (angle_a : torch .Tensor , angle_b : torch .Tensor ) -> torch .Tensor :
28- """Helper function for composing rotations defined by two sets of Euler angles."""
29- # Ensure both input angles have the same dtype
30- common_dtype = angle_a .dtype
31- if angle_b .dtype != common_dtype :
32- angle_b = angle_b .to (common_dtype )
33-
34- rotmat_a = roma .euler_to_rotmat (
35- EULER_ANGLE_FMT , angle_a , degrees = True , device = angle_a .device
36- )
37- rotmat_b = roma .euler_to_rotmat (
38- EULER_ANGLE_FMT , angle_b , degrees = True , device = angle_b .device
39- )
40- # Ensure both rotation matrices have the same dtype
41- if rotmat_b .dtype != rotmat_a .dtype :
42- rotmat_b = rotmat_b .to (rotmat_a .dtype )
43- rotmat_c = roma .rotmat_composition ((rotmat_a , rotmat_b ))
44- euler_angles_c = roma .rotmat_to_euler (EULER_ANGLE_FMT , rotmat_c , degrees = True )
45-
46- return euler_angles_c
47-
4822
4923# NOTE: Disabling pylint for too many arguments because we are taking a data-oriented
5024# approach where each argument is independent and explicit.
@@ -69,7 +43,7 @@ def core_differentiable_refine(
6943 batch_size : int = 32 ,
7044 num_cuda_streams : int = 1 ,
7145 mag_matrix : torch .Tensor | None = None ,
72- ) -> dict [torch . Tensor , torch .Tensor ]:
46+ ) -> dict [str , torch .Tensor ]:
7347 """Core function to refine orientations and defoci of a set of particles.
7448
7549 Parameters
@@ -119,13 +93,22 @@ def core_differentiable_refine(
11993
12094 Returns
12195 -------
122- dict[torch.Tensor , torch.Tensor]
96+ dict[str , torch.Tensor]
12397 Tensor containing the refined parameters for all particles.
12498 """
12599 # Convert single device to list for consistent handling
126100 if isinstance (device , torch .device ):
127101 device = [device ]
128102
103+ # Check that all devices are GPU devices (CUDA)
104+ # Differentiable refinement requires GPU for gradient computation
105+ for dev in device :
106+ if dev .type != "cuda" :
107+ raise ValueError (
108+ f"Differentiable refinement can only be run on GPU devices. "
109+ f"Got device type: { dev .type } . Please use a CUDA device."
110+ )
111+
129112 ###########################################
130113 ### Split particle stack across devices ###
131114 ###########################################
@@ -607,25 +590,14 @@ def _core_refine_template_single_thread(
607590 )
608591
609592 # Calculate the cross-correlation
610- if particle_image_dft .device .type == "cuda" :
611- # NOTE: Here we are setting to only a single stream, but this can easily
612- # be extended to multiple streams if needed.
613- cross_correlation = do_batched_orientation_cross_correlate (
614- image_dft = particle_image_dft ,
615- template_dft = template_dft ,
616- rotation_matrices = rot_matrix_batch ,
617- projective_filters = combined_projective_filter ,
618- requires_grad = True ,
619- mag_matrix = mag_matrix ,
620- )
621- else :
622- cross_correlation = do_batched_orientation_cross_correlate_cpu (
623- image_dft = particle_image_dft ,
624- template_dft = template_dft ,
625- rotation_matrices = rot_matrix_batch ,
626- projective_filters = combined_projective_filter ,
627- mag_matrix = mag_matrix ,
628- )
593+ cross_correlation = do_batched_orientation_cross_correlate (
594+ image_dft = particle_image_dft ,
595+ template_dft = template_dft ,
596+ rotation_matrices = rot_matrix_batch ,
597+ projective_filters = combined_projective_filter ,
598+ requires_grad = True ,
599+ mag_matrix = mag_matrix ,
600+ )
629601
630602 cross_correlation = cross_correlation [..., :crop_h , :crop_w ] # valid crop
631603 # Scale cross_correlation to be "z-score"-like
0 commit comments