Skip to content

Commit 38f115d

Browse files
authored
Merge pull request #114 from Lucaslab-Berkeley/jd_particle_shifts
Jd particle shifts
2 parents e5bcab4 + 1196c9f commit 38f115d

File tree

8 files changed

+262
-164
lines changed

8 files changed

+262
-164
lines changed

src/leopard_em/pydantic_models/config/movie_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ class MovieConfig(BaseModel2DTM):
1818
Path to the movie file.
1919
deformation_field_path: str
2020
Path to the deformation field file.
21+
particle_shifts_path: str
22+
Path to the particle shifts CSV file. If provided, takes precedence over
23+
deformation_field_path. The CSV should have columns: particle_index, frame,
24+
y_shift, x_shift.
2125
pre_exposure: float
2226
Pre-exposure time in seconds.
2327
fluence_per_frame: float
@@ -27,6 +31,7 @@ class MovieConfig(BaseModel2DTM):
2731
enabled: bool = False
2832
movie_path: str = ""
2933
deformation_field_path: str = ""
34+
particle_shifts_path: str = ""
3035
pre_exposure: float = 0.0
3136
fluence_per_frame: float = 1.0
3237

@@ -42,4 +47,7 @@ def deformation_field(self) -> torch.Tensor:
4247
"""Get the deformation field tensor."""
4348
if not self.enabled:
4449
return None
50+
if self.particle_shifts_path:
51+
# Particle shifts take precedence, so don't load deformation field
52+
return None
4553
return read_deformation_field_from_csv(self.deformation_field_path)

src/leopard_em/pydantic_models/data_structures/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44
from .particle_stack import ParticleStack
55

66
__all__ = [
7-
"ParticleStack",
87
"OpticsGroup",
8+
"ParticleStack",
99
]

src/leopard_em/pydantic_models/data_structures/particle_stack.py

