44# pylint: disable=E1102
55
66import time
7+ import traceback
78import warnings
89from functools import partial
910from multiprocessing import set_start_method
1011from typing import Any , Union
1112
1213import roma
14+ import tensordict
1315import torch
1416import tqdm
1517
1618from leopard_em .backend .cross_correlation import (
1719 do_batched_orientation_cross_correlate ,
18- do_streamed_orientation_cross_correlate ,
1920 do_batched_orientation_cross_correlate_zipfft ,
21+ do_streamed_orientation_cross_correlate ,
2022)
2123from leopard_em .backend .distributed import (
2224 MultiprocessWorkIndexQueue ,
2527from leopard_em .backend .process_results import (
2628 aggregate_distributed_results ,
2729 decode_global_search_index ,
30+ process_correlation_table ,
2831 scale_mip ,
2932)
30- from leopard_em .backend .utils import do_iteration_statistics_updates_compiled
33+ from leopard_em .backend .utils import do_iteration_and_correlation_table_updates
3134
3235DEFAULT_STATISTIC_DTYPE = torch .float32
36+ CORRELATION_TABLE_THRESHOLD = 5.5
3337
3438# Turn off gradient calculations by default
3539torch .set_grad_enabled (False )
3640
3741# Set multiprocessing start method to spawn
3842set_start_method ("spawn" , force = True )
43+ torch .multiprocessing .set_sharing_strategy ("file_system" )
3944
4045
4146def monitor_match_template_progress (
@@ -78,6 +83,7 @@ def monitor_match_template_progress(
7883 time .sleep (poll_interval )
7984 except Exception as e :
8085 print (f"Error occurred: { e } " )
86+ traceback .print_exc ()
8187 queue .set_error_flag ()
8288 raise e
8389 finally :
@@ -156,7 +162,7 @@ def core_match_template(
156162 orientation_batch_size : int = 1 ,
157163 num_cuda_streams : int = 1 ,
158164 backend : str = "streamed" ,
159- ) -> dict [str , torch .Tensor ]:
165+ ) -> dict [str , torch .Tensor | dict | int ]:
160166 """Core function for performing the whole-orientation search.
161167
162168 With the RFFT, the last dimension (fastest dimension) is half the width
@@ -213,7 +219,7 @@ def core_match_template(
213219
214220 Returns
215221 -------
216- dict[str, torch.Tensor]
222+ dict[str, torch.Tensor | dict | int ]
217223 Dictionary containing the following key, value pairs:
218224
219225 - "mip": Maximum intensity projection of the cross-correlation values across
@@ -223,10 +229,12 @@ def core_match_template(
223229 - "best_theta": Best theta angle for each pixel.
224230 - "best_psi": Best psi angle for each pixel.
225231 - "best_defocus": Best defocus value for each pixel.
226- - "best_pixel_size": Best pixel size value for each pixel.
227- - "correlation_sum": Sum of cross-correlation values for each pixel.
228- - "correlation_squared_sum": Sum of squared cross-correlation values for
232+ - "correlation_mean": Sum of cross-correlation values for each pixel.
233+ - "correlation_variance": Sum of squared cross-correlation values for
234+ - "correlation_table": Processed correlation table with all points in search
235+ space and image positions where correlation value exceeded a threshold.
229236 each pixel.
237+ - "total_projections": Total number of cross-correlations computed.
230238 - "total_orientations": Total number of orientations searched.
231239 - "total_defocus": Total number of defocus values searched.
232240 """
@@ -328,7 +336,7 @@ def core_match_template(
328336 correlation_squared_sum = aggregated_results ["correlation_squared_sum" ]
329337
330338 # Map from global search index to the best defocus & angles
331- best_phi , best_theta , best_psi , best_defocus = decode_global_search_index (
339+ best_phi , best_theta , best_psi , best_defocus , _ = decode_global_search_index (
332340 best_global_index , pixel_values , defocus_values , euler_angles
333341 )
334342
@@ -341,6 +349,14 @@ def core_match_template(
341349 total_correlation_positions = total_projections ,
342350 )
343351
352+ # Process the correlation table into a more interpretable format
353+ correlation_table = process_correlation_table (
354+ aggregated_results ["correlation_table" ],
355+ pixel_values ,
356+ defocus_values ,
357+ euler_angles ,
358+ )
359+
344360 return {
345361 "mip" : mip ,
346362 "scaled_mip" : mip_scaled ,
@@ -350,6 +366,7 @@ def core_match_template(
350366 "best_defocus" : best_defocus ,
351367 "correlation_mean" : correlation_mean ,
352368 "correlation_variance" : correlation_variance ,
369+ "correlation_table" : correlation_table ,
353370 "total_projections" : total_projections ,
354371 "total_orientations" : euler_angles .shape [0 ],
355372 "total_defocus" : defocus_values .shape [0 ],
@@ -372,7 +389,9 @@ def _core_match_template_single_gpu(
372389 num_cuda_streams : int ,
373390 backend : str ,
374391 device : torch .device ,
375- ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
392+ ) -> tuple [
393+ torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , tensordict .TensorDict
394+ ]:
376395 """Single-GPU call for template matching.
377396
378397 Parameters
@@ -422,11 +441,17 @@ def _core_match_template_single_gpu(
422441 - correlation_sum: Sum of cross-correlation values for each pixel.
423442 - correlation_squared_sum: Sum of squared cross-correlation values for
424443 each pixel.
444+ - correlation_table: Table of search indices and image positions where
445+ correlation values exceeded a threshold.
425446 """
426447 image_shape_real = (image_dft .shape [0 ], image_dft .shape [1 ] * 2 - 2 ) # adj. for RFFT
427- cross_correlation_shape_valid = (
428- image_shape_real [0 ] - template_dft .shape [1 ] + 1 ,
429- image_shape_real [1 ] - (template_dft .shape [2 ] * 2 - 2 ) + 1 ,
448+ projection_shape_real = (
449+ template_dft .shape [1 ],
450+ template_dft .shape [2 ] * 2 - 2 , # adj. for RFFT
451+ )
452+ valid_correlation_shape = (
453+ image_shape_real [0 ] - projection_shape_real [0 ] + 1 ,
454+ image_shape_real [1 ] - projection_shape_real [1 ] + 1 ,
430455 )
431456
432457 # Create CUDA streams for parallel computation
@@ -459,52 +484,52 @@ def _core_match_template_single_gpu(
459484 ### Initialize the tracked output statistics ###
460485 ################################################
461486
487+ # Correlation table built from 'tensordict' library where any (x, y) positions
488+ # in correlation map which surpass the threshold will be added to the table.
489+ # Keys in table are:
490+ # - "threshold": float threshold value used for the table.
491+ # - "global_idx": int32 global search index.
492+ # - "pos_x": int32 x position in image where corr value surpassed threshold.
493+ # - "pos_y": int32 y position in image where corr value surpassed threshold.
494+ # - "corr_value": float32 correlation value at (pos_x, pos_y) for the given
495+ # global index.
496+ correlation_table = tensordict .TensorDict (
497+ {
498+ "threshold" : CORRELATION_TABLE_THRESHOLD ,
499+ "global_idx" : torch .tensor ([], dtype = torch .int32 , device = device ),
500+ "pos_x" : torch .tensor ([], dtype = torch .int32 , device = device ),
501+ "pos_y" : torch .tensor ([], dtype = torch .int32 , device = device ),
502+ "corr_value" : torch .tensor ([], dtype = torch .float32 , device = device ),
503+ },
504+ device = device ,
505+ )
506+ mip = torch .full (
507+ size = valid_correlation_shape ,
508+ fill_value = - float ("inf" ),
509+ dtype = DEFAULT_STATISTIC_DTYPE ,
510+ device = device ,
511+ )
512+ best_global_index = torch .full (
513+ valid_correlation_shape ,
514+ fill_value = - 1 ,
515+ dtype = torch .int32 ,
516+ device = device ,
517+ )
518+ correlation_sum = torch .zeros (
519+ size = valid_correlation_shape ,
520+ dtype = DEFAULT_STATISTIC_DTYPE ,
521+ device = device ,
522+ )
523+ correlation_squared_sum = torch .zeros (
524+ size = valid_correlation_shape ,
525+ dtype = DEFAULT_STATISTIC_DTYPE ,
526+ device = device ,
527+ )
462528 if backend == "zipfft" :
463- mip = torch .full (
464- size = cross_correlation_shape_valid ,
465- fill_value = - float ("inf" ),
466- dtype = DEFAULT_STATISTIC_DTYPE ,
467- device = device ,
468- )
469- best_global_index = torch .full (
470- cross_correlation_shape_valid ,
471- fill_value = - 1 ,
472- dtype = torch .int32 ,
473- device = device ,
474- )
475- correlation_sum = torch .zeros (
476- size = cross_correlation_shape_valid ,
477- dtype = DEFAULT_STATISTIC_DTYPE ,
478- device = device ,
479- )
480- correlation_squared_sum = torch .zeros (
481- size = cross_correlation_shape_valid ,
482- dtype = DEFAULT_STATISTIC_DTYPE ,
483- device = device ,
484- )
485529 # NOTE: zipFFT expects a pre-transformed, pre-transposed input image FFT
486530 # Transpose the 'image_dft' along last two dimensions into contiguous layout
487531 # with shape (..., W // 2 + 1, H)
488532 image_dft = image_dft .transpose (- 2 , - 1 ).contiguous ()
489- # NOTE: zipFFT does not apply backwards FFT normalization, so we instead apply
490- # it to the input image (does not require addtl. multiplications in loop)
491- image_dft *= (image_shape_real [0 ] * image_shape_real [1 ])
492- else :
493- mip = torch .full (
494- size = image_shape_real ,
495- fill_value = - float ("inf" ),
496- dtype = DEFAULT_STATISTIC_DTYPE ,
497- device = device ,
498- )
499- best_global_index = torch .full (
500- image_shape_real , fill_value = - 1 , dtype = torch .int32 , device = device
501- )
502- correlation_sum = torch .zeros (
503- size = image_shape_real , dtype = DEFAULT_STATISTIC_DTYPE , device = device
504- )
505- correlation_squared_sum = torch .zeros (
506- size = image_shape_real , dtype = DEFAULT_STATISTIC_DTYPE , device = device
507- )
508533
509534 ##################################
510535 ### Start the orientation loop ###
@@ -563,25 +588,21 @@ def _core_match_template_single_gpu(
563588 projective_filters = projective_filters ,
564589 )
565590
566- # Update the tracked statistics
567- do_iteration_statistics_updates_compiled (
591+ # Update tracked statistics and correlation table
592+ do_iteration_and_correlation_table_updates (
568593 cross_correlation = cross_correlation ,
569594 current_indexes = batch_search_indices ,
595+ correlation_table = correlation_table ,
570596 mip = mip ,
571597 best_global_index = best_global_index ,
572598 correlation_sum = correlation_sum ,
573599 correlation_squared_sum = correlation_squared_sum ,
574- img_h = (
575- image_shape_real [0 ]
576- if backend != "zipfft"
577- else cross_correlation_shape_valid [0 ]
578- ),
579- img_w = (
580- image_shape_real [1 ]
581- if backend != "zipfft"
582- else cross_correlation_shape_valid [1 ]
583- ),
600+ threshold = CORRELATION_TABLE_THRESHOLD ,
601+ valid_shape_h = valid_correlation_shape [0 ],
602+ valid_shape_w = valid_correlation_shape [1 ],
603+ needs_valid_cropping = (backend != "zipfft" ),
584604 )
605+
585606 except Exception as e :
586607 index_queue .set_error_flag ()
587608 print (f"Error occurred in process { rank } : { e } " )
@@ -593,7 +614,13 @@ def _core_match_template_single_gpu(
593614
594615 torch .cuda .synchronize (device )
595616
596- return mip , best_global_index , correlation_sum , correlation_squared_sum
617+ return (
618+ mip ,
619+ best_global_index ,
620+ correlation_sum ,
621+ correlation_squared_sum ,
622+ correlation_table ,
623+ )
597624
598625
599626def _core_match_template_multiprocess_wrapper (
@@ -607,9 +634,13 @@ def _core_match_template_multiprocess_wrapper(
607634
608635 See the _core_match_template_single_gpu function for parameter descriptions.
609636 """
610- mip , best_global_index , correlation_sum , correlation_squared_sum = (
611- _core_match_template_single_gpu (rank , ** kwargs ) # type: ignore[arg-type]
612- )
637+ (
638+ mip ,
639+ best_global_index ,
640+ correlation_sum ,
641+ correlation_squared_sum ,
642+ correlation_table ,
643+ ) = _core_match_template_single_gpu (rank , ** kwargs ) # type: ignore[arg-type]
613644
614645 # NOTE: Need to send all tensors back to the CPU as numpy arrays for the shared
615646 # process dictionary. This is a workaround for now
@@ -618,6 +649,7 @@ def _core_match_template_multiprocess_wrapper(
618649 "best_global_index" : best_global_index .cpu ().numpy (),
619650 "correlation_sum" : correlation_sum .cpu ().numpy (),
620651 "correlation_squared_sum" : correlation_squared_sum .cpu ().numpy (),
652+ "correlation_table" : correlation_table .cpu (),
621653 }
622654
623655 # Place the results in the shared multi-process manager dictionary so accessible
0 commit comments