|
9 | 9 | import roma |
10 | 10 | import torch |
11 | 11 | import tqdm |
12 | | -from torch_fourier_slice import extract_central_slices_rfft_3d |
| 12 | +from torch_fourier_slice import extract_central_slices_rfft_3d, transform_slice_2d |
13 | 13 |
|
14 | 14 | from leopard_em.backend.cross_correlation import ( |
15 | 15 | do_batched_orientation_cross_correlate, |
@@ -61,6 +61,7 @@ def core_refine_template( |
61 | 61 | device: torch.device | list[torch.device], |
62 | 62 | batch_size: int = 32, |
63 | 63 | num_cuda_streams: int = 1, |
| 64 | + transform_matrix: torch.Tensor | None = None, |
64 | 65 | ) -> dict[str, torch.Tensor]: |
65 | 66 | """Core function to refine orientations and defoci of a set of particles. |
66 | 67 |
|
@@ -105,6 +106,9 @@ def core_refine_template( |
105 | 106 | The number of cross-correlations to process in one batch, defaults to 32. |
106 | 107 | num_cuda_streams : int, optional |
107 | 108 | Number of CUDA streams to use for parallel processing. Defaults to 1. |
| 109 | + transform_matrix : torch.Tensor | None, optional |
| 110 | + Anisotropic magnification transform matrix of shape (2, 2). If None, |
| 111 | + no magnification transform is applied. Default is None. |
108 | 112 |
|
109 | 113 | Returns |
110 | 114 | ------- |
@@ -136,6 +140,7 @@ def core_refine_template( |
136 | 140 | batch_size=batch_size, |
137 | 141 | devices=device, |
138 | 142 | num_cuda_streams=num_cuda_streams, |
| 143 | + transform_matrix=transform_matrix, |
139 | 144 | ) |
140 | 145 |
|
141 | 146 | results = run_multiprocess_jobs( |
@@ -229,6 +234,7 @@ def construct_multi_gpu_refine_template_kwargs( |
229 | 234 | batch_size: int, |
230 | 235 | devices: list[torch.device], |
231 | 236 | num_cuda_streams: int, |
| 237 | + transform_matrix: torch.Tensor | None = None, |
232 | 238 | ) -> list[dict]: |
233 | 239 | """Split particle stack between requested devices. |
234 | 240 |
|
@@ -266,6 +272,9 @@ def construct_multi_gpu_refine_template_kwargs( |
266 | 272 | List of devices to split across. |
267 | 273 | num_cuda_streams : int |
268 | 274 | Number of CUDA streams to use per device. |
| 275 | + transform_matrix : torch.Tensor | None, optional |
| 276 | + Anisotropic magnification transform matrix of shape (2, 2). If None, |
| 277 | + no magnification transform is applied. Default is None. |
269 | 278 |
|
270 | 279 | Returns |
271 | 280 | ------- |
@@ -321,6 +330,7 @@ def construct_multi_gpu_refine_template_kwargs( |
321 | 330 | "batch_size": batch_size, |
322 | 331 | "num_cuda_streams": num_cuda_streams, |
323 | 332 | "device": device, |
| 333 | + "transform_matrix": transform_matrix, |
324 | 334 | } |
325 | 335 |
|
326 | 336 | kwargs_per_device.append(kwargs) |
@@ -350,6 +360,7 @@ def _core_refine_template_single_gpu( |
350 | 360 | batch_size: int, |
351 | 361 | device: torch.device, |
352 | 362 | num_cuda_streams: int = 1, |
| 363 | + transform_matrix: torch.Tensor | None = None, |
353 | 364 | ) -> None: |
354 | 365 | """Run refine template on a subset of particles on a single GPU. |
355 | 366 |
|
@@ -393,6 +404,9 @@ def _core_refine_template_single_gpu( |
393 | 404 | Torch device to run this process on. |
394 | 405 | num_cuda_streams : int, optional |
395 | 406 | Number of CUDA streams to use for parallel processing. Defaults to 1. |
| 407 | + transform_matrix : torch.Tensor | None, optional |
| 408 | + Anisotropic magnification transform matrix of shape (2, 2). If None, |
| 409 | + no magnification transform is applied. Default is None. |
396 | 410 | """ |
397 | 411 | streams = [torch.cuda.Stream(device=device) for _ in range(num_cuda_streams)] |
398 | 412 |
|
@@ -464,6 +478,7 @@ def _core_refine_template_single_gpu( |
464 | 478 | corr_std=corr_std[i], |
465 | 479 | projective_filter=projective_filters[i], |
466 | 480 | batch_size=batch_size, |
| 481 | + transform_matrix=transform_matrix, |
467 | 482 | device_id=device_id, |
468 | 483 | ) |
469 | 484 | refined_statistics.append(refined_stats) |
@@ -565,6 +580,7 @@ def _core_refine_template_single_thread( |
565 | 580 | projective_filter: torch.Tensor, |
566 | 581 | batch_size: int = 32, |
567 | 582 | device_id: int = 0, |
| 583 | + transform_matrix: torch.Tensor | None = None, |
568 | 584 | ) -> dict[str, float | int]: |
569 | 585 | """Run the single-threaded core refine template function. |
570 | 586 |
|
@@ -605,6 +621,9 @@ def _core_refine_template_single_thread( |
605 | 621 | The number of orientations to cross-correlate at once. Default is 32. |
606 | 622 | device_id : int, optional |
607 | 623 | The ID of the device/process. Default is 0. |
| 624 | + transform_matrix : torch.Tensor | None, optional |
| 625 | + Anisotropic magnification transform matrix of shape (2, 2). If None, |
| 626 | + no magnification transform is applied. Default is None. |
608 | 627 |
|
609 | 628 | Returns |
610 | 629 | ------- |
@@ -693,13 +712,15 @@ def _core_refine_template_single_thread( |
693 | 712 | template_dft=template_dft, |
694 | 713 | rotation_matrices=rot_matrix_batch, |
695 | 714 | projective_filters=combined_projective_filter, |
| 715 | + transform_matrix=transform_matrix, |
696 | 716 | ) |
697 | 717 | else: |
698 | 718 | cross_correlation = do_batched_orientation_cross_correlate_cpu( |
699 | 719 | image_dft=particle_image_dft, |
700 | 720 | template_dft=template_dft, |
701 | 721 | rotation_matrices=rot_matrix_batch, |
702 | 722 | projective_filters=combined_projective_filter, |
| 723 | + transform_matrix=transform_matrix, |
703 | 724 | ) |
704 | 725 |
|
705 | 726 | cross_correlation = cross_correlation[..., :crop_h, :crop_w] # valid crop |
@@ -765,6 +786,7 @@ def cross_correlate_particle_stack( |
765 | 786 | projective_filters: torch.Tensor, # (N, h, w) |
766 | 787 | mode: Literal["valid", "same"] = "valid", |
767 | 788 | batch_size: int = 1024, |
| 789 | + transform_matrix: torch.Tensor | None = None, |
768 | 790 | ) -> torch.Tensor: |
769 | 791 | """Cross-correlate a stack of particle images against a template. |
770 | 792 |
|
@@ -793,6 +815,9 @@ def cross_correlate_particle_stack( |
793 | 815 | The number of particle images to cross-correlate at once. Default is 1024. |
794 | 816 | Larger sizes will consume more memory. If -1, then the entire stack will be |
795 | 817 | cross-correlated at once. |
| 818 | + transform_matrix : torch.Tensor | None, optional |
| 819 | + Anisotropic magnification transform matrix of shape (2, 2). If None, |
| 820 | + no magnification transform is applied. Default is None. |
796 | 821 |
|
797 | 822 | Returns |
798 | 823 | ------- |
@@ -839,9 +864,19 @@ def cross_correlate_particle_stack( |
839 | 864 | # Extract the Fourier slice and apply the projective filters |
840 | 865 | fourier_slice = extract_central_slices_rfft_3d( |
841 | 866 | volume_rfft=template_dft, |
842 | | - image_shape=(template_h,) * 3, |
843 | 867 | rotation_matrices=batch_rotation_matrices, |
844 | 868 | ) |
| 869 | + # Apply anisotropic magnification transform if provided |
| 870 | + # pylint: disable=duplicate-code |
| 871 | + if transform_matrix is not None: |
| 872 | + rfft_shape = (template_h, template_w) |
| 873 | + stack_shape = (batch_rotation_matrices.shape[0],) |
| 874 | + fourier_slice = transform_slice_2d( |
| 875 | + projection_image_dfts=fourier_slice, |
| 876 | + rfft_shape=rfft_shape, |
| 877 | + stack_shape=stack_shape, |
| 878 | + transform_matrix=transform_matrix, |
| 879 | + ) |
845 | 880 | fourier_slice = torch.fft.ifftshift(fourier_slice, dim=(-2,)) |
846 | 881 | fourier_slice[..., 0, 0] = 0 + 0j # zero out the DC component (mean zero) |
847 | 882 | fourier_slice *= -1 # flip contrast |
|
0 commit comments