Lines changed: 132 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -251,17 +251,25 @@ def _get_cropped_image_regions_torch(
251251

252252
regions = []
253253
for y, x in zip(pos_y, pos_x):
254-
# Check bounds and clamp to edges if out of bounds
255-
# Convert to Python ints for comparison and clamping
254+
# Convert to Python ints for comparison
256255
y = int(y.item() if hasattr(y, "item") else y)
257256
x = int(x.item() if hasattr(x, "item") else x)
258257
original_y, original_x = y, x
258+
259+
# Check bounds
259260
if (
260261
y < 0
261262
or x < 0
262263
or y + box_size[0] > image.shape[0]
263264
or x + box_size[1] > image.shape[1]
264265
):
266+
if handle_bounds == "error":
267+
raise IndexError(
268+
f"Region bounds [{original_y}:{original_y + box_size[0]}, "
269+
f"{original_x}:{original_x + box_size[1]}] exceed "
270+
f"image dimensions {image.shape}"
271+
)
272+
# For "pad" mode, warn and clamp coordinates
265273
warnings.warn(
266274
f"Region bounds [{original_y}:{original_y + box_size[0]}, "
267275
f"{original_x}:{original_x + box_size[1]}] exceed "
@@ -874,120 +882,138 @@ def get_dataframe_copy(self) -> pd.DataFrame:
874882
@staticmethod
875883
# pylint: disable=too-many-arguments
876884
# pylint: disable=too-many-positional-arguments
877-
def _process_single_frame_for_checkpoint(
885+
def _process_single_frame_with_shifts_checkpoint(
886+
movie_frame: torch.Tensor,
887+
shifts: torch.Tensor, # (N, 2) -> (dy, dx)
888+
pos_y: torch.Tensor,
889+
pos_x: torch.Tensor,
890+
extracted_box_size: tuple[int, int],
891+
handle_bounds: Literal["pad", "error"],
892+
padding_mode: Literal["constant", "reflect", "replicate"],
893+
padding_value: float,
894+
) -> torch.Tensor:
895+
"""
896+
Process a single frame using *precomputed particle shifts*.
897+
898+
This function is safe for gradient checkpointing and contains no
899+
deformation-field evaluation.
900+
901+
Parameters
902+
----------
903+
movie_frame : torch.Tensor
904+
Single movie frame (H, W)
905+
shifts : torch.Tensor
906+
Per-particle shifts with shape (N, 2) as (dy, dx)
907+
pos_y, pos_x : torch.Tensor
908+
Top-left extraction positions
909+
extracted_box_size : tuple[int, int]
910+
(box_h, box_w)
911+
handle_bounds, padding_mode, padding_value
912+
Passed through to cropping
913+
914+
Returns
915+
-------
916+
torch.Tensor
917+
Shifted FFTs with shape (N, box_h, box_w//2 + 1)
918+
"""
919+
box_h, box_w = extracted_box_size
920+
921+
# Extract particle images
922+
cropped_images = get_cropped_image_regions(
923+
movie_frame,
924+
pos_y,
925+
pos_x,
926+
extracted_box_size,
927+
pos_reference="top-left",
928+
handle_bounds=handle_bounds,
929+
padding_mode=padding_mode,
930+
padding_value=padding_value,
931+
)
932+
933+
# FFT
934+
cropped_images_dft = torch.fft.rfftn( # pylint: disable=not-callable
935+
cropped_images, dim=(-2, -1)
936+
)
937+
938+
# Fourier shift
939+
shifted_fft = fourier_shift_dft_2d(
940+
dft=cropped_images_dft,
941+
image_shape=(box_h, box_w),
942+
shifts=shifts,
943+
rfft=True,
944+
fftshifted=False,
945+
)
946+
947+
return shifted_fft
948+
949+
def compute_frame_particle_shifts_from_deformation(
950+
self,
878951
movie_frame: torch.Tensor,
879952
deformation_field: CubicCatmullRomGrid3d,
880953
normalized_t_value: torch.Tensor,
881954
pixel_grid: torch.Tensor,
882955
pixel_spacing: float,
883956
pos_y_center: torch.Tensor,
884957
pos_x_center: torch.Tensor,
885-
pos_y: torch.Tensor,
886-
pos_x: torch.Tensor,
887-
extracted_box_size: tuple[int, int],
888958
gh: int,
889959
gw: int,
890-
handle_bounds: Literal["pad", "error"],
891-
padding_mode: Literal["constant", "reflect", "replicate"],
892-
padding_value: float,
893960
) -> torch.Tensor:
894-
"""Process a single frame with gradient checkpointing.
895-
896-
This function extracts particles from a single movie frame, computes
897-
the deformation-based shifts, and returns the Fourier-shifted FFTs.
961+
"""
962+
Compute per-particle shifts for a single frame from a deformation field.
898963
899964
Parameters
900965
----------
901966
movie_frame : torch.Tensor
902-
Single frame from the movie.
967+
Single movie frame (H, W)
903968
deformation_field : CubicCatmullRomGrid3d
904969
The deformation field grid.
905970
normalized_t_value : torch.Tensor
906-
The normalized time value for this frame (0 to 1).
971+
The normalized time value for the frame.
907972
pixel_grid : torch.Tensor
908-
Coordinate grid for the full image.
973+
The pixel grid tensor.
909974
pixel_spacing : float
910-
Pixel size for this particle.
975+
The pixel spacing.
911976
pos_y_center : torch.Tensor
912-
Y positions of particle centers.
977+
The center y position.
913978
pos_x_center : torch.Tensor
914-
X positions of particle centers.
915-
pos_y : torch.Tensor
916-
Y positions for extraction (top-left).
917-
pos_x : torch.Tensor
918-
X positions for extraction (top-left).
919-
extracted_box_size : tuple[int, int]
920-
Size of boxes to extract (height, width).
979+
The center x position.
921980
gh : int
922-
Grid height of deformation field.
981+
The height of the deformation field grid.
923982
gw : int
924-
Grid width of deformation field.
925-
handle_bounds : Literal["pad", "error"]
926-
How to handle image bounds.
927-
padding_mode : Literal["constant", "reflect", "replicate"]
928-
Padding mode for extraction.
929-
padding_value : float
930-
Value for constant padding.
983+
The width of the deformation field grid.
931984
932985
Returns
933986
-------
934987
torch.Tensor
935-
Shifted FFTs for all particles in this frame.
988+
Shifts with shape (N, 2) as (dy, dx)
936989
"""
937-
box_h, box_w = extracted_box_size
938-
939-
# Evaluate deformation field at this time point
940990
frame_deformation_field = evaluate_deformation_field_at_t(
941991
deformation_field=deformation_field,
942-
t=normalized_t_value.item(), # Convert to float
992+
t=normalized_t_value.item(),
943993
grid_shape=(10 * gh, 10 * gw),
944994
)
945995

946-
# Compute pixel shifts
947996
pixel_shifts = get_pixel_shifts(
948997
frame=movie_frame,
949998
pixel_spacing=pixel_spacing,
950999
frame_deformation_grid=frame_deformation_field,
9511000
pixel_grid=pixel_grid,
9521001
)
9531002

954-
# Extract shifts at particle centers
9551003
y_shifts = -pixel_shifts[pos_y_center, pos_x_center, 0]
9561004
x_shifts = -pixel_shifts[pos_y_center, pos_x_center, 1]
9571005

958-
# Extract particles from this frame
959-
cropped_images = get_cropped_image_regions(
960-
movie_frame,
961-
pos_y,
962-
pos_x,
963-
extracted_box_size,
964-
pos_reference="top-left",
965-
handle_bounds=handle_bounds,
966-
padding_mode=padding_mode,
967-
padding_value=padding_value,
968-
)
969-
970-
# Compute FFT of cropped images
971-
cropped_images_dft = torch.fft.rfftn(cropped_images, dim=(-2, -1)) # pylint: disable=not-callable
972-
973-
# Apply Fourier shift
974-
shifted_fft = fourier_shift_dft_2d(
975-
dft=cropped_images_dft,
976-
image_shape=(box_h, box_w),
977-
shifts=torch.stack((y_shifts, x_shifts), dim=-1),
978-
rfft=True,
979-
fftshifted=False,
980-
)
981-
982-
return shifted_fft
1006+
return torch.stack((y_shifts, x_shifts), dim=-1)
9831007

9841008
# pylint: disable=too-many-arguments
9851009
# pylint: disable=too-many-positional-arguments
9861010
# pylint: disable=too-many-statements
1011+
# pylint: disable=too-many-branches
9871012
def construct_image_stack_from_movie(
9881013
self,
9891014
movie: torch.Tensor,
990-
deformation_field: CubicCatmullRomGrid3d,
1015+
deformation_field: CubicCatmullRomGrid3d | None = None,
1016+
particle_shifts: torch.Tensor | None = None,
9911017
pos_reference: Literal["center", "top-left"] = "top-left",
9921018
handle_bounds: Literal["pad", "error"] = "pad",
9931019
padding_mode: Literal["constant", "reflect", "replicate"] = "constant",
@@ -1003,8 +1029,13 @@ def construct_image_stack_from_movie(
10031029
----------
10041030
movie : torch.Tensor
10051031
The movie tensor.
1006-
deformation_field : CubicCatmullRomGrid3d
1032+
deformation_field : CubicCatmullRomGrid3d | None, optional
10071033
The deformation field grid.
1034+
particle_shifts : torch.Tensor | None, optional
1035+
The particle shifts to apply to the movie. If None, the particle shifts
1036+
are computed from the deformation field. If provided, the particle shifts
1037+
are used to shift the movie. One must be provided.
1038+
Shape is (T, N, 2) where T = number of frames, N = number of particles,
10081039
pos_reference : Literal["center", "top-left"], optional
10091040
The reference point for the positions, by default "top-left". If "center",
10101041
the boxes extracted are image[y - box_size // 2 : y + box_size // 2, ...].
@@ -1042,14 +1073,21 @@ def construct_image_stack_from_movie(
10421073
The stack of images with shape (N, H, W) where N is the number of particles
10431074
and (H, W) is the extracted box size.
10441075
"""
1076+
if (deformation_field is None) == (particle_shifts is None):
1077+
raise ValueError(
1078+
"One of `deformation_field` or `particle_shifts` must be provided."
1079+
)
10451080
pixel_sizes = self.get_pixel_size()
10461081
# Determine which position columns to use (refined if available)
10471082
y_col, x_col = self._get_position_reference_columns()
10481083
# Create an empty tensor to store the image stack
10491084
h, w = self.original_template_size
10501085
box_h, box_w = self.extracted_box_size
10511086
t, img_h, img_w = movie.shape
1052-
_, _, gh, gw = deformation_field.data.shape
1087+
if deformation_field is not None:
1088+
_, _, gh, gw = deformation_field.data.shape
1089+
else:
1090+
gh = gw = 0
10531091
normalized_t = torch.linspace(0, 1, steps=t, device=movie.device)
10541092
pixel_grid = coordinate_grid(
10551093
image_shape=(img_h, img_w),
@@ -1091,68 +1129,52 @@ def construct_image_stack_from_movie(
10911129
movie = movie - torch.mean(movie, dim=(-2, -1), keepdim=True)
10921130

10931131
for frame_index, movie_frame in enumerate(movie):
1094-
if use_gradient_checkpointing:
1095-
# Use gradient checkpointing to save memory
1096-
shifted_fft = checkpoint(
1097-
self._process_single_frame_for_checkpoint,
1132+
# ------------------------------------------------------------
1133+
# Obtain shifts (dy, dx) for this frame
1134+
# ------------------------------------------------------------
1135+
if particle_shifts is not None:
1136+
frame_shifts = particle_shifts[frame_index] # (N, 2)
1137+
else:
1138+
frame_shifts = self.compute_frame_particle_shifts_from_deformation(
10981139
movie_frame=movie_frame,
10991140
deformation_field=deformation_field,
11001141
normalized_t_value=normalized_t[frame_index],
11011142
pixel_grid=pixel_grid,
11021143
pixel_spacing=pixel_sizes[0].item(),
11031144
pos_y_center=pos_y_center,
11041145
pos_x_center=pos_x_center,
1105-
pos_y=pos_y,
1106-
pos_x=pos_x,
1107-
extracted_box_size=self.extracted_box_size,
11081146
gh=gh,
11091147
gw=gw,
1110-
handle_bounds=handle_bounds,
1111-
padding_mode=padding_mode,
1112-
padding_value=padding_value,
1113-
use_reentrant=False,
1114-
)
1115-
else:
1116-
# Process without checkpointing (for debugging)
1117-
frame_deformation_field = evaluate_deformation_field_at_t(
1118-
deformation_field=deformation_field,
1119-
t=normalized_t[frame_index],
1120-
grid_shape=(10 * gh, 10 * gw),
1121-
)
1122-
1123-
pixel_shifts = get_pixel_shifts(
1124-
frame=movie_frame,
1125-
pixel_spacing=pixel_sizes[0],
1126-
frame_deformation_grid=frame_deformation_field,
1127-
pixel_grid=pixel_grid,
11281148
)
11291149

1130-
y_shifts = -pixel_shifts[pos_y_center, pos_x_center, 0]
1131-
x_shifts = -pixel_shifts[pos_y_center, pos_x_center, 1]
1132-
1133-
cropped_images = get_cropped_image_regions(
1150+
# ------------------------------------------------------------
1151+
# Apply shifts + FFT (checkpointed)
1152+
# ------------------------------------------------------------
1153+
if use_gradient_checkpointing:
1154+
shifted_fft = checkpoint(
1155+
self._process_single_frame_with_shifts_checkpoint,
11341156
movie_frame,
1157+
frame_shifts,
11351158
pos_y,
11361159
pos_x,
11371160
self.extracted_box_size,
1138-
pos_reference="top-left",
1161+
handle_bounds,
1162+
padding_mode,
1163+
padding_value,
1164+
use_reentrant=False,
1165+
)
1166+
else:
1167+
shifted_fft = self._process_single_frame_with_shifts_checkpoint(
1168+
movie_frame=movie_frame,
1169+
shifts=frame_shifts,
1170+
pos_y=pos_y,
1171+
pos_x=pos_x,
1172+
extracted_box_size=self.extracted_box_size,
11391173
handle_bounds=handle_bounds,
11401174
padding_mode=padding_mode,
11411175
padding_value=padding_value,
11421176
)
11431177

1144-
cropped_images_dft = torch.fft.rfftn( # pylint: disable=not-callable
1145-
cropped_images, dim=(-2, -1)
1146-
)
1147-
1148-
shifted_fft = fourier_shift_dft_2d(
1149-
dft=cropped_images_dft,
1150-
image_shape=(box_h, box_w),
1151-
shifts=torch.stack((y_shifts, x_shifts), dim=-1),
1152-
rfft=True,
1153-
fftshifted=False,
1154-
)
1155-
11561178
# Store the shifted FFTs
11571179
aligned_particle_movies_rfft[:, frame_index] = shifted_fft
11581180

0 commit comments

Comments
 (0)