@@ -61,7 +61,7 @@ def core_refine_template(
6161 device : torch .device | list [torch .device ],
6262 batch_size : int = 32 ,
6363 num_cuda_streams : int = 1 ,
64- transform_matrix : torch .Tensor | None = None ,
64+ mag_matrix : torch .Tensor | None = None ,
6565) -> dict [str , torch .Tensor ]:
6666 """Core function to refine orientations and defoci of a set of particles.
6767
@@ -106,8 +106,8 @@ def core_refine_template(
106106 The number of cross-correlations to process in one batch, defaults to 32.
107107 num_cuda_streams : int, optional
108108 Number of CUDA streams to use for parallel processing. Defaults to 1.
109- transform_matrix : torch.Tensor | None, optional
110- Anisotropic magnification transform matrix of shape (2, 2). If None,
109+ mag_matrix : torch.Tensor | None, optional
110+ Anisotropic magnification matrix of shape (2, 2). If None,
111111 no magnification transform is applied. Default is None.
112112
113113 Returns
@@ -140,7 +140,7 @@ def core_refine_template(
140140 batch_size = batch_size ,
141141 devices = device ,
142142 num_cuda_streams = num_cuda_streams ,
143- transform_matrix = transform_matrix ,
143+ mag_matrix = mag_matrix ,
144144 )
145145
146146 results = run_multiprocess_jobs (
@@ -234,7 +234,7 @@ def construct_multi_gpu_refine_template_kwargs(
234234 batch_size : int ,
235235 devices : list [torch .device ],
236236 num_cuda_streams : int ,
237- transform_matrix : torch .Tensor | None = None ,
237+ mag_matrix : torch .Tensor | None = None ,
238238) -> list [dict ]:
239239 """Split particle stack between requested devices.
240240
@@ -272,8 +272,8 @@ def construct_multi_gpu_refine_template_kwargs(
272272 List of devices to split across.
273273 num_cuda_streams : int
274274 Number of CUDA streams to use per device.
275- transform_matrix : torch.Tensor | None, optional
276- Anisotropic magnification transform matrix of shape (2, 2). If None,
275+ mag_matrix : torch.Tensor | None, optional
276+ Anisotropic magnification matrix of shape (2, 2). If None,
277277 no magnification transform is applied. Default is None.
278278
279279 Returns
@@ -330,7 +330,7 @@ def construct_multi_gpu_refine_template_kwargs(
330330 "batch_size" : batch_size ,
331331 "num_cuda_streams" : num_cuda_streams ,
332332 "device" : device ,
333- "transform_matrix " : transform_matrix ,
333+ "mag_matrix " : mag_matrix ,
334334 }
335335
336336 kwargs_per_device .append (kwargs )
@@ -360,7 +360,7 @@ def _core_refine_template_single_gpu(
360360 batch_size : int ,
361361 device : torch .device ,
362362 num_cuda_streams : int = 1 ,
363- transform_matrix : torch .Tensor | None = None ,
363+ mag_matrix : torch .Tensor | None = None ,
364364) -> None :
365365 """Run refine template on a subset of particles on a single GPU.
366366
@@ -404,8 +404,8 @@ def _core_refine_template_single_gpu(
404404 Torch device to run this process on.
405405 num_cuda_streams : int, optional
406406 Number of CUDA streams to use for parallel processing. Defaults to 1.
407- transform_matrix : torch.Tensor | None, optional
408- Anisotropic magnification transform matrix of shape (2, 2). If None,
407+ mag_matrix : torch.Tensor | None, optional
408+ Anisotropic magnification matrix of shape (2, 2). If None,
409409 no magnification transform is applied. Default is None.
410410 """
411411 streams = [torch .cuda .Stream (device = device ) for _ in range (num_cuda_streams )]
@@ -478,7 +478,7 @@ def _core_refine_template_single_gpu(
478478 corr_std = corr_std [i ],
479479 projective_filter = projective_filters [i ],
480480 batch_size = batch_size ,
481- transform_matrix = transform_matrix ,
481+ mag_matrix = mag_matrix ,
482482 device_id = device_id ,
483483 )
484484 refined_statistics .append (refined_stats )
@@ -580,7 +580,7 @@ def _core_refine_template_single_thread(
580580 projective_filter : torch .Tensor ,
581581 batch_size : int = 32 ,
582582 device_id : int = 0 ,
583- transform_matrix : torch .Tensor | None = None ,
583+ mag_matrix : torch .Tensor | None = None ,
584584) -> dict [str , float | int ]:
585585 """Run the single-threaded core refine template function.
586586
@@ -621,8 +621,8 @@ def _core_refine_template_single_thread(
621621 The number of orientations to cross-correlate at once. Default is 32.
622622 device_id : int, optional
623623 The ID of the device/process. Default is 0.
624- transform_matrix : torch.Tensor | None, optional
625- Anisotropic magnification transform matrix of shape (2, 2). If None,
624+ mag_matrix : torch.Tensor | None, optional
625+ Anisotropic magnification matrix of shape (2, 2). If None,
626626 no magnification transform is applied. Default is None.
627627
628628 Returns
@@ -712,15 +712,15 @@ def _core_refine_template_single_thread(
712712 template_dft = template_dft ,
713713 rotation_matrices = rot_matrix_batch ,
714714 projective_filters = combined_projective_filter ,
715- transform_matrix = transform_matrix ,
715+ mag_matrix = mag_matrix ,
716716 )
717717 else :
718718 cross_correlation = do_batched_orientation_cross_correlate_cpu (
719719 image_dft = particle_image_dft ,
720720 template_dft = template_dft ,
721721 rotation_matrices = rot_matrix_batch ,
722722 projective_filters = combined_projective_filter ,
723- transform_matrix = transform_matrix ,
723+ mag_matrix = mag_matrix ,
724724 )
725725
726726 cross_correlation = cross_correlation [..., :crop_h , :crop_w ] # valid crop
@@ -786,7 +786,7 @@ def cross_correlate_particle_stack(
786786 projective_filters : torch .Tensor , # (N, h, w)
787787 mode : Literal ["valid" , "same" ] = "valid" ,
788788 batch_size : int = 1024 ,
789- transform_matrix : torch .Tensor | None = None ,
789+ mag_matrix : torch .Tensor | None = None ,
790790) -> torch .Tensor :
791791 """Cross-correlate a stack of particle images against a template.
792792
@@ -815,8 +815,8 @@ def cross_correlate_particle_stack(
815815 The number of particle images to cross-correlate at once. Default is 1024.
816816 Larger sizes will consume more memory. If -1, then the entire stack will be
817817 cross-correlated at once.
818- transform_matrix : torch.Tensor | None, optional
819- Anisotropic magnification transform matrix of shape (2, 2). If None,
818+ mag_matrix : torch.Tensor | None, optional
819+ Anisotropic magnification matrix of shape (2, 2). If None,
820820 no magnification transform is applied. Default is None.
821821
822822 Returns
@@ -868,14 +868,14 @@ def cross_correlate_particle_stack(
868868 )
869869 # Apply anisotropic magnification transform if provided
870870 # pylint: disable=duplicate-code
871- if transform_matrix is not None :
871+ if mag_matrix is not None :
872872 rfft_shape = (template_h , template_w )
873873 stack_shape = (batch_rotation_matrices .shape [0 ],)
874874 fourier_slice = transform_slice_2d (
875875 projection_image_dfts = fourier_slice ,
876876 rfft_shape = rfft_shape ,
877877 stack_shape = stack_shape ,
878- transform_matrix = transform_matrix ,
878+ transform_matrix = mag_matrix ,
879879 )
880880 fourier_slice = torch .fft .ifftshift (fourier_slice , dim = (- 2 ,))
881881 fourier_slice [..., 0 , 0 ] = 0 + 0j # zero out the DC component (mean zero)
0 commit comments