Skip to content

Commit e1ddd7c

Browse files
authored
Merge pull request #110 from Lucaslab-Berkeley/jd_ctf_aberrations
CTF aberrations now included
2 parents 9bcef75 + 38f115d commit e1ddd7c

23 files changed

+732
-553
lines changed

pyproject.toml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,16 @@ dependencies = [
4848
"types-PyYAML",
4949
"roma",
5050
"tqdm",
51-
"torch-fourier-slice>=v0.2.0",
51+
"torch-fourier-slice>=v0.4.0",
5252
"torch-fourier-filter>=v0.2.6",
5353
"torch-so3>=v0.2.0",
5454
"ttsim3d>=v0.4.0",
5555
"lmfit",
5656
"zenodo-get",
5757
"torch-fourier-shift",
5858
"torch-motion-correction>=0.0.4",
59-
"torch-grid-utils>=v0.0.9"
59+
"torch-grid-utils>=v0.0.9",
60+
"torch-ctf"
6061
]
6162

6263
[tool.hatch.metadata]
@@ -156,7 +157,10 @@ pretty = true
156157
[tool.pytest.ini_options]
157158
minversion = "7.0"
158159
testpaths = ["tests"]
159-
filterwarnings = ["error"]
160+
filterwarnings = [
161+
"error",
162+
"ignore::FutureWarning",
163+
]
160164
addopts = "-m 'not slow'" # Skip slow tests on default
161165
markers = ["slow: marks test as slow"]
162166

src/leopard_em/analysis/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
match_template_peaks_to_dict,
77
)
88
from .pvalue_metric import extract_peaks_and_statistics_p_value
9-
from .zscore_metric import extract_peaks_and_statistics_zscore, gaussian_noise_zscore_cutoff
9+
from .zscore_metric import (
10+
extract_peaks_and_statistics_zscore,
11+
gaussian_noise_zscore_cutoff,
12+
)
1013

