Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
5f683ed
feat: Can now refine from movie + def field
jdickerson95 Nov 9, 2025
e70fd18
fix pylint errors
jdickerson95 Nov 9, 2025
066f129
feat: differentiable refine added
jdickerson95 Nov 11, 2025
522b503
differentiable refine working for motion polishing
jdickerson95 Nov 21, 2025
578a0e5
refactor and changes from code review
jdickerson95 Nov 22, 2025
b22acc6
fix linting errors
jdickerson95 Nov 22, 2025
f6e4c28
fix pylint 3.10 error
jdickerson95 Nov 22, 2025
e3d225b
fix: bug when nothing provided
jdickerson95 Dec 15, 2025
8cbe944
fix: patch error in constrained from particle stack changes
jdickerson95 Dec 28, 2025
34c3795
refactor util.py
jdickerson95 Dec 28, 2025
9bcef75
fix linting errors
jdickerson95 Dec 28, 2025
87ac772
feat: add aberrations to ctf
jdickerson95 Dec 28, 2025
0d2f613
feat: aniso mag stretches projections (not jsut ctf freqs)
jdickerson95 Dec 28, 2025
3a3d0b7
Update for fourier-slice api change
jdickerson95 Dec 29, 2025
3b860b2
fix linting error
jdickerson95 Dec 29, 2025
679f4c9
fix: refine runs if no values for aberrations in df
jdickerson95 Jan 3, 2026
da2dd00
bodge fix: clamp particle stack out of bounds
jdickerson95 Jan 7, 2026
12005ad
Fix for changes in torch-fourier-slice
jdickerson95 Jan 7, 2026
2f76119
Refactor from code review
Jan 9, 2026
66b5fc4
fix: prevent zeros being put in results everywhere
jdickerson95 Jan 11, 2026
e5bcab4
Merge branch 'jd_ctf_aberrations' of https://github.com/Lucaslab-Berk…
jdickerson95 Jan 11, 2026
f2d41e0
Can use particle shifts directly to get particle stacks from movie
jdickerson95 Jan 16, 2026
34007d3
Allow movie refine and global filtering
jdickerson95 Jan 18, 2026
00a641b
feat: movie refine from particle shifts directly
jdickerson95 Jan 18, 2026
315db2f
ignore pylint errors
jdickerson95 Jan 18, 2026
1196c9f
update tests
jdickerson95 Jan 18, 2026
38f115d
Merge pull request #114 from Lucaslab-Berkeley/jd_particle_shifts
jdickerson95 Jan 29, 2026
e1ddd7c
Merge pull request #110 from Lucaslab-Berkeley/jd_ctf_aberrations
jdickerson95 Jan 29, 2026
deaa831
Merge pull request #106 from Lucaslab-Berkeley/jd_differentiable_refine
jdickerson95 Jan 29, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion programs/refine_template/refine_template_example_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,11 @@ preprocessing_filters:
low_freq_cutoff: null
computational_config:
gpu_ids: 0
num_cpus: 1
num_cpus: 1
apply_global_filtering: true
movie_config:
enabled: false
movie_path: path/to//aligned_movie.mrc
deformation_field_path: path/to/deformation_field.csv
pre_exposure: 0.0
fluence_per_frame: 1.0
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ dependencies = [
"ttsim3d>=v0.4.0",
"lmfit",
"zenodo-get",
"torch-fourier-shift",
"torch-motion-correction>=0.0.4",
"torch-grid-utils>=v0.0.9"
]

