-
Notifications
You must be signed in to change notification settings - Fork 8
Add ability to refine_template from a movie #104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
5f683ed
e70fd18
066f129
522b503
578a0e5
b22acc6
f6e4c28
e3d225b
8cbe944
34c3795
9bcef75
87ac772
0d2f613
3a3d0b7
3b860b2
679f4c9
da2dd00
12005ad
2f76119
66b5fc4
e5bcab4
f2d41e0
34007d3
00a641b
315db2f
1196c9f
38f115d
e1ddd7c
deaa831
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| """Serialization and validation of movie parameters for 2DTM.""" | ||
|
|
||
| import torch | ||
| from torch_motion_correction.data_io import read_deformation_field_from_csv | ||
|
|
||
| from leopard_em.pydantic_models.custom_types import BaseModel2DTM | ||
| from leopard_em.utils.data_io import load_mrc_volume | ||
|
|
||
|
|
||
| class MovieConfig(BaseModel2DTM): | ||
| """Serialization and validation of movie parameters for 2DTM. | ||
|
|
||
| Attributes | ||
| ---------- | ||
| enabled: bool | ||
| Whether to enable movie configuration. | ||
| movie_path: str | ||
| Path to the movie file. | ||
| deformation_field_path: str | ||
| Path to the deformation field file. | ||
| pre_exposure: float | ||
| Pre-exposure time in seconds. | ||
| fluence_per_frame: float | ||
| Dose per frame in electrons per pixel. | ||
| """ | ||
|
|
||
| enabled: bool = False | ||
| movie_path: str = "" | ||
| deformation_field_path: str = "" | ||
| pre_exposure: float = 0.0 | ||
| fluence_per_frame: float = 1.0 | ||
|
|
||
| @property | ||
| def movie(self) -> torch.Tensor: | ||
| """Get the movie tensor.""" | ||
| if not self.enabled: | ||
| return None | ||
| return load_mrc_volume(self.movie_path) | ||
|
|
||
| @property | ||
| def deformation_field(self) -> torch.Tensor: | ||
| """Get the deformation field tensor.""" | ||
| if not self.enabled: | ||
| return None | ||
| return read_deformation_field_from_csv(self.deformation_field_path) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,10 +7,17 @@ | |
| import pandas as pd | ||
| import torch | ||
| from pydantic import ConfigDict | ||
| from torch_fourier_shift import fourier_shift_dft_2d | ||
| from torch_grid_utils import coordinate_grid | ||
| from torch_motion_correction.correct_motion import get_pixel_shifts | ||
| from torch_motion_correction.deformation_field_utils import ( | ||
| evaluate_deformation_field_at_t, | ||
| ) | ||
|
|
||
| from leopard_em.pydantic_models.config import PreprocessingFilters | ||
| from leopard_em.pydantic_models.custom_types import BaseModel2DTM, ExcludedTensor | ||
| from leopard_em.pydantic_models.formats import MATCH_TEMPLATE_DF_COLUMN_ORDER | ||
| from leopard_em.pydantic_models.utils import dose_weight | ||
| from leopard_em.utils.data_io import load_mrc_image | ||
|
|
||
| TORCH_TO_NUMPY_PADDING_MODE = { | ||
|
|
@@ -634,7 +641,6 @@ def construct_projective_filters( | |
| """ | ||
| # Create an empty tensor to store the filter stack | ||
| filter_stack = torch.zeros((self.num_particles, *output_shape)) | ||
|
|
||
| # Verify that the number of images matches the number of indices | ||
| if images_dft.shape[0] != len(indices): | ||
| raise ValueError( | ||
|
|
@@ -844,3 +850,150 @@ def get_dataframe_copy(self) -> pd.DataFrame: | |
| A copy of the underlying DataFrame | ||
| """ | ||
| return self._df.copy() | ||
|
|
||
| def construct_image_stack_from_movie( | ||
| self, | ||
| movie: torch.Tensor, | ||
| deformation_field: torch.Tensor, | ||
| pos_reference: Literal["center", "top-left"] = "top-left", | ||
| handle_bounds: Literal["pad", "error"] = "pad", | ||
| padding_mode: Literal["constant", "reflect", "replicate"] = "constant", | ||
| padding_value: float = 0.0, | ||
| pre_exposure: float = 0.0, | ||
| fluence_per_frame: float = 0.0, | ||
| ) -> torch.Tensor: | ||
| """Construct a stack of images from a movie file. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| movie : torch.Tensor | ||
| The movie tensor. | ||
| deformation_field : torch.Tensor | ||
| The deformation field tensor. | ||
| pos_reference : Literal["center", "top-left"], optional | ||
| The reference point for the positions, by default "top-left". If "center", | ||
| the boxes extracted are image[y - box_size // 2 : y + box_size // 2, ...]. | ||
| If "top-left", the boxes will be image[y : y + box_size, ...]. | ||
| handle_bounds : Literal["pad", "error"], optional | ||
| How to handle the bounds of the image, by default "pad". If "pad", the image | ||
| will be padded with the padding value based on the padding mode. | ||
| If "error", an error will be raised if any region exceeds the image bounds. | ||
| Note clipping is not supported | ||
| since returned stack may have inhomogeneous sizes. | ||
| padding_mode : Literal["constant", "reflect", "replicate"], optional | ||
| The padding mode to use when padding the image, by default "constant". | ||
| "constant" pads with the value `padding_value`, "reflect" pads with the | ||
| reflection of the image, and "replicate" pads with the last pixel | ||
| of the image. These match the modes available in `torch.nn.functional.pad`. | ||
| padding_value : float, optional | ||
| The value to use for padding when `padding_mode` is "constant", | ||
| by default 0.0. | ||
| pre_exposure : float, optional | ||
| The pre-exposure time in seconds, by default 0.0. | ||
| fluence_per_frame : float, optional | ||
| The dose per frame in electrons per pixel, by default 0.0. | ||
|
|
||
| Returns | ||
| ------- | ||
| torch.Tensor | ||
| The stack of images with shape (N, H, W) where N is the number of particles | ||
| and (H, W) is the extracted box size. | ||
| """ | ||
| pixel_sizes = self.get_pixel_size() | ||
| # Determine which position columns to use (refined if available) | ||
| y_col, x_col = self._get_position_reference_columns() | ||
| # Create an empty tensor to store the image stack | ||
| h, w = self.original_template_size | ||
| box_h, box_w = self.extracted_box_size | ||
| t, img_h, img_w = movie.shape | ||
| _, _, gh, gw = deformation_field.shape | ||
| normalized_t = torch.linspace(0, 1, steps=t, device=movie.device) | ||
| pixel_grid = coordinate_grid( | ||
| image_shape=(img_h, img_w), | ||
| device=movie.device, | ||
| ) | ||
| # Find the indexes in the DataFrame that correspond to each unique image | ||
| paticle_indexes = self._df.index.tolist() | ||
| pos_y = self._df.loc[paticle_indexes, y_col].to_numpy() | ||
| pos_x = self._df.loc[paticle_indexes, x_col].to_numpy() | ||
| # If the position reference is "top-left", shift (x, y) by half the original | ||
| # template width/height so reference is now in the center | ||
| if pos_reference == "center": | ||
| pos_y = pos_y - h // 2 | ||
| pos_x = pos_x - w // 2 | ||
|
|
||
| pos_y_center = pos_y + h // 2 | ||
| pos_x_center = pos_x + w // 2 | ||
| pos_y -= (box_h - h) // 2 | ||
| pos_x -= (box_w - w) // 2 | ||
| pos_y = torch.tensor(pos_y) | ||
| pos_x = torch.tensor(pos_x) | ||
| pos_y_center = torch.tensor(pos_y_center) | ||
| pos_x_center = torch.tensor(pos_x_center) | ||
| aligned_particle_movies_rfft = torch.zeros( | ||
| (self.num_particles, t, box_h, box_w // 2 + 1), dtype=torch.complex64 | ||
| ) | ||
| # set frames mean zero | ||
| movie = movie - torch.mean(movie, dim=(-2, -1), keepdim=True) | ||
| for frame_index, movie_frame in enumerate(movie): | ||
| print(f"Extracting particle images for frame {frame_index} of {t}") | ||
| # If memory becomes an issue, do this in batches of particles | ||
| # Get the shift in the deformation field for the center pixel | ||
| frame_deformation_field = evaluate_deformation_field_at_t( | ||
| deformation_field=deformation_field, | ||
| t=normalized_t[frame_index], | ||
| grid_shape=(10 * gh, 10 * gw), | ||
| ) | ||
| pixel_shifts = get_pixel_shifts( | ||
| frame=movie_frame, | ||
| pixel_spacing=pixel_sizes[0], | ||
| frame_deformation_grid=frame_deformation_field, | ||
| pixel_grid=pixel_grid, | ||
| ) # (H, W, yx) | ||
| y_shifts = -pixel_shifts[ | ||
| pos_y_center, pos_x_center, 0 | ||
| ] # y-component of shifts | ||
| x_shifts = -pixel_shifts[ | ||
| pos_y_center, pos_x_center, 1 | ||
| ] # x-component of shifts | ||
| print( | ||
| f"frame {frame_index}: y_shifts: {y_shifts[0]}, x_shifts: {x_shifts[0]}" | ||
| ) | ||
|
||
| # Extract particles from this frame | ||
| cropped_images = get_cropped_image_regions( | ||
| movie_frame, | ||
| pos_y, | ||
| pos_x, | ||
| self.extracted_box_size, | ||
| pos_reference="top-left", | ||
| handle_bounds=handle_bounds, | ||
| padding_mode=padding_mode, | ||
| padding_value=padding_value, | ||
| ) | ||
| # Now Fourier shift the cropped images (use torch-fourier-shift) | ||
| cropped_images_dft = torch.fft.rfftn(cropped_images, dim=(-2, -1)) # pylint: disable=not-callable | ||
| shifted_fft = fourier_shift_dft_2d( | ||
| dft=cropped_images_dft, | ||
| image_shape=(box_h, box_w), | ||
| shifts=torch.stack((y_shifts, x_shifts), dim=-1), # (N, 2) shifts | ||
| rfft=True, | ||
| fftshifted=False, | ||
| ) | ||
| # store them in a tensor shape (N, t, box_h, box_w) | ||
| aligned_particle_movies_rfft[:, frame_index] = shifted_fft | ||
| # Dose weight the aligned particle images | ||
| aligned_particle_images = torch.zeros((self.num_particles, box_h, box_w)) | ||
| for particle_index in range(self.num_particles): | ||
| print(f"Dose weighting particle {particle_index} of {self.num_particles}") | ||
| particle_dft = aligned_particle_movies_rfft[particle_index] | ||
| dw_sum = dose_weight( | ||
| movie_fft=particle_dft, | ||
| pixel_size=pixel_sizes[particle_index], | ||
| pre_exposure=pre_exposure, | ||
| fluence_per_frame=fluence_per_frame, | ||
| voltage=self._df["voltage"].to_numpy()[particle_index], | ||
| ) # (box_h, box_w) | ||
| aligned_particle_images[particle_index] = dw_sum | ||
|
|
||
| self.image_stack = aligned_particle_images | ||
| return aligned_particle_images | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,6 @@ | ||
| """Pydantic model for running the refine template program.""" | ||
|
|
||
| import warnings | ||
| from typing import Any, ClassVar | ||
|
|
||
| import numpy as np | ||
|
|
@@ -9,6 +10,7 @@ | |
| from leopard_em.pydantic_models.config import ( | ||
| ComputationalConfigRefine, | ||
| DefocusSearchConfig, | ||
| MovieConfig, | ||
| PixelSizeSearchConfig, | ||
| PreprocessingFilters, | ||
| RefineOrientationConfig, | ||
|
|
@@ -45,6 +47,8 @@ class RefineTemplateManager(BaseModel2DTM): | |
| Default is True. | ||
| template_volume : ExcludedTensor | ||
| The template volume tensor (excluded from serialization). | ||
| movie_config : MovieConfig | ||
| Configuration for the movie. | ||
|
|
||
| Methods | ||
| ------- | ||
|
|
@@ -66,7 +70,7 @@ class RefineTemplateManager(BaseModel2DTM): | |
| orientation_refinement_config: RefineOrientationConfig | ||
| preprocessing_filters: PreprocessingFilters | ||
| computational_config: ComputationalConfigRefine | ||
|
|
||
| movie_config: MovieConfig | ||
| apply_global_filtering: bool = True | ||
|
|
||
| # Excluded tensors | ||
|
|
@@ -109,6 +113,18 @@ def make_backend_core_function_kwargs( | |
| # The relative pixel size values to search over | ||
| pixel_size_offsets = self.pixel_size_refinement_config.pixel_size_values | ||
|
|
||
| # Load movie and deformation field | ||
| movie = self.movie_config.movie | ||
| deformation_field = self.movie_config.deformation_field | ||
|
|
||
| if movie is not None and self.apply_global_filtering: | ||
| warnings.warn( | ||
| "Global filtering cannot be applied with movie refinement. " | ||
| "Disabling apply_global_filtering.", | ||
| stacklevel=2, | ||
| ) | ||
| self.apply_global_filtering = False | ||
|
||
|
|
||
| # Use the common utility function to set up the backend kwargs | ||
| # pylint: disable=duplicate-code | ||
| return setup_particle_backend_kwargs( | ||
|
|
@@ -120,6 +136,10 @@ def make_backend_core_function_kwargs( | |
| defocus_offsets=defocus_offsets, | ||
| pixel_size_offsets=pixel_size_offsets, | ||
| apply_global_filtering=self.apply_global_filtering, | ||
| movie=movie, | ||
| deformation_field=deformation_field, | ||
| pre_exposure=self.movie_config.pre_exposure, | ||
| fluence_per_frame=self.movie_config.fluence_per_frame, | ||
| device_list=self.computational_config.gpu_devices, | ||
| ) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this extra value of
10? I know its because of oversampling to then do interpolation, but it's not immediately obvious why the grid dimensions are being multiplied.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the same as in torch motion correction.
It is expensive to evaluate the cubic spline grid over all pixels. As an approximation, we evaluate it over a grid 10x10 larger than the control points grid, and then use bicubic interpolation to go to the full per-pixel shift.
10x was chosen because it seemed to work; more work would be needed to characterize sensitivity vs speed properly.