@@ -251,17 +251,25 @@ def _get_cropped_image_regions_torch(
251251
252252 regions = []
253253 for y , x in zip (pos_y , pos_x ):
254- # Check bounds and clamp to edges if out of bounds
255- # Convert to Python ints for comparison and clamping
254+ # Convert to Python ints for comparison
256255 y = int (y .item () if hasattr (y , "item" ) else y )
257256 x = int (x .item () if hasattr (x , "item" ) else x )
258257 original_y , original_x = y , x
258+
259+ # Check bounds
259260 if (
260261 y < 0
261262 or x < 0
262263 or y + box_size [0 ] > image .shape [0 ]
263264 or x + box_size [1 ] > image .shape [1 ]
264265 ):
266+ if handle_bounds == "error" :
267+ raise IndexError (
268+ f"Region bounds [{ original_y } :{ original_y + box_size [0 ]} , "
269+ f"{ original_x } :{ original_x + box_size [1 ]} ] exceed "
270+ f"image dimensions { image .shape } "
271+ )
272+ # For "pad" mode, warn and clamp coordinates
265273 warnings .warn (
266274 f"Region bounds [{ original_y } :{ original_y + box_size [0 ]} , "
267275 f"{ original_x } :{ original_x + box_size [1 ]} ] exceed "
@@ -874,120 +882,138 @@ def get_dataframe_copy(self) -> pd.DataFrame:
874882 @staticmethod
875883 # pylint: disable=too-many-arguments
876884 # pylint: disable=too-many-positional-arguments
877- def _process_single_frame_for_checkpoint (
885+ def _process_single_frame_with_shifts_checkpoint (
886+ movie_frame : torch .Tensor ,
887+ shifts : torch .Tensor , # (N, 2) -> (dy, dx)
888+ pos_y : torch .Tensor ,
889+ pos_x : torch .Tensor ,
890+ extracted_box_size : tuple [int , int ],
891+ handle_bounds : Literal ["pad" , "error" ],
892+ padding_mode : Literal ["constant" , "reflect" , "replicate" ],
893+ padding_value : float ,
894+ ) -> torch .Tensor :
895+ """
896+ Process a single frame using *precomputed particle shifts*.
897+
898+ This function is safe for gradient checkpointing and contains no
899+ deformation-field evaluation.
900+
901+ Parameters
902+ ----------
903+ movie_frame : torch.Tensor
904+ Single movie frame (H, W)
905+ shifts : torch.Tensor
906+ Per-particle shifts with shape (N, 2) as (dy, dx)
907+ pos_y, pos_x : torch.Tensor
908+ Top-left extraction positions
909+ extracted_box_size : tuple[int, int]
910+ (box_h, box_w)
911+ handle_bounds, padding_mode, padding_value
912+ Passed through to cropping
913+
914+ Returns
915+ -------
916+ torch.Tensor
917+ Shifted FFTs with shape (N, box_h, box_w//2 + 1)
918+ """
919+ box_h , box_w = extracted_box_size
920+
921+ # Extract particle images
922+ cropped_images = get_cropped_image_regions (
923+ movie_frame ,
924+ pos_y ,
925+ pos_x ,
926+ extracted_box_size ,
927+ pos_reference = "top-left" ,
928+ handle_bounds = handle_bounds ,
929+ padding_mode = padding_mode ,
930+ padding_value = padding_value ,
931+ )
932+
933+ # FFT
934+ cropped_images_dft = torch .fft .rfftn ( # pylint: disable=not-callable
935+ cropped_images , dim = (- 2 , - 1 )
936+ )
937+
938+ # Fourier shift
939+ shifted_fft = fourier_shift_dft_2d (
940+ dft = cropped_images_dft ,
941+ image_shape = (box_h , box_w ),
942+ shifts = shifts ,
943+ rfft = True ,
944+ fftshifted = False ,
945+ )
946+
947+ return shifted_fft
948+
949+ def compute_frame_particle_shifts_from_deformation (
950+ self ,
878951 movie_frame : torch .Tensor ,
879952 deformation_field : CubicCatmullRomGrid3d ,
880953 normalized_t_value : torch .Tensor ,
881954 pixel_grid : torch .Tensor ,
882955 pixel_spacing : float ,
883956 pos_y_center : torch .Tensor ,
884957 pos_x_center : torch .Tensor ,
885- pos_y : torch .Tensor ,
886- pos_x : torch .Tensor ,
887- extracted_box_size : tuple [int , int ],
888958 gh : int ,
889959 gw : int ,
890- handle_bounds : Literal ["pad" , "error" ],
891- padding_mode : Literal ["constant" , "reflect" , "replicate" ],
892- padding_value : float ,
893960 ) -> torch .Tensor :
894- """Process a single frame with gradient checkpointing.
895-
896- This function extracts particles from a single movie frame, computes
897- the deformation-based shifts, and returns the Fourier-shifted FFTs.
961+ """
962+ Compute per-particle shifts for a single frame from a deformation field.
898963
899964 Parameters
900965 ----------
901966 movie_frame : torch.Tensor
902- Single frame from the movie.
967+ Single movie frame (H, W)
903968 deformation_field : CubicCatmullRomGrid3d
904969 The deformation field grid.
905970 normalized_t_value : torch.Tensor
906- The normalized time value for this frame (0 to 1) .
971+ The normalized time value for the frame.
907972 pixel_grid : torch.Tensor
908- Coordinate grid for the full image .
973+ The pixel grid tensor .
909974 pixel_spacing : float
910- Pixel size for this particle .
975+ The pixel spacing .
911976 pos_y_center : torch.Tensor
912- Y positions of particle centers .
977+ The center y position .
913978 pos_x_center : torch.Tensor
914- X positions of particle centers.
915- pos_y : torch.Tensor
916- Y positions for extraction (top-left).
917- pos_x : torch.Tensor
918- X positions for extraction (top-left).
919- extracted_box_size : tuple[int, int]
920- Size of boxes to extract (height, width).
979+ The center x position.
921980 gh : int
922- Grid height of deformation field.
981+ The height of the deformation field grid .
923982 gw : int
924- Grid width of deformation field.
925- handle_bounds : Literal["pad", "error"]
926- How to handle image bounds.
927- padding_mode : Literal["constant", "reflect", "replicate"]
928- Padding mode for extraction.
929- padding_value : float
930- Value for constant padding.
983+ The width of the deformation field grid.
931984
932985 Returns
933986 -------
934987 torch.Tensor
935- Shifted FFTs for all particles in this frame.
988+ Shifts with shape (N, 2) as (dy, dx)
936989 """
937- box_h , box_w = extracted_box_size
938-
939- # Evaluate deformation field at this time point
940990 frame_deformation_field = evaluate_deformation_field_at_t (
941991 deformation_field = deformation_field ,
942- t = normalized_t_value .item (), # Convert to float
992+ t = normalized_t_value .item (),
943993 grid_shape = (10 * gh , 10 * gw ),
944994 )
945995
946- # Compute pixel shifts
947996 pixel_shifts = get_pixel_shifts (
948997 frame = movie_frame ,
949998 pixel_spacing = pixel_spacing ,
950999 frame_deformation_grid = frame_deformation_field ,
9511000 pixel_grid = pixel_grid ,
9521001 )
9531002
954- # Extract shifts at particle centers
9551003 y_shifts = - pixel_shifts [pos_y_center , pos_x_center , 0 ]
9561004 x_shifts = - pixel_shifts [pos_y_center , pos_x_center , 1 ]
9571005
958- # Extract particles from this frame
959- cropped_images = get_cropped_image_regions (
960- movie_frame ,
961- pos_y ,
962- pos_x ,
963- extracted_box_size ,
964- pos_reference = "top-left" ,
965- handle_bounds = handle_bounds ,
966- padding_mode = padding_mode ,
967- padding_value = padding_value ,
968- )
969-
970- # Compute FFT of cropped images
971- cropped_images_dft = torch .fft .rfftn (cropped_images , dim = (- 2 , - 1 )) # pylint: disable=not-callable
972-
973- # Apply Fourier shift
974- shifted_fft = fourier_shift_dft_2d (
975- dft = cropped_images_dft ,
976- image_shape = (box_h , box_w ),
977- shifts = torch .stack ((y_shifts , x_shifts ), dim = - 1 ),
978- rfft = True ,
979- fftshifted = False ,
980- )
981-
982- return shifted_fft
1006+ return torch .stack ((y_shifts , x_shifts ), dim = - 1 )
9831007
9841008 # pylint: disable=too-many-arguments
9851009 # pylint: disable=too-many-positional-arguments
9861010 # pylint: disable=too-many-statements
1011+ # pylint: disable=too-many-branches
9871012 def construct_image_stack_from_movie (
9881013 self ,
9891014 movie : torch .Tensor ,
990- deformation_field : CubicCatmullRomGrid3d ,
1015+ deformation_field : CubicCatmullRomGrid3d | None = None ,
1016+ particle_shifts : torch .Tensor | None = None ,
9911017 pos_reference : Literal ["center" , "top-left" ] = "top-left" ,
9921018 handle_bounds : Literal ["pad" , "error" ] = "pad" ,
9931019 padding_mode : Literal ["constant" , "reflect" , "replicate" ] = "constant" ,
@@ -1003,8 +1029,13 @@ def construct_image_stack_from_movie(
10031029 ----------
10041030 movie : torch.Tensor
10051031 The movie tensor.
1006- deformation_field : CubicCatmullRomGrid3d
1032+ deformation_field : CubicCatmullRomGrid3d | None, optional
10071033 The deformation field grid.
1034+ particle_shifts : torch.Tensor | None, optional
1035+ The particle shifts to apply to the movie. If None, the particle shifts
1036+ are computed from the deformation field. If provided, the particle shifts
1037+ are used to shift the movie. One must be provided.
1038+ Shape is (T, N, 2) where T = number of frames, N = number of particles,
10081039 pos_reference : Literal["center", "top-left"], optional
10091040 The reference point for the positions, by default "top-left". If "center",
10101041 the boxes extracted are image[y - box_size // 2 : y + box_size // 2, ...].
@@ -1042,14 +1073,21 @@ def construct_image_stack_from_movie(
10421073 The stack of images with shape (N, H, W) where N is the number of particles
10431074 and (H, W) is the extracted box size.
10441075 """
1076+ if (deformation_field is None ) == (particle_shifts is None ):
1077+ raise ValueError (
1078+ "One of `deformation_field` or `particle_shifts` must be provided."
1079+ )
10451080 pixel_sizes = self .get_pixel_size ()
10461081 # Determine which position columns to use (refined if available)
10471082 y_col , x_col = self ._get_position_reference_columns ()
10481083 # Create an empty tensor to store the image stack
10491084 h , w = self .original_template_size
10501085 box_h , box_w = self .extracted_box_size
10511086 t , img_h , img_w = movie .shape
1052- _ , _ , gh , gw = deformation_field .data .shape
1087+ if deformation_field is not None :
1088+ _ , _ , gh , gw = deformation_field .data .shape
1089+ else :
1090+ gh = gw = 0
10531091 normalized_t = torch .linspace (0 , 1 , steps = t , device = movie .device )
10541092 pixel_grid = coordinate_grid (
10551093 image_shape = (img_h , img_w ),
@@ -1091,68 +1129,52 @@ def construct_image_stack_from_movie(
10911129 movie = movie - torch .mean (movie , dim = (- 2 , - 1 ), keepdim = True )
10921130
10931131 for frame_index , movie_frame in enumerate (movie ):
1094- if use_gradient_checkpointing :
1095- # Use gradient checkpointing to save memory
1096- shifted_fft = checkpoint (
1097- self ._process_single_frame_for_checkpoint ,
1132+ # ------------------------------------------------------------
1133+ # Obtain shifts (dy, dx) for this frame
1134+ # ------------------------------------------------------------
1135+ if particle_shifts is not None :
1136+ frame_shifts = particle_shifts [frame_index ] # (N, 2)
1137+ else :
1138+ frame_shifts = self .compute_frame_particle_shifts_from_deformation (
10981139 movie_frame = movie_frame ,
10991140 deformation_field = deformation_field ,
11001141 normalized_t_value = normalized_t [frame_index ],
11011142 pixel_grid = pixel_grid ,
11021143 pixel_spacing = pixel_sizes [0 ].item (),
11031144 pos_y_center = pos_y_center ,
11041145 pos_x_center = pos_x_center ,
1105- pos_y = pos_y ,
1106- pos_x = pos_x ,
1107- extracted_box_size = self .extracted_box_size ,
11081146 gh = gh ,
11091147 gw = gw ,
1110- handle_bounds = handle_bounds ,
1111- padding_mode = padding_mode ,
1112- padding_value = padding_value ,
1113- use_reentrant = False ,
1114- )
1115- else :
1116- # Process without checkpointing (for debugging)
1117- frame_deformation_field = evaluate_deformation_field_at_t (
1118- deformation_field = deformation_field ,
1119- t = normalized_t [frame_index ],
1120- grid_shape = (10 * gh , 10 * gw ),
1121- )
1122-
1123- pixel_shifts = get_pixel_shifts (
1124- frame = movie_frame ,
1125- pixel_spacing = pixel_sizes [0 ],
1126- frame_deformation_grid = frame_deformation_field ,
1127- pixel_grid = pixel_grid ,
11281148 )
11291149
1130- y_shifts = - pixel_shifts [pos_y_center , pos_x_center , 0 ]
1131- x_shifts = - pixel_shifts [pos_y_center , pos_x_center , 1 ]
1132-
1133- cropped_images = get_cropped_image_regions (
1150+ # ------------------------------------------------------------
1151+ # Apply shifts + FFT (checkpointed)
1152+ # ------------------------------------------------------------
1153+ if use_gradient_checkpointing :
1154+ shifted_fft = checkpoint (
1155+ self ._process_single_frame_with_shifts_checkpoint ,
11341156 movie_frame ,
1157+ frame_shifts ,
11351158 pos_y ,
11361159 pos_x ,
11371160 self .extracted_box_size ,
1138- pos_reference = "top-left" ,
1161+ handle_bounds ,
1162+ padding_mode ,
1163+ padding_value ,
1164+ use_reentrant = False ,
1165+ )
1166+ else :
1167+ shifted_fft = self ._process_single_frame_with_shifts_checkpoint (
1168+ movie_frame = movie_frame ,
1169+ shifts = frame_shifts ,
1170+ pos_y = pos_y ,
1171+ pos_x = pos_x ,
1172+ extracted_box_size = self .extracted_box_size ,
11391173 handle_bounds = handle_bounds ,
11401174 padding_mode = padding_mode ,
11411175 padding_value = padding_value ,
11421176 )
11431177
1144- cropped_images_dft = torch .fft .rfftn ( # pylint: disable=not-callable
1145- cropped_images , dim = (- 2 , - 1 )
1146- )
1147-
1148- shifted_fft = fourier_shift_dft_2d (
1149- dft = cropped_images_dft ,
1150- image_shape = (box_h , box_w ),
1151- shifts = torch .stack ((y_shifts , x_shifts ), dim = - 1 ),
1152- rfft = True ,
1153- fftshifted = False ,
1154- )
1155-
11561178 # Store the shifted FFTs
11571179 aligned_particle_movies_rfft [:, frame_index ] = shifted_fft
11581180
0 commit comments