[tool.hatch.metadata]
Expand Down
2 changes: 2 additions & 0 deletions src/leopard_em/pydantic_models/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
WhiteningFilterConfig,
)
from .defocus_search import DefocusSearchConfig
from .movie_config import MovieConfig
from .orientation_search import (
ConstrainedOrientationConfig,
MultipleOrientationConfig,
Expand All @@ -34,4 +35,5 @@
"RefineOrientationConfig",
"WhiteningFilterConfig",
"ConstrainedOrientationConfig",
"MovieConfig",
]
45 changes: 45 additions & 0 deletions src/leopard_em/pydantic_models/config/movie_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Serialization and validation of movie parameters for 2DTM."""

import torch
from torch_motion_correction.data_io import read_deformation_field_from_csv

from leopard_em.pydantic_models.custom_types import BaseModel2DTM
from leopard_em.utils.data_io import load_mrc_volume


class MovieConfig(BaseModel2DTM):
"""Serialization and validation of movie parameters for 2DTM.

Attributes
----------
enabled: bool
Whether to enable movie configuration.
movie_path: str
Path to the movie file.
deformation_field_path: str
Path to the deformation field file.
pre_exposure: float
Pre-exposure time in seconds.
fluence_per_frame: float
Dose per frame in electrons per pixel.
"""

enabled: bool = False
movie_path: str = ""
deformation_field_path: str = ""
pre_exposure: float = 0.0
fluence_per_frame: float = 1.0

@property
def movie(self) -> torch.Tensor:
"""Get the movie tensor."""
if not self.enabled:
return None
return load_mrc_volume(self.movie_path)

@property
def deformation_field(self) -> torch.Tensor:
"""Get the deformation field tensor."""
if not self.enabled:
return None
return read_deformation_field_from_csv(self.deformation_field_path)
155 changes: 154 additions & 1 deletion src/leopard_em/pydantic_models/data_structures/particle_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,17 @@
import pandas as pd
import torch
from pydantic import ConfigDict
from torch_fourier_shift import fourier_shift_dft_2d
from torch_grid_utils import coordinate_grid
from torch_motion_correction.correct_motion import get_pixel_shifts
from torch_motion_correction.deformation_field_utils import (
evaluate_deformation_field_at_t,
)

from leopard_em.pydantic_models.config import PreprocessingFilters
from leopard_em.pydantic_models.custom_types import BaseModel2DTM, ExcludedTensor
from leopard_em.pydantic_models.formats import MATCH_TEMPLATE_DF_COLUMN_ORDER
from leopard_em.pydantic_models.utils import dose_weight
from leopard_em.utils.data_io import load_mrc_image

TORCH_TO_NUMPY_PADDING_MODE = {
Expand Down Expand Up @@ -634,7 +641,6 @@ def construct_projective_filters(
"""
# Create an empty tensor to store the filter stack
filter_stack = torch.zeros((self.num_particles, *output_shape))