1114
__all__ = [
1215
"MatchTemplatePeaks",

src/leopard_em/backend/core_differentiable_refine.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def core_differentiable_refine(
6868
device: torch.device | list[torch.device],
6969
batch_size: int = 32,
7070
num_cuda_streams: int = 1,
71+
transform_matrix: torch.Tensor | None = None,
7172
) -> dict[torch.Tensor, torch.Tensor]:
7273
"""Core function to refine orientations and defoci of a set of particles.
7374
@@ -112,6 +113,9 @@ def core_differentiable_refine(
112113
The number of cross-correlations to process in one batch, defaults to 32.
113114
num_cuda_streams : int, optional
114115
Number of CUDA streams to use for parallel processing. Defaults to 1.
116+
transform_matrix : torch.Tensor | None, optional
117+
Anisotropic magnification transform matrix of shape (2, 2). If None,
118+
no magnification transform is applied. Default is None.
115119
116120
Returns
117121
-------
@@ -143,6 +147,7 @@ def core_differentiable_refine(
143147
batch_size=batch_size,
144148
devices=device,
145149
num_cuda_streams=num_cuda_streams,
150+
transform_matrix=transform_matrix,
146151
)
147152

148153
results = {}
@@ -230,6 +235,7 @@ def _core_refine_template_single_gpu(
230235
batch_size: int,
231236
device: torch.device,
232237
num_cuda_streams: int = 1,
238+
transform_matrix: torch.Tensor | None = None,
233239
) -> None:
234240
"""Run refine template on a subset of particles on a single GPU.
235241
@@ -273,6 +279,9 @@ def _core_refine_template_single_gpu(
273279
Torch device to run this process on.
274280
num_cuda_streams : int, optional
275281
Number of CUDA streams to use for parallel processing. Defaults to 1.
282+
transform_matrix : torch.Tensor | None, optional
283+
Anisotropic magnification transform matrix of shape (2, 2). If None,
284+
no magnification transform is applied. Default is None.
276285
"""
277286
streams = [torch.cuda.Stream(device=device) for _ in range(num_cuda_streams)]
278287

@@ -293,6 +302,8 @@ def _core_refine_template_single_gpu(
293302
corr_mean = corr_mean.to(device)
294303
corr_std = corr_std.to(device)
295304
projective_filters = projective_filters.to(device)
305+
if transform_matrix is not None:
306+
transform_matrix = transform_matrix.to(device)
296307

297308
########################################
298309
### Setup constants and progress bar ###
@@ -345,6 +356,7 @@ def _core_refine_template_single_gpu(
345356
projective_filter=projective_filters[i],
346357
batch_size=batch_size,
347358
device_id=device_id,
359+
transform_matrix=transform_matrix,
348360
)
349361
refined_statistics.append(refined_stats)
350362

@@ -457,6 +469,7 @@ def _core_refine_template_single_thread(
457469
projective_filter: torch.Tensor,
458470
batch_size: int = 32,
459471
device_id: int = 0,
472+
transform_matrix: torch.Tensor | None = None,
460473
) -> dict[str, float | int]:
461474
"""Run the single-threaded core refine template function.
462475
@@ -497,6 +510,9 @@ def _core_refine_template_single_thread(
497510
The number of orientations to cross-correlate at once. Default is 32.
498511
device_id : int, optional
499512
The ID of the device/process. Default is 0.
513+
transform_matrix : torch.Tensor | None, optional
514+
Anisotropic magnification transform matrix of shape (2, 2). If None,
515+
no magnification transform is applied. Default is None.
500516
501517
Returns
502518
-------
@@ -600,13 +616,15 @@ def _core_refine_template_single_thread(
600616
rotation_matrices=rot_matrix_batch,
601617
projective_filters=combined_projective_filter,
602618
requires_grad=True,
619+
transform_matrix=transform_matrix,
603620
)
604621
else:
605622
cross_correlation = do_batched_orientation_cross_correlate_cpu(
606623
image_dft=particle_image_dft,
607624
template_dft=template_dft,
608625
rotation_matrices=rot_matrix_batch,
609626
projective_filters=combined_projective_filter,
627+
transform_matrix=transform_matrix,
610628
)
611629

612630
cross_correlation = cross_correlation[..., :crop_h, :crop_w] # valid crop

src/leopard_em/backend/core_match_template.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def core_match_template(
155155
orientation_batch_size: int = 1,
156156
num_cuda_streams: int = 1,
157157
backend: str = "streamed",
158+
transform_matrix: torch.Tensor | None = None,
158159
) -> dict[str, torch.Tensor]:
159160
"""Core function for performing the whole-orientation search.
160161
@@ -209,6 +210,9 @@ def core_match_template(
209210
backend : str, optional
210211
The backend to use for computation. Defaults to 'streamed'.
211212
Must be 'streamed' or 'batched'.
213+
transform_matrix : torch.Tensor | None, optional
214+
Anisotropic magnification transform matrix of shape (2, 2). If None,
215+
no magnification transform is applied. Default is None.
212216
213217
Returns
214218
-------
@@ -259,6 +263,9 @@ def core_match_template(
259263
defocus_values = defocus_values.cpu()
260264
pixel_values = pixel_values.cpu()
261265
euler_angles = euler_angles.cpu()
266+
# Move transform_matrix to CPU if it's not None
267+
if transform_matrix is not None:
268+
transform_matrix = transform_matrix.cpu()
262269

263270
##############################################################
264271
### Pre-multiply the whitening filter with the CTF filters ###
@@ -308,6 +315,7 @@ def core_match_template(
308315
"num_cuda_streams": num_cuda_streams,
309316
"backend": backend,
310317
"device": d,
318+
"transform_matrix": transform_matrix,
311319
}
312320

313321
kwargs_per_device.append(kwargs)
@@ -371,6 +379,7 @@ def _core_match_template_single_gpu(
371379
num_cuda_streams: int,
372380
backend: str,
373381
device: torch.device,
382+
transform_matrix: torch.Tensor | None = None,
374383
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
375384
"""Single-GPU call for template matching.
376385
@@ -410,6 +419,9 @@ def _core_match_template_single_gpu(
410419
Defaults to 'streamed'. Must be 'streamed' or 'batched'.
411420
device : torch.device
412421
Device to run the computation on. All tensors must be allocated on this device.
422+
transform_matrix : torch.Tensor | None, optional
423+
Anisotropic magnification transform matrix of shape (2, 2). If None,
424+
no magnification transform is applied. Default is None.
413425
414426
Returns
415427
-------
@@ -435,6 +447,9 @@ def _core_match_template_single_gpu(
435447
template_dft = template_dft.to(device)
436448
euler_angles = euler_angles.to(device)
437449
projective_filters = projective_filters.to(device)
450+
# Move transform_matrix to device if it's not None
451+
if transform_matrix is not None:
452+
transform_matrix = transform_matrix.to(device)
438453

439454
num_orientations = euler_angles.shape[0]
440455
num_defocus = defocus_values.shape[0]
@@ -510,6 +525,7 @@ def _core_match_template_single_gpu(
510525
template_dft=template_dft,
511526
rotation_matrices=rot_matrix,
512527
projective_filters=projective_filters,
528+
transform_matrix=transform_matrix,
513529
)
514530
else:
515531
cross_correlation = do_streamed_orientation_cross_correlate(
@@ -518,6 +534,7 @@ def _core_match_template_single_gpu(
518534
rotation_matrices=rot_matrix,
519535
projective_filters=projective_filters,
520536
streams=streams,
537+
transform_matrix=transform_matrix,
521538
)
522539

523540
# Update the tracked statistics

src/leopard_em/backend/core_refine_template.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import roma
1010
import torch
1111
import tqdm
12-
from torch_fourier_slice import extract_central_slices_rfft_3d
12+
from torch_fourier_slice import extract_central_slices_rfft_3d, transform_slice_2d
1313

1414
from leopard_em.backend.cross_correlation import (
1515
do_batched_orientation_cross_correlate,
@@ -61,6 +61,7 @@ def core_refine_template(
6161
device: torch.device | list[torch.device],
6262
batch_size: int = 32,
6363
num_cuda_streams: int = 1,
64+
transform_matrix: torch.Tensor | None = None,
6465
) -> dict[str, torch.Tensor]:
6566
"""Core function to refine orientations and defoci of a set of particles.
6667
@@ -105,6 +106,9 @@ def core_refine_template(
105106
The number of cross-correlations to process in one batch, defaults to 32.
106107
num_cuda_streams : int, optional
107108
Number of CUDA streams to use for parallel processing. Defaults to 1.
109+
transform_matrix : torch.Tensor | None, optional
110+
Anisotropic magnification transform matrix of shape (2, 2). If None,
111+
no magnification transform is applied. Default is None.
108112
109113
Returns
110114
-------
@@ -136,6 +140,7 @@ def core_refine_template(
136140
batch_size=batch_size,
137141
devices=device,
138142
num_cuda_streams=num_cuda_streams,
143+
transform_matrix=transform_matrix,
139144
)
140145

141146
results = run_multiprocess_jobs(
@@ -229,6 +234,7 @@ def construct_multi_gpu_refine_template_kwargs(
229234
batch_size: int,
230235
devices: list[torch.device],
231236
num_cuda_streams: int,
237+
transform_matrix: torch.Tensor | None = None,
232238
) -> list[dict]:
233239
"""Split particle stack between requested devices.
234240
@@ -266,6 +272,9 @@ def construct_multi_gpu_refine_template_kwargs(
266272
List of devices to split across.
267273
num_cuda_streams : int
268274
Number of CUDA streams to use per device.
275+
transform_matrix : torch.Tensor | None, optional
276+
Anisotropic magnification transform matrix of shape (2, 2). If None,
277+
no magnification transform is applied. Default is None.
269278
270279
Returns
271280
-------
@@ -321,6 +330,7 @@ def construct_multi_gpu_refine_template_kwargs(
321330
"batch_size": batch_size,
322331
"num_cuda_streams": num_cuda_streams,
323332
"device": device,
333+
"transform_matrix": transform_matrix,
324334
}
325335

326336
kwargs_per_device.append(kwargs)
@@ -350,6 +360,7 @@ def _core_refine_template_single_gpu(
350360
batch_size: int,
351361
device: torch.device,
352362
num_cuda_streams: int = 1,
363+
transform_matrix: torch.Tensor | None = None,
353364
) -> None:
354365
"""Run refine template on a subset of particles on a single GPU.
355366
@@ -393,6 +404,9 @@ def _core_refine_template_single_gpu(
393404
Torch device to run this process on.
394405
num_cuda_streams : int, optional
395406
Number of CUDA streams to use for parallel processing. Defaults to 1.
407+
transform_matrix : torch.Tensor | None, optional
408+
Anisotropic magnification transform matrix of shape (2, 2). If None,
409+
no magnification transform is applied. Default is None.
396410
"""
397411
streams = [torch.cuda.Stream(device=device) for _ in range(num_cuda_streams)]
398412

@@ -464,6 +478,7 @@ def _core_refine_template_single_gpu(
464478
corr_std=corr_std[i],
465479
projective_filter=projective_filters[i],
466480
batch_size=batch_size,
481+
transform_matrix=transform_matrix,
467482
device_id=device_id,
468483
)
469484
refined_statistics.append(refined_stats)
@@ -565,6 +580,7 @@ def _core_refine_template_single_thread(
565580
projective_filter: torch.Tensor,
566581
batch_size: int = 32,
567582
device_id: int = 0,
583+
transform_matrix: torch.Tensor | None = None,
568584
) -> dict[str, float | int]:
569585
"""Run the single-threaded core refine template function.
570586
@@ -605,6 +621,9 @@ def _core_refine_template_single_thread(
605621
The number of orientations to cross-correlate at once. Default is 32.
606622
device_id : int, optional
607623
The ID of the device/process. Default is 0.
624+
transform_matrix : torch.Tensor | None, optional
625+
Anisotropic magnification transform matrix of shape (2, 2). If None,
626+
no magnification transform is applied. Default is None.
608627
609628
Returns
610629
-------
@@ -693,13 +712,15 @@ def _core_refine_template_single_thread(
693712
template_dft=template_dft,
694713
rotation_matrices=rot_matrix_batch,
695714
projective_filters=combined_projective_filter,
715+
transform_matrix=transform_matrix,
696716
)
697717
else:
698718
cross_correlation = do_batched_orientation_cross_correlate_cpu(
699719
image_dft=particle_image_dft,
700720
template_dft=template_dft,
701721
rotation_matrices=rot_matrix_batch,
702722
projective_filters=combined_projective_filter,
723+
transform_matrix=transform_matrix,
703724
)
704725

705726
cross_correlation = cross_correlation[..., :crop_h, :crop_w] # valid crop
@@ -765,6 +786,7 @@ def cross_correlate_particle_stack(
765786
projective_filters: torch.Tensor, # (N, h, w)
766787
mode: Literal["valid", "same"] = "valid",
767788
batch_size: int = 1024,
789+
transform_matrix: torch.Tensor | None = None,
768790
) -> torch.Tensor:
769791
"""Cross-correlate a stack of particle images against a template.
770792
@@ -793,6 +815,9 @@ def cross_correlate_particle_stack(
793815
The number of particle images to cross-correlate at once. Default is 1024.
794816
Larger sizes will consume more memory. If -1, then the entire stack will be
795817
cross-correlated at once.
818+
transform_matrix : torch.Tensor | None, optional
819+
Anisotropic magnification transform matrix of shape (2, 2). If None,
820+
no magnification transform is applied. Default is None.
796821
797822
Returns
798823
-------
@@ -839,9 +864,19 @@ def cross_correlate_particle_stack(
839864
# Extract the Fourier slice and apply the projective filters
840865
fourier_slice = extract_central_slices_rfft_3d(
841866
volume_rfft=template_dft,
842-
image_shape=(template_h,) * 3,
843867
rotation_matrices=batch_rotation_matrices,
844868
)
869+
# Apply anisotropic magnification transform if provided
870+
# pylint: disable=duplicate-code
871+
if transform_matrix is not None:
872+
rfft_shape = (template_h, template_w)
873+
stack_shape = (batch_rotation_matrices.shape[0],)
874+
fourier_slice = transform_slice_2d(
875+
projection_image_dfts=fourier_slice,
876+
rfft_shape=rfft_shape,
877+
stack_shape=stack_shape,
878+
transform_matrix=transform_matrix,
879+
)
845880
fourier_slice = torch.fft.ifftshift(fourier_slice, dim=(-2,))
846881
fourier_slice[..., 0, 0] = 0 + 0j # zero out the DC component (mean zero)
847882
fourier_slice *= -1 # flip contrast

0 commit comments

Comments
 (0)