1010import roma
1111import torch
1212import tqdm
13- from torch_fourier_slice import extract_central_slices_rfft_3d
1413
14+ from leopard_em .backend .cross_correlation import (
15+ do_streamed_orientation_cross_correlate ,
16+ )
1517from leopard_em .backend .process_results import (
1618 aggregate_distributed_results ,
1719 scale_mip ,
1820)
1921from leopard_em .backend .utils import (
20- do_iteration_statistics_updates ,
21- normalize_template_projection ,
22+ do_iteration_statistics_updates_compiled ,
2223 run_multiprocess_jobs ,
2324)
2425
25- COMPILE_BACKEND = "inductor"
2626DEFAULT_STATISTIC_DTYPE = torch .float32
2727
2828# Turn off gradient calculations by default
3131# Set multiprocessing start method to spawn
3232set_start_method ("spawn" , force = True )
3333
34- normalize_template_projection_compiled = torch .compile (
35- normalize_template_projection , backend = COMPILE_BACKEND
36- )
37- do_iteration_statistics_updates_compiled = torch .compile (
38- do_iteration_statistics_updates , backend = COMPILE_BACKEND
39- )
40-
4134
4235###########################################################
4336### Main function for whole orientation search ###
@@ -396,6 +389,10 @@ def _core_match_template_single_gpu(
396389 ### Setup iterator object with tqdm for progress bar ###
397390 ########################################################
398391
392+ total_projections = (
393+ euler_angles .shape [0 ] * defocus_values .shape [0 ] * pixel_values .shape [0 ]
394+ )
395+
399396 num_batches = math .ceil (euler_angles .shape [0 ] / orientation_batch_size )
400397 orientation_batch_iterator = tqdm .tqdm (
401398 range (num_batches ),
@@ -405,10 +402,9 @@ def _core_match_template_single_gpu(
405402 dynamic_ncols = True ,
406403 position = device .index ,
407404 mininterval = 1 , # Slow down to reduce number of lines written
408- )
409-
410- total_projections = (
411- euler_angles .shape [0 ] * defocus_values .shape [0 ] * pixel_values .shape [0 ]
405+ smoothing = 0.05 ,
406+ unit = "corr" ,
407+ unit_scale = total_projections / num_batches ,
412408 )
413409
414410 ##################################
@@ -423,7 +419,7 @@ def _core_match_template_single_gpu(
423419 "ZYZ" , euler_angles_batch , degrees = True , device = device
424420 )
425421
426- cross_correlation = _do_bached_orientation_cross_correlate (
422+ cross_correlation = do_streamed_orientation_cross_correlate (
427423 image_dft = image_dft ,
428424 template_dft = template_dft ,
429425 rotation_matrices = rot_matrix ,
@@ -466,180 +462,3 @@ def _core_match_template_single_gpu(
466462 # Place the results in the shared multi-process manager dictionary so accessible
467463 # by the main process.
468464 result_dict [device_id ] = result
469-
470-
471- def _do_bached_orientation_cross_correlate (
472- image_dft : torch .Tensor ,
473- template_dft : torch .Tensor ,
474- rotation_matrices : torch .Tensor ,
475- projective_filters : torch .Tensor ,
476- streams : list [torch .cuda .Stream ],
477- ) -> torch .Tensor :
478- """Batched projection and cross-correlation with fixed (batched) filters.
479-
480- Note that this function returns a cross-correlogram with "same" mode (i.e. the
481- same size as the input image). See numpy correlate docs for more information.
482-
483- Parameters
484- ----------
485- image_dft : torch.Tensor
486- Real-fourier transform (RFFT) of the image with large image filters
487- already applied. Has shape (H, W // 2 + 1).
488- template_dft : torch.Tensor
489- Real-fourier transform (RFFT) of the template volume to take Fourier
490- slices from. Has shape (l, h, w // 2 + 1) where (l, h, w) is the original
491- real-space shape of the template volume.
492- rotation_matrices : torch.Tensor
493- Rotation matrices to apply to the template volume. Has shape
494- (num_orientations, 3, 3).
495- projective_filters : torch.Tensor
496- Multiplied 'ctf_filters' with 'whitening_filter_template'. Has shape
497- (num_Cs, num_defocus, h, w // 2 + 1). Is RFFT and not fftshifted.
498- streams : list[torch.cuda.Stream]
499- List of CUDA streams to use for parallel computation. Each stream will
500- handle a separate cross-correlation.
501-
502- Returns
503- -------
504- torch.Tensor
505- Cross-correlation of the image with the template volume for each
506- orientation and defocus value. Will have shape
507- (num_Cs, num_defocus, num_orientations, H, W).
508- """
509- # Accounting for RFFT shape
510- projection_shape_real = (template_dft .shape [1 ], template_dft .shape [2 ] * 2 - 2 )
511- image_shape_real = (image_dft .shape [0 ], image_dft .shape [1 ] * 2 - 2 )
512-
513- num_orientations = rotation_matrices .shape [0 ]
514- num_Cs = projective_filters .shape [0 ] # pylint: disable=invalid-name
515- num_defocus = projective_filters .shape [1 ]
516-
517- cross_correlation = torch .empty (
518- size = (num_Cs , num_defocus , num_orientations , * image_shape_real ),
519- dtype = DEFAULT_STATISTIC_DTYPE ,
520- device = image_dft .device ,
521- )
522-
523- # Do a batched Fourier slice extraction for all the orientations at once.
524- fourier_slices = extract_central_slices_rfft_3d (
525- volume_rfft = template_dft ,
526- image_shape = (projection_shape_real [0 ],) * 3 ,
527- rotation_matrices = rotation_matrices ,
528- )
529- fourier_slices = torch .fft .ifftshift (fourier_slices , dim = (- 2 ,))
530- fourier_slices [..., 0 , 0 ] = 0 + 0j # zero out the DC component (mean zero)
531- fourier_slices *= - 1 # flip contrast
532-
533- # Iterate over the orientations
534- for i in range (num_orientations ):
535- fourier_slice = fourier_slices [i ]
536-
537- # Iterate over the different pixel sizes (Cs) and defocus values for this
538- # particular orientation
539- for j in range (num_defocus ):
540- for k in range (num_Cs ):
541- # Use a round-robin scheduling for the streams
542- job_idx = (i * num_defocus * num_Cs ) + (j * num_Cs ) + k
543- stream_idx = job_idx % len (streams )
544- stream = streams [stream_idx ]
545-
546- with torch .cuda .stream (stream ):
547- # Apply the projective filter and do template normalization
548- fourier_slice_filtered = fourier_slice * projective_filters [k , j ]
549- projection = torch .fft .irfft2 (fourier_slice_filtered )
550- projection = torch .fft .ifftshift (projection , dim = (- 2 , - 1 ))
551- projection = normalize_template_projection_compiled (
552- projection ,
553- projection_shape_real ,
554- image_shape_real ,
555- )
556-
557- # Padded forward Fourier transform for cross-correlation
558- projection_dft = torch .fft .rfft2 (projection , s = image_shape_real )
559- projection_dft [0 , 0 ] = 0 + 0j
560-
561- # Cross correlation step by element-wise multiplication
562- projection_dft = image_dft * projection_dft .conj ()
563- torch .fft .irfft2 (
564- projection_dft ,
565- s = image_shape_real ,
566- out = cross_correlation [k , j , i ],
567- )
568-
569- # Wait for all streams to finish
570- for stream in streams :
571- stream .synchronize ()
572-
573- # shape is (num_Cs, num_defocus, num_orientations, H, W)
574- return cross_correlation
575-
576-
577- def _do_bached_orientation_cross_correlate_cpu (
578- image_dft : torch .Tensor ,
579- template_dft : torch .Tensor ,
580- rotation_matrices : torch .Tensor ,
581- projective_filters : torch .Tensor ,
582- ) -> torch .Tensor :
583- """Same as `_do_bached_orientation_cross_correlate` but on the CPU.
584-
585- The only difference is that this function does not call into a compiled torch
586- function for normalization.
587-
588- TODO: Figure out a better way to split up CPU/GPU functions while remaining
589- performant and not duplicating code.
590-
591- Parameters
592- ----------
593- image_dft : torch.Tensor
594- Real-fourier transform (RFFT) of the image with large image filters
595- already applied. Has shape (H, W // 2 + 1).
596- template_dft : torch.Tensor
597- Real-fourier transform (RFFT) of the template volume to take Fourier
598- slices from. Has shape (l, h, w // 2 + 1).
599- rotation_matrices : torch.Tensor
600- Rotation matrices to apply to the template volume. Has shape
601- (orientations, 3, 3).
602- projective_filters : torch.Tensor
603- Multiplied 'ctf_filters' with 'whitening_filter_template'. Has shape
604- (defocus_batch, h, w // 2 + 1). Is RFFT and not fftshifted.
605-
606- Returns
607- -------
608- torch.Tensor
609- Cross-correlation for the batch of orientations and defocus values.s
610- """
611- # Accounting for RFFT shape
612- projection_shape_real = (template_dft .shape [1 ], template_dft .shape [2 ] * 2 - 2 )
613- image_shape_real = (image_dft .shape [0 ], image_dft .shape [1 ] * 2 - 2 )
614-
615- # Extract central slice(s) from the template volume
616- fourier_slice = extract_central_slices_rfft_3d (
617- volume_rfft = template_dft ,
618- image_shape = (projection_shape_real [0 ],) * 3 , # NOTE: requires cubic template
619- rotation_matrices = rotation_matrices ,
620- )
621- fourier_slice = torch .fft .ifftshift (fourier_slice , dim = (- 2 ,))
622- fourier_slice [..., 0 , 0 ] = 0 + 0j # zero out the DC component (mean zero)
623- fourier_slice *= - 1 # flip contrast
624-
625- # Apply the projective filters on a new batch dimension
626- fourier_slice = fourier_slice [None , None , ...] * projective_filters [:, :, None , ...]
627-
628- # Inverse Fourier transform into real space and normalize
629- projections = torch .fft .irfftn (fourier_slice , dim = (- 2 , - 1 ))
630- projections = torch .fft .ifftshift (projections , dim = (- 2 , - 1 ))
631- projections = normalize_template_projection (
632- projections ,
633- projection_shape_real ,
634- image_shape_real ,
635- )
636-
637- # Padded forward Fourier transform for cross-correlation
638- projections_dft = torch .fft .rfftn (projections , dim = (- 2 , - 1 ), s = image_shape_real )
639- projections_dft [..., 0 , 0 ] = 0 + 0j # zero out the DC component (mean zero)
640-
641- # Cross correlation step by element-wise multiplication
642- projections_dft = image_dft [None , None , None , ...] * projections_dft .conj ()
643- cross_correlation = torch .fft .irfftn (projections_dft , dim = (- 2 , - 1 ))
644-
645- return cross_correlation
0 commit comments