Skip to content

Commit ca6fde6

Browse files
committed
change transform_matrix name
1 parent cdd9d1c commit ca6fde6

File tree

7 files changed

+78
-78
lines changed

7 files changed

+78
-78
lines changed

src/leopard_em/backend/core_differentiable_refine.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def core_differentiable_refine(
6868
device: torch.device | list[torch.device],
6969
batch_size: int = 32,
7070
num_cuda_streams: int = 1,
71-
transform_matrix: torch.Tensor | None = None,
71+
mag_matrix: torch.Tensor | None = None,
7272
) -> dict[torch.Tensor, torch.Tensor]:
7373
"""Core function to refine orientations and defoci of a set of particles.
7474
@@ -113,8 +113,8 @@ def core_differentiable_refine(
113113
The number of cross-correlations to process in one batch, defaults to 32.
114114
num_cuda_streams : int, optional
115115
Number of CUDA streams to use for parallel processing. Defaults to 1.
116-
transform_matrix : torch.Tensor | None, optional
117-
Anisotropic magnification transform matrix of shape (2, 2). If None,
116+
mag_matrix : torch.Tensor | None, optional
117+
Anisotropic magnification matrix of shape (2, 2). If None,
118118
no magnification transform is applied. Default is None.
119119
120120
Returns
@@ -147,7 +147,7 @@ def core_differentiable_refine(
147147
batch_size=batch_size,
148148
devices=device,
149149
num_cuda_streams=num_cuda_streams,
150-
transform_matrix=transform_matrix,
150+
mag_matrix=mag_matrix,
151151
)
152152

153153
results = {}
@@ -235,7 +235,7 @@ def _core_refine_template_single_gpu(
235235
batch_size: int,
236236
device: torch.device,
237237
num_cuda_streams: int = 1,
238-
transform_matrix: torch.Tensor | None = None,
238+
mag_matrix: torch.Tensor | None = None,
239239
) -> None:
240240
"""Run refine template on a subset of particles on a single GPU.
241241
@@ -279,8 +279,8 @@ def _core_refine_template_single_gpu(
279279
Torch device to run this process on.
280280
num_cuda_streams : int, optional
281281
Number of CUDA streams to use for parallel processing. Defaults to 1.
282-
transform_matrix : torch.Tensor | None, optional
283-
Anisotropic magnification transform matrix of shape (2, 2). If None,
282+
mag_matrix : torch.Tensor | None, optional
283+
Anisotropic magnification matrix of shape (2, 2). If None,
284284
no magnification transform is applied. Default is None.
285285
"""
286286
streams = [torch.cuda.Stream(device=device) for _ in range(num_cuda_streams)]
@@ -302,8 +302,8 @@ def _core_refine_template_single_gpu(
302302
corr_mean = corr_mean.to(device)
303303
corr_std = corr_std.to(device)
304304
projective_filters = projective_filters.to(device)
305-
if transform_matrix is not None:
306-
transform_matrix = transform_matrix.to(device)
305+
if mag_matrix is not None:
306+
mag_matrix = mag_matrix.to(device)
307307

308308
########################################
309309
### Setup constants and progress bar ###
@@ -356,7 +356,7 @@ def _core_refine_template_single_gpu(
356356
projective_filter=projective_filters[i],
357357
batch_size=batch_size,
358358
device_id=device_id,
359-
transform_matrix=transform_matrix,
359+
mag_matrix=mag_matrix,
360360
)
361361
refined_statistics.append(refined_stats)
362362

@@ -469,7 +469,7 @@ def _core_refine_template_single_thread(
469469
projective_filter: torch.Tensor,
470470
batch_size: int = 32,
471471
device_id: int = 0,
472-
transform_matrix: torch.Tensor | None = None,
472+
mag_matrix: torch.Tensor | None = None,
473473
) -> dict[str, float | int]:
474474
"""Run the single-threaded core refine template function.
475475
@@ -510,8 +510,8 @@ def _core_refine_template_single_thread(
510510
The number of orientations to cross-correlate at once. Default is 32.
511511
device_id : int, optional
512512
The ID of the device/process. Default is 0.
513-
transform_matrix : torch.Tensor | None, optional
514-
Anisotropic magnification transform matrix of shape (2, 2). If None,
513+
mag_matrix : torch.Tensor | None, optional
514+
Anisotropic magnification matrix of shape (2, 2). If None,
515515
no magnification transform is applied. Default is None.
516516
517517
Returns
@@ -616,15 +616,15 @@ def _core_refine_template_single_thread(
616616
rotation_matrices=rot_matrix_batch,
617617
projective_filters=combined_projective_filter,
618618
requires_grad=True,
619-
transform_matrix=transform_matrix,
619+
mag_matrix=mag_matrix,
620620
)
621621
else:
622622
cross_correlation = do_batched_orientation_cross_correlate_cpu(
623623
image_dft=particle_image_dft,
624624
template_dft=template_dft,
625625
rotation_matrices=rot_matrix_batch,
626626
projective_filters=combined_projective_filter,
627-
transform_matrix=transform_matrix,
627+
mag_matrix=mag_matrix,
628628
)
629629

630630
cross_correlation = cross_correlation[..., :crop_h, :crop_w] # valid crop

src/leopard_em/backend/core_match_template.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def core_match_template(
155155
orientation_batch_size: int = 1,
156156
num_cuda_streams: int = 1,
157157
backend: str = "streamed",
158-
transform_matrix: torch.Tensor | None = None,
158+
mag_matrix: torch.Tensor | None = None,
159159
) -> dict[str, torch.Tensor]:
160160
"""Core function for performing the whole-orientation search.
161161
@@ -210,8 +210,8 @@ def core_match_template(
210210
backend : str, optional
211211
The backend to use for computation. Defaults to 'streamed'.
212212
Must be 'streamed' or 'batched'.
213-
transform_matrix : torch.Tensor | None, optional
214-
Anisotropic magnification transform matrix of shape (2, 2). If None,
213+
mag_matrix : torch.Tensor | None, optional
214+
Anisotropic magnification matrix of shape (2, 2). If None,
215215
no magnification transform is applied. Default is None.
216216
217217
Returns
@@ -263,9 +263,9 @@ def core_match_template(
263263
defocus_values = defocus_values.cpu()
264264
pixel_values = pixel_values.cpu()
265265
euler_angles = euler_angles.cpu()
266-
# Move transform_matrix to CPU if it's not None
267-
if transform_matrix is not None:
268-
transform_matrix = transform_matrix.cpu()
266+
# Move mag_matrix to CPU if it's not None
267+
if mag_matrix is not None:
268+
mag_matrix = mag_matrix.cpu()
269269

270270
##############################################################
271271
### Pre-multiply the whitening filter with the CTF filters ###
@@ -315,7 +315,7 @@ def core_match_template(
315315
"num_cuda_streams": num_cuda_streams,
316316
"backend": backend,
317317
"device": d,
318-
"transform_matrix": transform_matrix,
318+
"mag_matrix": mag_matrix,
319319
}
320320

321321
kwargs_per_device.append(kwargs)
@@ -379,7 +379,7 @@ def _core_match_template_single_gpu(
379379
num_cuda_streams: int,
380380
backend: str,
381381
device: torch.device,
382-
transform_matrix: torch.Tensor | None = None,
382+
mag_matrix: torch.Tensor | None = None,
383383
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
384384
"""Single-GPU call for template matching.
385385
@@ -419,8 +419,8 @@ def _core_match_template_single_gpu(
419419
Defaults to 'streamed'. Must be 'streamed' or 'batched'.
420420
device : torch.device
421421
Device to run the computation on. All tensors must be allocated on this device.
422-
transform_matrix : torch.Tensor | None, optional
423-
Anisotropic magnification transform matrix of shape (2, 2). If None,
422+
mag_matrix : torch.Tensor | None, optional
423+
Anisotropic magnification matrix of shape (2, 2). If None,
424424
no magnification transform is applied. Default is None.
425425
426426
Returns
@@ -447,9 +447,9 @@ def _core_match_template_single_gpu(
447447
template_dft = template_dft.to(device)
448448
euler_angles = euler_angles.to(device)
449449
projective_filters = projective_filters.to(device)
450-
# Move transform_matrix to device if it's not None
451-
if transform_matrix is not None:
452-
transform_matrix = transform_matrix.to(device)
450+
# Move mag_matrix to device if it's not None
451+
if mag_matrix is not None:
452+
mag_matrix = mag_matrix.to(device)
453453

454454
num_orientations = euler_angles.shape[0]
455455
num_defocus = defocus_values.shape[0]
@@ -525,7 +525,7 @@ def _core_match_template_single_gpu(
525525
template_dft=template_dft,
526526
rotation_matrices=rot_matrix,
527527
projective_filters=projective_filters,
528-
transform_matrix=transform_matrix,
528+
mag_matrix=mag_matrix,
529529
)
530530
else:
531531
cross_correlation = do_streamed_orientation_cross_correlate(
@@ -534,7 +534,7 @@ def _core_match_template_single_gpu(
534534
rotation_matrices=rot_matrix,
535535
projective_filters=projective_filters,
536536
streams=streams,
537-
transform_matrix=transform_matrix,
537+
mag_matrix=mag_matrix,
538538
)
539539

540540
# Update the tracked statistics

src/leopard_em/backend/core_refine_template.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def core_refine_template(
6161
device: torch.device | list[torch.device],
6262
batch_size: int = 32,
6363
num_cuda_streams: int = 1,
64-
transform_matrix: torch.Tensor | None = None,
64+
mag_matrix: torch.Tensor | None = None,
6565
) -> dict[str, torch.Tensor]:
6666
"""Core function to refine orientations and defoci of a set of particles.
6767
@@ -106,8 +106,8 @@ def core_refine_template(
106106
The number of cross-correlations to process in one batch, defaults to 32.
107107
num_cuda_streams : int, optional
108108
Number of CUDA streams to use for parallel processing. Defaults to 1.
109-
transform_matrix : torch.Tensor | None, optional
110-
Anisotropic magnification transform matrix of shape (2, 2). If None,
109+
mag_matrix : torch.Tensor | None, optional
110+
Anisotropic magnification matrix of shape (2, 2). If None,
111111
no magnification transform is applied. Default is None.
112112
113113
Returns
@@ -140,7 +140,7 @@ def core_refine_template(
140140
batch_size=batch_size,
141141
devices=device,
142142
num_cuda_streams=num_cuda_streams,
143-
transform_matrix=transform_matrix,
143+
mag_matrix=mag_matrix,
144144
)
145145

146146
results = run_multiprocess_jobs(
@@ -234,7 +234,7 @@ def construct_multi_gpu_refine_template_kwargs(
234234
batch_size: int,
235235
devices: list[torch.device],
236236
num_cuda_streams: int,
237-
transform_matrix: torch.Tensor | None = None,
237+
mag_matrix: torch.Tensor | None = None,
238238
) -> list[dict]:
239239
"""Split particle stack between requested devices.
240240
@@ -272,8 +272,8 @@ def construct_multi_gpu_refine_template_kwargs(
272272
List of devices to split across.
273273
num_cuda_streams : int
274274
Number of CUDA streams to use per device.
275-
transform_matrix : torch.Tensor | None, optional
276-
Anisotropic magnification transform matrix of shape (2, 2). If None,
275+
mag_matrix : torch.Tensor | None, optional
276+
Anisotropic magnification matrix of shape (2, 2). If None,
277277
no magnification transform is applied. Default is None.
278278
279279
Returns
@@ -330,7 +330,7 @@ def construct_multi_gpu_refine_template_kwargs(
330330
"batch_size": batch_size,
331331
"num_cuda_streams": num_cuda_streams,
332332
"device": device,
333-
"transform_matrix": transform_matrix,
333+
"mag_matrix": mag_matrix,
334334
}
335335

336336
kwargs_per_device.append(kwargs)
@@ -360,7 +360,7 @@ def _core_refine_template_single_gpu(
360360
batch_size: int,
361361
device: torch.device,
362362
num_cuda_streams: int = 1,
363-
transform_matrix: torch.Tensor | None = None,
363+
mag_matrix: torch.Tensor | None = None,
364364
) -> None:
365365
"""Run refine template on a subset of particles on a single GPU.
366366
@@ -404,8 +404,8 @@ def _core_refine_template_single_gpu(
404404
Torch device to run this process on.
405405
num_cuda_streams : int, optional
406406
Number of CUDA streams to use for parallel processing. Defaults to 1.
407-
transform_matrix : torch.Tensor | None, optional
408-
Anisotropic magnification transform matrix of shape (2, 2). If None,
407+
mag_matrix : torch.Tensor | None, optional
408+
Anisotropic magnification matrix of shape (2, 2). If None,
409409
no magnification transform is applied. Default is None.
410410
"""
411411
streams = [torch.cuda.Stream(device=device) for _ in range(num_cuda_streams)]
@@ -478,7 +478,7 @@ def _core_refine_template_single_gpu(
478478
corr_std=corr_std[i],
479479
projective_filter=projective_filters[i],
480480
batch_size=batch_size,
481-
transform_matrix=transform_matrix,
481+
mag_matrix=mag_matrix,
482482
device_id=device_id,
483483
)
484484
refined_statistics.append(refined_stats)
@@ -580,7 +580,7 @@ def _core_refine_template_single_thread(
580580
projective_filter: torch.Tensor,
581581
batch_size: int = 32,
582582
device_id: int = 0,
583-
transform_matrix: torch.Tensor | None = None,
583+
mag_matrix: torch.Tensor | None = None,
584584
) -> dict[str, float | int]:
585585
"""Run the single-threaded core refine template function.
586586
@@ -621,8 +621,8 @@ def _core_refine_template_single_thread(
621621
The number of orientations to cross-correlate at once. Default is 32.
622622
device_id : int, optional
623623
The ID of the device/process. Default is 0.
624-
transform_matrix : torch.Tensor | None, optional
625-
Anisotropic magnification transform matrix of shape (2, 2). If None,
624+
mag_matrix : torch.Tensor | None, optional
625+
Anisotropic magnification matrix of shape (2, 2). If None,
626626
no magnification transform is applied. Default is None.
627627
628628
Returns
@@ -712,15 +712,15 @@ def _core_refine_template_single_thread(
712712
template_dft=template_dft,
713713
rotation_matrices=rot_matrix_batch,
714714
projective_filters=combined_projective_filter,
715-
transform_matrix=transform_matrix,
715+
mag_matrix=mag_matrix,
716716
)
717717
else:
718718
cross_correlation = do_batched_orientation_cross_correlate_cpu(
719719
image_dft=particle_image_dft,
720720
template_dft=template_dft,
721721
rotation_matrices=rot_matrix_batch,
722722
projective_filters=combined_projective_filter,
723-
transform_matrix=transform_matrix,
723+
mag_matrix=mag_matrix,
724724
)
725725

726726
cross_correlation = cross_correlation[..., :crop_h, :crop_w] # valid crop
@@ -786,7 +786,7 @@ def cross_correlate_particle_stack(
786786
projective_filters: torch.Tensor, # (N, h, w)
787787
mode: Literal["valid", "same"] = "valid",
788788
batch_size: int = 1024,
789-
transform_matrix: torch.Tensor | None = None,
789+
mag_matrix: torch.Tensor | None = None,
790790
) -> torch.Tensor:
791791
"""Cross-correlate a stack of particle images against a template.
792792
@@ -815,8 +815,8 @@ def cross_correlate_particle_stack(
815815
The number of particle images to cross-correlate at once. Default is 1024.
816816
Larger sizes will consume more memory. If -1, then the entire stack will be
817817
cross-correlated at once.
818-
transform_matrix : torch.Tensor | None, optional
819-
Anisotropic magnification transform matrix of shape (2, 2). If None,
818+
mag_matrix : torch.Tensor | None, optional
819+
Anisotropic magnification matrix of shape (2, 2). If None,
820820
no magnification transform is applied. Default is None.
821821
822822
Returns
@@ -868,14 +868,14 @@ def cross_correlate_particle_stack(
868868
)
869869
# Apply anisotropic magnification transform if provided
870870
# pylint: disable=duplicate-code
871-
if transform_matrix is not None:
871+
if mag_matrix is not None:
872872
rfft_shape = (template_h, template_w)
873873
stack_shape = (batch_rotation_matrices.shape[0],)
874874
fourier_slice = transform_slice_2d(
875875
projection_image_dfts=fourier_slice,
876876
rfft_shape=rfft_shape,
877877
stack_shape=stack_shape,
878-
transform_matrix=transform_matrix,
878+
transform_matrix=mag_matrix,
879879
)
880880
fourier_slice = torch.fft.ifftshift(fourier_slice, dim=(-2,))
881881
fourier_slice[..., 0, 0] = 0 + 0j # zero out the DC component (mean zero)

0 commit comments

Comments
 (0)