Skip to content

Commit f4f0d1f

Browse files
authored
Merge pull request #105 from mgiammar/mdg_correlation_table
Sync correlation table updates with `zipfft_experimental` branch
2 parents 69284da + ba0b4eb commit f4f0d1f

File tree

11 files changed

+425
-141
lines changed

11 files changed

+425
-141
lines changed

src/leopard_em/analysis/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
match_template_peaks_to_dict,
77
)
88
from .pvalue_metric import extract_peaks_and_statistics_p_value
9-
from .zscore_metric import extract_peaks_and_statistics_zscore, gaussian_noise_zscore_cutoff
9+
from .zscore_metric import (
10+
extract_peaks_and_statistics_zscore,
11+
gaussian_noise_zscore_cutoff,
12+
)
1013

1114
__all__ = [
1215
"MatchTemplatePeaks",

src/leopard_em/backend/core_match_template.py

Lines changed: 101 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,21 @@
44
# pylint: disable=E1102
55

66
import time
7+
import traceback
78
import warnings
89
from functools import partial
910
from multiprocessing import set_start_method
1011
from typing import Any, Union
1112

1213
import roma
14+
import tensordict
1315
import torch
1416
import tqdm
1517

1618
from leopard_em.backend.cross_correlation import (
1719
do_batched_orientation_cross_correlate,
18-
do_streamed_orientation_cross_correlate,
1920
do_batched_orientation_cross_correlate_zipfft,
21+
do_streamed_orientation_cross_correlate,
2022
)
2123
from leopard_em.backend.distributed import (
2224
MultiprocessWorkIndexQueue,
@@ -25,17 +27,20 @@
2527
from leopard_em.backend.process_results import (
2628
aggregate_distributed_results,
2729
decode_global_search_index,
30+
process_correlation_table,
2831
scale_mip,
2932
)
30-
from leopard_em.backend.utils import do_iteration_statistics_updates_compiled
33+
from leopard_em.backend.utils import do_iteration_and_correlation_table_updates
3134

3235
DEFAULT_STATISTIC_DTYPE = torch.float32
36+
CORRELATION_TABLE_THRESHOLD = 5.5
3337

3438
# Turn off gradient calculations by default
3539
torch.set_grad_enabled(False)
3640

3741
# Set multiprocessing start method to spawn
3842
set_start_method("spawn", force=True)
43+
torch.multiprocessing.set_sharing_strategy("file_system")
3944

4045

4146
def monitor_match_template_progress(
@@ -78,6 +83,7 @@ def monitor_match_template_progress(
7883
time.sleep(poll_interval)
7984
except Exception as e:
8085
print(f"Error occurred: {e}")
86+
traceback.print_exc()
8187
queue.set_error_flag()
8288
raise e
8389
finally:
@@ -156,7 +162,7 @@ def core_match_template(
156162
orientation_batch_size: int = 1,
157163
num_cuda_streams: int = 1,
158164
backend: str = "streamed",
159-
) -> dict[str, torch.Tensor]:
165+
) -> dict[str, torch.Tensor | dict | int]:
160166
"""Core function for performing the whole-orientation search.
161167
162168
With the RFFT, the last dimension (fastest dimension) is half the width
@@ -213,7 +219,7 @@ def core_match_template(
213219
214220
Returns
215221
-------
216-
dict[str, torch.Tensor]
222+
dict[str, torch.Tensor | dict | int]
217223
Dictionary containing the following key, value pairs:
218224
219225
- "mip": Maximum intensity projection of the cross-correlation values across
@@ -223,10 +229,12 @@ def core_match_template(
223229
- "best_theta": Best theta angle for each pixel.
224230
- "best_psi": Best psi angle for each pixel.
225231
- "best_defocus": Best defocus value for each pixel.
226-
- "best_pixel_size": Best pixel size value for each pixel.
227-
- "correlation_sum": Sum of cross-correlation values for each pixel.
228-
- "correlation_squared_sum": Sum of squared cross-correlation values for
232+
- "correlation_mean": Sum of cross-correlation values for each pixel.
233+
- "correlation_variance": Sum of squared cross-correlation values for
234+
- "correlation_table": Processed correlation table with all points in search
235+
space and image positions where correlation value exceeded a threshold.
229236
each pixel.
237+
- "total_projections": Total number of cross-correlations computed.
230238
- "total_orientations": Total number of orientations searched.
231239
- "total_defocus": Total number of defocus values searched.
232240
"""
@@ -328,7 +336,7 @@ def core_match_template(
328336
correlation_squared_sum = aggregated_results["correlation_squared_sum"]
329337

330338
# Map from global search index to the best defocus & angles
331-
best_phi, best_theta, best_psi, best_defocus = decode_global_search_index(
339+
best_phi, best_theta, best_psi, best_defocus, _ = decode_global_search_index(
332340
best_global_index, pixel_values, defocus_values, euler_angles
333341
)
334342

@@ -341,6 +349,14 @@ def core_match_template(
341349
total_correlation_positions=total_projections,
342350
)
343351

352+
# Process the correlation table into a more interpretable format
353+
correlation_table = process_correlation_table(
354+
aggregated_results["correlation_table"],
355+
pixel_values,
356+
defocus_values,
357+
euler_angles,
358+
)
359+
344360
return {
345361
"mip": mip,
346362
"scaled_mip": mip_scaled,
@@ -350,6 +366,7 @@ def core_match_template(
350366
"best_defocus": best_defocus,
351367
"correlation_mean": correlation_mean,
352368
"correlation_variance": correlation_variance,
369+
"correlation_table": correlation_table,
353370
"total_projections": total_projections,
354371
"total_orientations": euler_angles.shape[0],
355372
"total_defocus": defocus_values.shape[0],
@@ -372,7 +389,9 @@ def _core_match_template_single_gpu(
372389
num_cuda_streams: int,
373390
backend: str,
374391
device: torch.device,
375-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
392+
) -> tuple[
393+
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, tensordict.TensorDict
394+
]:
376395
"""Single-GPU call for template matching.
377396
378397
Parameters
@@ -422,11 +441,17 @@ def _core_match_template_single_gpu(
422441
- correlation_sum: Sum of cross-correlation values for each pixel.
423442
- correlation_squared_sum: Sum of squared cross-correlation values for
424443
each pixel.
444+
- correlation_table: Table of search indices and image positions where
445+
correlation values exceeded a threshold.
425446
"""
426447
image_shape_real = (image_dft.shape[0], image_dft.shape[1] * 2 - 2) # adj. for RFFT
427-
cross_correlation_shape_valid = (
428-
image_shape_real[0] - template_dft.shape[1] + 1,
429-
image_shape_real[1] - (template_dft.shape[2] * 2 - 2) + 1,
448+
projection_shape_real = (
449+
template_dft.shape[1],
450+
template_dft.shape[2] * 2 - 2, # adj. for RFFT
451+
)
452+
valid_correlation_shape = (
453+
image_shape_real[0] - projection_shape_real[0] + 1,
454+
image_shape_real[1] - projection_shape_real[1] + 1,
430455
)
431456

432457
# Create CUDA streams for parallel computation
@@ -459,52 +484,52 @@ def _core_match_template_single_gpu(
459484
### Initialize the tracked output statistics ###
460485
################################################
461486

487+
# Correlation table built from 'tensordict' library where any (x, y) positions
488+
# in correlation map which surpass the threshold will be added to the table.
489+
# Keys in table are:
490+
# - "threshold": float threshold value used for the table.
491+
# - "global_idx": int32 global search index.
492+
# - "pos_x": int32 x position in image where corr value surpassed threshold.
493+
# - "pos_y": int32 y position in image where corr value surpassed threshold.
494+
# - "corr_value": float32 correlation value at (pos_x, pos_y) for the given
495+
# global index.
496+
correlation_table = tensordict.TensorDict(
497+
{
498+
"threshold": CORRELATION_TABLE_THRESHOLD,
499+
"global_idx": torch.tensor([], dtype=torch.int32, device=device),
500+
"pos_x": torch.tensor([], dtype=torch.int32, device=device),
501+
"pos_y": torch.tensor([], dtype=torch.int32, device=device),
502+
"corr_value": torch.tensor([], dtype=torch.float32, device=device),
503+
},
504+
device=device,
505+
)
506+
mip = torch.full(
507+
size=valid_correlation_shape,
508+
fill_value=-float("inf"),
509+
dtype=DEFAULT_STATISTIC_DTYPE,
510+
device=device,
511+
)
512+
best_global_index = torch.full(
513+
valid_correlation_shape,
514+
fill_value=-1,
515+
dtype=torch.int32,
516+
device=device,
517+
)
518+
correlation_sum = torch.zeros(
519+
size=valid_correlation_shape,
520+
dtype=DEFAULT_STATISTIC_DTYPE,
521+
device=device,
522+
)
523+
correlation_squared_sum = torch.zeros(
524+
size=valid_correlation_shape,
525+
dtype=DEFAULT_STATISTIC_DTYPE,
526+
device=device,
527+
)
462528
if backend == "zipfft":
463-
mip = torch.full(
464-
size=cross_correlation_shape_valid,
465-
fill_value=-float("inf"),
466-
dtype=DEFAULT_STATISTIC_DTYPE,
467-
device=device,
468-
)
469-
best_global_index = torch.full(
470-
cross_correlation_shape_valid,
471-
fill_value=-1,
472-
dtype=torch.int32,
473-
device=device,
474-
)
475-
correlation_sum = torch.zeros(
476-
size=cross_correlation_shape_valid,
477-
dtype=DEFAULT_STATISTIC_DTYPE,
478-
device=device,
479-
)
480-
correlation_squared_sum = torch.zeros(
481-
size=cross_correlation_shape_valid,
482-
dtype=DEFAULT_STATISTIC_DTYPE,
483-
device=device,
484-
)
485529
# NOTE: zipFFT expects a pre-transformed, pre-transposed input image FFT
486530
# Transpose the 'image_dft' along last two dimensions into contiguous layout
487531
# with shape (..., W // 2 + 1, H)
488532
image_dft = image_dft.transpose(-2, -1).contiguous()
489-
# NOTE: zipFFT does not apply backwards FFT normalization, so we instead apply
490-
# it to the input image (does not require addtl. multiplications in loop)
491-
image_dft *= (image_shape_real[0] * image_shape_real[1])
492-
else:
493-
mip = torch.full(
494-
size=image_shape_real,
495-
fill_value=-float("inf"),
496-
dtype=DEFAULT_STATISTIC_DTYPE,
497-
device=device,
498-
)
499-
best_global_index = torch.full(
500-
image_shape_real, fill_value=-1, dtype=torch.int32, device=device
501-
)
502-
correlation_sum = torch.zeros(
503-
size=image_shape_real, dtype=DEFAULT_STATISTIC_DTYPE, device=device
504-
)
505-
correlation_squared_sum = torch.zeros(
506-
size=image_shape_real, dtype=DEFAULT_STATISTIC_DTYPE, device=device
507-
)
508533

509534
##################################
510535
### Start the orientation loop ###
@@ -563,25 +588,21 @@ def _core_match_template_single_gpu(
563588
projective_filters=projective_filters,
564589
)
565590

566-
# Update the tracked statistics
567-
do_iteration_statistics_updates_compiled(
591+
# Update tracked statistics and correlation table
592+
do_iteration_and_correlation_table_updates(
568593
cross_correlation=cross_correlation,
569594
current_indexes=batch_search_indices,
595+
correlation_table=correlation_table,
570596
mip=mip,
571597
best_global_index=best_global_index,
572598
correlation_sum=correlation_sum,
573599
correlation_squared_sum=correlation_squared_sum,
574-
img_h=(
575-
image_shape_real[0]
576-
if backend != "zipfft"
577-
else cross_correlation_shape_valid[0]
578-
),
579-
img_w=(
580-
image_shape_real[1]
581-
if backend != "zipfft"
582-
else cross_correlation_shape_valid[1]
583-
),
600+
threshold=CORRELATION_TABLE_THRESHOLD,
601+
valid_shape_h=valid_correlation_shape[0],
602+
valid_shape_w=valid_correlation_shape[1],
603+
needs_valid_cropping=(backend != "zipfft"),
584604
)
605+
585606
except Exception as e:
586607
index_queue.set_error_flag()
587608
print(f"Error occurred in process {rank}: {e}")
@@ -593,7 +614,13 @@ def _core_match_template_single_gpu(
593614

594615
torch.cuda.synchronize(device)
595616

596-
return mip, best_global_index, correlation_sum, correlation_squared_sum
617+
return (
618+
mip,
619+
best_global_index,
620+
correlation_sum,
621+
correlation_squared_sum,
622+
correlation_table,
623+
)
597624

598625

599626
def _core_match_template_multiprocess_wrapper(
@@ -607,9 +634,13 @@ def _core_match_template_multiprocess_wrapper(
607634
608635
See the _core_match_template_single_gpu function for parameter descriptions.
609636
"""
610-
mip, best_global_index, correlation_sum, correlation_squared_sum = (
611-
_core_match_template_single_gpu(rank, **kwargs) # type: ignore[arg-type]
612-
)
637+
(
638+
mip,
639+
best_global_index,
640+
correlation_sum,
641+
correlation_squared_sum,
642+
correlation_table,
643+
) = _core_match_template_single_gpu(rank, **kwargs) # type: ignore[arg-type]
613644

614645
# NOTE: Need to send all tensors back to the CPU as numpy arrays for the shared
615646
# process dictionary. This is a workaround for now
@@ -618,6 +649,7 @@ def _core_match_template_multiprocess_wrapper(
618649
"best_global_index": best_global_index.cpu().numpy(),
619650
"correlation_sum": correlation_sum.cpu().numpy(),
620651
"correlation_squared_sum": correlation_squared_sum.cpu().numpy(),
652+
"correlation_table": correlation_table.cpu(),
621653
}
622654

623655
# Place the results in the shared multi-process manager dictionary so accessible

src/leopard_em/backend/core_match_template_distributed.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -454,21 +454,25 @@ def core_match_template_distributed(
454454
###########################################################
455455

456456
dist.barrier()
457-
(mip, best_global_index, correlation_sum, correlation_squared_sum) = (
458-
_core_match_template_single_gpu(
459-
rank=rank,
460-
index_queue=distributed_queue, # type: ignore
461-
image_dft=image_dft,
462-
template_dft=template_dft,
463-
euler_angles=euler_angles,
464-
projective_filters=projective_filters,
465-
defocus_values=defocus_values,
466-
pixel_values=pixel_values,
467-
orientation_batch_size=orientation_batch_size,
468-
num_cuda_streams=num_cuda_streams,
469-
backend=backend,
470-
device=device,
471-
)
457+
(
458+
mip,
459+
best_global_index,
460+
correlation_sum,
461+
correlation_squared_sum,
462+
_, # TODO: include correlation_table in distributed version
463+
) = _core_match_template_single_gpu(
464+
rank=rank,
465+
index_queue=distributed_queue, # type: ignore
466+
image_dft=image_dft,
467+
template_dft=template_dft,
468+
euler_angles=euler_angles,
469+
projective_filters=projective_filters,
470+
defocus_values=defocus_values,
471+
pixel_values=pixel_values,
472+
orientation_batch_size=orientation_batch_size,
473+
num_cuda_streams=num_cuda_streams,
474+
backend=backend,
475+
device=device,
472476
)
473477
dist.barrier()
474478

@@ -534,7 +538,7 @@ def core_match_template_distributed(
534538

535539
# Map from global search index to the best defocus & angles
536540
# pylint: disable=duplicate-code
537-
best_phi, best_theta, best_psi, best_defocus = decode_global_search_index(
541+
best_phi, best_theta, best_psi, best_defocus, _ = decode_global_search_index(
538542
best_global_index, pixel_values, defocus_values, euler_angles
539543
)
540544

0 commit comments

Comments
 (0)