# Verify that the number of images matches the number of indices
if images_dft.shape[0] != len(indices):
raise ValueError(
Expand Down Expand Up @@ -844,3 +850,150 @@ def get_dataframe_copy(self) -> pd.DataFrame:
A copy of the underlying DataFrame
"""
return self._df.copy()

def construct_image_stack_from_movie(
self,
movie: torch.Tensor,
deformation_field: torch.Tensor,
pos_reference: Literal["center", "top-left"] = "top-left",
handle_bounds: Literal["pad", "error"] = "pad",
padding_mode: Literal["constant", "reflect", "replicate"] = "constant",
padding_value: float = 0.0,
pre_exposure: float = 0.0,
fluence_per_frame: float = 0.0,
) -> torch.Tensor:
"""Construct a stack of images from a movie file.

Parameters
----------
movie : torch.Tensor
The movie tensor.
deformation_field : torch.Tensor
The deformation field tensor.
pos_reference : Literal["center", "top-left"], optional
The reference point for the positions, by default "top-left". If "center",
the boxes extracted are image[y - box_size // 2 : y + box_size // 2, ...].
If "top-left", the boxes will be image[y : y + box_size, ...].
handle_bounds : Literal["pad", "error"], optional
How to handle the bounds of the image, by default "pad". If "pad", the image
will be padded with the padding value based on the padding mode.
If "error", an error will be raised if any region exceeds the image bounds.
Note clipping is not supported
since returned stack may have inhomogeneous sizes.
padding_mode : Literal["constant", "reflect", "replicate"], optional
The padding mode to use when padding the image, by default "constant".
"constant" pads with the value `padding_value`, "reflect" pads with the
reflection of the image, and "replicate" pads with the last pixel
of the image. These match the modes available in `torch.nn.functional.pad`.
padding_value : float, optional
The value to use for padding when `padding_mode` is "constant",
by default 0.0.
pre_exposure : float, optional
The pre-exposure time in seconds, by default 0.0.
fluence_per_frame : float, optional
The dose per frame in electrons per pixel, by default 0.0.

Returns
-------
torch.Tensor
The stack of images with shape (N, H, W) where N is the number of particles
and (H, W) is the extracted box size.
"""
pixel_sizes = self.get_pixel_size()
# Determine which position columns to use (refined if available)
y_col, x_col = self._get_position_reference_columns()
# Create an empty tensor to store the image stack
h, w = self.original_template_size
box_h, box_w = self.extracted_box_size
t, img_h, img_w = movie.shape
_, _, gh, gw = deformation_field.shape
normalized_t = torch.linspace(0, 1, steps=t, device=movie.device)
pixel_grid = coordinate_grid(
image_shape=(img_h, img_w),
device=movie.device,
)
# Find the indexes in the DataFrame that correspond to each unique image
paticle_indexes = self._df.index.tolist()
pos_y = self._df.loc[paticle_indexes, y_col].to_numpy()
pos_x = self._df.loc[paticle_indexes, x_col].to_numpy()
# If the position reference is "top-left", shift (x, y) by half the original
# template width/height so reference is now in the center
if pos_reference == "center":
pos_y = pos_y - h // 2
pos_x = pos_x - w // 2

pos_y_center = pos_y + h // 2
pos_x_center = pos_x + w // 2
pos_y -= (box_h - h) // 2
pos_x -= (box_w - w) // 2
pos_y = torch.tensor(pos_y)
pos_x = torch.tensor(pos_x)
pos_y_center = torch.tensor(pos_y_center)
pos_x_center = torch.tensor(pos_x_center)
aligned_particle_movies_rfft = torch.zeros(
(self.num_particles, t, box_h, box_w // 2 + 1), dtype=torch.complex64
)
# set frames mean zero
movie = movie - torch.mean(movie, dim=(-2, -1), keepdim=True)
for frame_index, movie_frame in enumerate(movie):
print(f"Extracting particle images for frame {frame_index} of {t}")
# If memory becomes an issue, do this in batches of particles
# Get the shift in the deformation field for the center pixel
frame_deformation_field = evaluate_deformation_field_at_t(
deformation_field=deformation_field,
t=normalized_t[frame_index],
grid_shape=(10 * gh, 10 * gw),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this extra value of 10? I know its because of oversampling to then do interpolation, but it's not immediately obvious why the grid dimensions are being multiplied.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same as in torch motion correction.
It is expensive to evaluate the cubic spline grid over all pixels. As an approximation, we evaluate it over a grid 10x10 larger than the control points grid, and then use bicubic interpolation to go to the full per-pixel shift.
10x was chosen because it seemed to work; more work would be needed to characterize sensitivity vs speed properly.

)
pixel_shifts = get_pixel_shifts(
frame=movie_frame,
pixel_spacing=pixel_sizes[0],
frame_deformation_grid=frame_deformation_field,
pixel_grid=pixel_grid,
) # (H, W, yx)
y_shifts = -pixel_shifts[
pos_y_center, pos_x_center, 0
] # y-component of shifts
x_shifts = -pixel_shifts[
pos_y_center, pos_x_center, 1
] # x-component of shifts
print(
f"frame {frame_index}: y_shifts: {y_shifts[0]}, x_shifts: {x_shifts[0]}"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extra print statement which should not be there

# Extract particles from this frame
cropped_images = get_cropped_image_regions(
movie_frame,
pos_y,
pos_x,
self.extracted_box_size,
pos_reference="top-left",
handle_bounds=handle_bounds,
padding_mode=padding_mode,
padding_value=padding_value,
)
# Now Fourier shift the cropped images (use torch-fourier-shift)
cropped_images_dft = torch.fft.rfftn(cropped_images, dim=(-2, -1)) # pylint: disable=not-callable
shifted_fft = fourier_shift_dft_2d(
dft=cropped_images_dft,
image_shape=(box_h, box_w),
shifts=torch.stack((y_shifts, x_shifts), dim=-1), # (N, 2) shifts
rfft=True,
fftshifted=False,
)
# store them in a tensor shape (N, t, box_h, box_w)
aligned_particle_movies_rfft[:, frame_index] = shifted_fft
# Dose weight the aligned particle images
aligned_particle_images = torch.zeros((self.num_particles, box_h, box_w))
for particle_index in range(self.num_particles):
print(f"Dose weighting particle {particle_index} of {self.num_particles}")
particle_dft = aligned_particle_movies_rfft[particle_index]
dw_sum = dose_weight(
movie_fft=particle_dft,
pixel_size=pixel_sizes[particle_index],
pre_exposure=pre_exposure,
fluence_per_frame=fluence_per_frame,
voltage=self._df["voltage"].to_numpy()[particle_index],
) # (box_h, box_w)
aligned_particle_images[particle_index] = dw_sum

self.image_stack = aligned_particle_images
return aligned_particle_images
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Pydantic model for running the refine template program."""

import warnings
from typing import Any, ClassVar

import numpy as np
Expand All @@ -9,6 +10,7 @@
from leopard_em.pydantic_models.config import (
ComputationalConfigRefine,
DefocusSearchConfig,
MovieConfig,
PixelSizeSearchConfig,
PreprocessingFilters,
RefineOrientationConfig,
Expand Down Expand Up @@ -45,6 +47,8 @@ class RefineTemplateManager(BaseModel2DTM):
Default is True.
template_volume : ExcludedTensor
The template volume tensor (excluded from serialization).
movie_config : MovieConfig
Configuration for the movie.

Methods
-------
Expand All @@ -66,7 +70,7 @@ class RefineTemplateManager(BaseModel2DTM):
orientation_refinement_config: RefineOrientationConfig
preprocessing_filters: PreprocessingFilters
computational_config: ComputationalConfigRefine

movie_config: MovieConfig
apply_global_filtering: bool = True

# Excluded tensors
Expand Down Expand Up @@ -109,6 +113,18 @@ def make_backend_core_function_kwargs(
# The relative pixel size values to search over
pixel_size_offsets = self.pixel_size_refinement_config.pixel_size_values

# Load movie and deformation field
movie = self.movie_config.movie
deformation_field = self.movie_config.deformation_field

if movie is not None and self.apply_global_filtering:
warnings.warn(
"Global filtering cannot be applied with movie refinement. "
"Disabling apply_global_filtering.",
stacklevel=2,
)
self.apply_global_filtering = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be preferable to raise a validation error through Pydantic when the movie is not None or if self.movie_config.enabled == True and self.apply_global_filtering == True. That way there are no implicit modifications to a configuration file


# Use the common utility function to set up the backend kwargs
# pylint: disable=duplicate-code
return setup_particle_backend_kwargs(
Expand All @@ -120,6 +136,10 @@ def make_backend_core_function_kwargs(
defocus_offsets=defocus_offsets,
pixel_size_offsets=pixel_size_offsets,
apply_global_filtering=self.apply_global_filtering,
movie=movie,
deformation_field=deformation_field,
pre_exposure=self.movie_config.pre_exposure,
fluence_per_frame=self.movie_config.fluence_per_frame,
device_list=self.computational_config.gpu_devices,
)

Expand Down
Loading