Skip to content

Commit d9e404d

Browse files
authored
Merge pull request #59 from mgiammar/refine_template_stream
Add back batched cross-correlation for `refine_template` program
2 parents 312a2f5 + d3c85b3 commit d9e404d

File tree

12 files changed

+1370
-312
lines changed

12 files changed

+1370
-312
lines changed

docs/programs/refine_template.md

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ All these parameters are discussed in more detail on the [match template program
125125

126126
### Configuring GPUs for a match template run
127127

128-
Template refinement can run across multiple GPUs and is controlled in the same way as match template.
129-
Note that the `num_cpus` field is currently unused for the refine template program and can just be set to one.
130-
Like [configuring GPUs for a match template run](match_template.md#configuring-gpus-for-a-match-template-run), GPUs are targeted by their device index.
128+
The refine template program parallelizes across multiple GPUs by splitting which particles are refined across the configured list of GPU devices.
129+
The `num_cpus` field controls how many concurrent streams of work are being submitted to each GPUs; in most cases, a value of `1` or `2` will saturate the GPU and give the best performance, although your mileage may vary.
130+
Like [configuring GPUs for a match template run](match_template.md#configuring-gpus-for-a-match-template-run), GPUs are targeted by their device index or the special string `"all"`
131131
The following configuration will run `refine_template` on GPU zero.
132132

133133
```yaml
@@ -136,6 +136,14 @@ computational_config:
136136
num_cpus: 1
137137
```
138138

139+
The following configuration will run `refine_template` on all available GPUs with two streams per GPU.
140+
141+
```yaml
142+
computational_config:
143+
gpu_ids: "all"
144+
num_cpus: 2
145+
```
146+
139147
## Running the refine template program
140148

141149
Once you've configured a YAML file, running the refine template program is fairly simple.

src/leopard_em/backend/core_match_template.py

Lines changed: 12 additions & 193 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,19 @@
1010
import roma
1111
import torch
1212
import tqdm
13-
from torch_fourier_slice import extract_central_slices_rfft_3d
1413

14+
from leopard_em.backend.cross_correlation import (
15+
do_streamed_orientation_cross_correlate,
16+
)
1517
from leopard_em.backend.process_results import (
1618
aggregate_distributed_results,
1719
scale_mip,
1820
)
1921
from leopard_em.backend.utils import (
20-
do_iteration_statistics_updates,
21-
normalize_template_projection,
22+
do_iteration_statistics_updates_compiled,
2223
run_multiprocess_jobs,
2324
)
2425

25-
COMPILE_BACKEND = "inductor"
2626
DEFAULT_STATISTIC_DTYPE = torch.float32
2727

2828
# Turn off gradient calculations by default
@@ -31,13 +31,6 @@
3131
# Set multiprocessing start method to spawn
3232
set_start_method("spawn", force=True)
3333

34-
normalize_template_projection_compiled = torch.compile(
35-
normalize_template_projection, backend=COMPILE_BACKEND
36-
)
37-
do_iteration_statistics_updates_compiled = torch.compile(
38-
do_iteration_statistics_updates, backend=COMPILE_BACKEND
39-
)
40-
4134

4235
###########################################################
4336
### Main function for whole orientation search ###
@@ -396,6 +389,10 @@ def _core_match_template_single_gpu(
396389
### Setup iterator object with tqdm for progress bar ###
397390
########################################################
398391

392+
total_projections = (
393+
euler_angles.shape[0] * defocus_values.shape[0] * pixel_values.shape[0]
394+
)
395+
399396
num_batches = math.ceil(euler_angles.shape[0] / orientation_batch_size)
400397
orientation_batch_iterator = tqdm.tqdm(
401398
range(num_batches),
@@ -405,10 +402,9 @@ def _core_match_template_single_gpu(
405402
dynamic_ncols=True,
406403
position=device.index,
407404
mininterval=1, # Slow down to reduce number of lines written
408-
)
409-
410-
total_projections = (
411-
euler_angles.shape[0] * defocus_values.shape[0] * pixel_values.shape[0]
405+
smoothing=0.05,
406+
unit="corr",
407+
unit_scale=total_projections / num_batches,
412408
)
413409

414410
##################################
@@ -423,7 +419,7 @@ def _core_match_template_single_gpu(
423419
"ZYZ", euler_angles_batch, degrees=True, device=device
424420
)
425421

426-
cross_correlation = _do_bached_orientation_cross_correlate(
422+
cross_correlation = do_streamed_orientation_cross_correlate(
427423
image_dft=image_dft,
428424
template_dft=template_dft,
429425
rotation_matrices=rot_matrix,
@@ -466,180 +462,3 @@ def _core_match_template_single_gpu(
466462
# Place the results in the shared multi-process manager dictionary so accessible
467463
# by the main process.
468464
result_dict[device_id] = result
469-
470-
471-
def _do_bached_orientation_cross_correlate(
472-
image_dft: torch.Tensor,
473-
template_dft: torch.Tensor,
474-
rotation_matrices: torch.Tensor,
475-
projective_filters: torch.Tensor,
476-
streams: list[torch.cuda.Stream],
477-
) -> torch.Tensor:
478-
"""Batched projection and cross-correlation with fixed (batched) filters.
479-
480-
Note that this function returns a cross-correlogram with "same" mode (i.e. the
481-
same size as the input image). See numpy correlate docs for more information.
482-
483-
Parameters
484-
----------
485-
image_dft : torch.Tensor
486-
Real-fourier transform (RFFT) of the image with large image filters
487-
already applied. Has shape (H, W // 2 + 1).
488-
template_dft : torch.Tensor
489-
Real-fourier transform (RFFT) of the template volume to take Fourier
490-
slices from. Has shape (l, h, w // 2 + 1) where (l, h, w) is the original
491-
real-space shape of the template volume.
492-
rotation_matrices : torch.Tensor
493-
Rotation matrices to apply to the template volume. Has shape
494-
(num_orientations, 3, 3).
495-
projective_filters : torch.Tensor
496-
Multiplied 'ctf_filters' with 'whitening_filter_template'. Has shape
497-
(num_Cs, num_defocus, h, w // 2 + 1). Is RFFT and not fftshifted.
498-
streams : list[torch.cuda.Stream]
499-
List of CUDA streams to use for parallel computation. Each stream will
500-
handle a separate cross-correlation.
501-
502-
Returns
503-
-------
504-
torch.Tensor
505-
Cross-correlation of the image with the template volume for each
506-
orientation and defocus value. Will have shape
507-
(num_Cs, num_defocus, num_orientations, H, W).
508-
"""
509-
# Accounting for RFFT shape
510-
projection_shape_real = (template_dft.shape[1], template_dft.shape[2] * 2 - 2)
511-
image_shape_real = (image_dft.shape[0], image_dft.shape[1] * 2 - 2)
512-
513-
num_orientations = rotation_matrices.shape[0]
514-
num_Cs = projective_filters.shape[0] # pylint: disable=invalid-name
515-
num_defocus = projective_filters.shape[1]
516-
517-
cross_correlation = torch.empty(
518-
size=(num_Cs, num_defocus, num_orientations, *image_shape_real),
519-
dtype=DEFAULT_STATISTIC_DTYPE,
520-
device=image_dft.device,
521-
)
522-
523-
# Do a batched Fourier slice extraction for all the orientations at once.
524-
fourier_slices = extract_central_slices_rfft_3d(
525-
volume_rfft=template_dft,
526-
image_shape=(projection_shape_real[0],) * 3,
527-
rotation_matrices=rotation_matrices,
528-
)
529-
fourier_slices = torch.fft.ifftshift(fourier_slices, dim=(-2,))
530-
fourier_slices[..., 0, 0] = 0 + 0j # zero out the DC component (mean zero)
531-
fourier_slices *= -1 # flip contrast
532-
533-
# Iterate over the orientations
534-
for i in range(num_orientations):
535-
fourier_slice = fourier_slices[i]
536-
537-
# Iterate over the different pixel sizes (Cs) and defocus values for this
538-
# particular orientation
539-
for j in range(num_defocus):
540-
for k in range(num_Cs):
541-
# Use a round-robin scheduling for the streams
542-
job_idx = (i * num_defocus * num_Cs) + (j * num_Cs) + k
543-
stream_idx = job_idx % len(streams)
544-
stream = streams[stream_idx]
545-
546-
with torch.cuda.stream(stream):
547-
# Apply the projective filter and do template normalization
548-
fourier_slice_filtered = fourier_slice * projective_filters[k, j]
549-
projection = torch.fft.irfft2(fourier_slice_filtered)
550-
projection = torch.fft.ifftshift(projection, dim=(-2, -1))
551-
projection = normalize_template_projection_compiled(
552-
projection,
553-
projection_shape_real,
554-
image_shape_real,
555-
)
556-
557-
# Padded forward Fourier transform for cross-correlation
558-
projection_dft = torch.fft.rfft2(projection, s=image_shape_real)
559-
projection_dft[0, 0] = 0 + 0j
560-
561-
# Cross correlation step by element-wise multiplication
562-
projection_dft = image_dft * projection_dft.conj()
563-
torch.fft.irfft2(
564-
projection_dft,
565-
s=image_shape_real,
566-
out=cross_correlation[k, j, i],
567-
)
568-
569-
# Wait for all streams to finish
570-
for stream in streams:
571-
stream.synchronize()
572-
573-
# shape is (num_Cs, num_defocus, num_orientations, H, W)
574-
return cross_correlation
575-
576-
577-
def _do_bached_orientation_cross_correlate_cpu(
578-
image_dft: torch.Tensor,
579-
template_dft: torch.Tensor,
580-
rotation_matrices: torch.Tensor,
581-
projective_filters: torch.Tensor,
582-
) -> torch.Tensor:
583-
"""Same as `_do_bached_orientation_cross_correlate` but on the CPU.
584-
585-
The only difference is that this function does not call into a compiled torch
586-
function for normalization.
587-
588-
TODO: Figure out a better way to split up CPU/GPU functions while remaining
589-
performant and not duplicating code.
590-
591-
Parameters
592-
----------
593-
image_dft : torch.Tensor
594-
Real-fourier transform (RFFT) of the image with large image filters
595-
already applied. Has shape (H, W // 2 + 1).
596-
template_dft : torch.Tensor
597-
Real-fourier transform (RFFT) of the template volume to take Fourier
598-
slices from. Has shape (l, h, w // 2 + 1).
599-
rotation_matrices : torch.Tensor
600-
Rotation matrices to apply to the template volume. Has shape
601-
(orientations, 3, 3).
602-
projective_filters : torch.Tensor
603-
Multiplied 'ctf_filters' with 'whitening_filter_template'. Has shape
604-
(defocus_batch, h, w // 2 + 1). Is RFFT and not fftshifted.
605-
606-
Returns
607-
-------
608-
torch.Tensor
609-
Cross-correlation for the batch of orientations and defocus values.s
610-
"""
611-
# Accounting for RFFT shape
612-
projection_shape_real = (template_dft.shape[1], template_dft.shape[2] * 2 - 2)
613-
image_shape_real = (image_dft.shape[0], image_dft.shape[1] * 2 - 2)
614-
615-
# Extract central slice(s) from the template volume
616-
fourier_slice = extract_central_slices_rfft_3d(
617-
volume_rfft=template_dft,
618-
image_shape=(projection_shape_real[0],) * 3, # NOTE: requires cubic template
619-
rotation_matrices=rotation_matrices,
620-
)
621-
fourier_slice = torch.fft.ifftshift(fourier_slice, dim=(-2,))
622-
fourier_slice[..., 0, 0] = 0 + 0j # zero out the DC component (mean zero)
623-
fourier_slice *= -1 # flip contrast
624-
625-
# Apply the projective filters on a new batch dimension
626-
fourier_slice = fourier_slice[None, None, ...] * projective_filters[:, :, None, ...]
627-
628-
# Inverse Fourier transform into real space and normalize
629-
projections = torch.fft.irfftn(fourier_slice, dim=(-2, -1))
630-
projections = torch.fft.ifftshift(projections, dim=(-2, -1))
631-
projections = normalize_template_projection(
632-
projections,
633-
projection_shape_real,
634-
image_shape_real,
635-
)
636-
637-
# Padded forward Fourier transform for cross-correlation
638-
projections_dft = torch.fft.rfftn(projections, dim=(-2, -1), s=image_shape_real)
639-
projections_dft[..., 0, 0] = 0 + 0j # zero out the DC component (mean zero)
640-
641-
# Cross correlation step by element-wise multiplication
642-
projections_dft = image_dft[None, None, None, ...] * projections_dft.conj()
643-
cross_correlation = torch.fft.irfftn(projections_dft, dim=(-2, -1))
644-
645-
return cross_correlation

0 commit comments

Comments
 (0)