Skip to content

Commit 9976407

Browse files
committed
feat: Add proper template phase randomization
1 parent 00b73c8 commit 9976407

File tree

3 files changed

+72
-28
lines changed

3 files changed

+72
-28
lines changed

src/leopard_em/pydantic_models/config/correlation_filters.py

Lines changed: 62 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -111,56 +111,90 @@ def calculate_whitening_filter(
111111
class PhaseRandomizationFilterConfig(BaseModel2DTM):
112112
"""Configuration for phase randomization filter.
113113
114-
NOTE: Something is not working with the underlying torch_fourier_filter code
115-
for phase randomization.
114+
Phase randomization is applied directly to the template volume's Fourier transform,
115+
randomizing the phases above a certain frequency cutoff while preserving amplitudes.
116116
117117
Attributes
118118
----------
119119
enabled : bool
120-
If True, apply a phase randomization filter to the input image. Default
121-
is False.
122-
cuton : float
123-
Spatial resolution, in terms of Nyquist, above which to randomize the phase.
120+
If True, apply phase randomization to the template volume. Default is False.
121+
cuton : Optional[float]
122+
Spatial frequency cutoff, in terms of Nyquist frequency, above which to
123+
randomize the phase. Frequencies above this value will have their phases
124+
randomized while amplitudes are preserved. If None, phase randomization
125+
is applied to all frequencies. Default is None.
124126
125127
Methods
126128
-------
127-
calculate_phase_randomization_filter(ref_img_rfft)
128-
Helper function for the phase randomization filter based on the input reference
129-
image and held configuration parameters.
129+
apply_phase_randomization_to_template(template_dft)
130+
Apply phase randomization directly to a 3D template volume's Fourier transform.
130131
"""
131132

132133
enabled: bool = False
133134
cuton: Optional[Annotated[float, Field(ge=0.0)]] = None
134135

135-
def calculate_phase_randomization_filter(
136-
self, ref_img_rfft: torch.Tensor
136+
def apply_phase_randomization_to_template(
137+
self, template_dft: torch.Tensor
137138
) -> torch.Tensor:
138-
"""Helper function for phase randomization filter based on the reference image.
139+
"""Apply phase randomization to a 3D template volume's Fourier transform.
140+
141+
This method modifies the template DFT in-place (or returns a modified copy)
142+
by randomizing the phases while preserving amplitudes. If cuton is provided,
143+
only frequencies above the cutoff are randomized. If cuton is None, all
144+
frequencies are randomized. The template DFT should be in RFFT format as
145+
produced by `volume_to_rfft_fourier_slice()`.
139146
140147
Parameters
141148
----------
142-
ref_img_rfft : torch.Tensor
143-
The image to phase randomization.
144-
This should be RFFT'd and unshifted
145-
(zero-frequency component at the top-left corner).
146-
"""
147-
output_shape = ref_img_rfft.shape
149+
template_dft : torch.Tensor
150+
The 3D template volume's Fourier transform. Should have shape
151+
(d, h, w // 2 + 1) and be fftshifted in dimensions (0, 1) as
152+
produced by `volume_to_rfft_fourier_slice()`. This should be RFFT'd
153+
and fftshifted in the first two dimensions.
148154
149-
# Handle case where phase randomization filter is disabled
155+
Returns
156+
-------
157+
torch.Tensor
158+
The phase-randomized template DFT. If phase randomization is disabled,
159+
returns the input tensor unchanged.
160+
"""
161+
# Handle case where phase randomization is disabled
150162
if not self.enabled:
151-
return torch.ones(output_shape, dtype=torch.float32)
152-
153-
# Fix for underlying shape bug in torch_fourier_filter
154-
output_shape = output_shape[:-1] + (2 * (output_shape[-1] - 1),)
155-
156-
return phase_randomize(
157-
dft=ref_img_rfft,
158-
image_shape=output_shape,
163+
return template_dft
164+
165+
# The template_dft is 3D with shape (d, h, w // 2 + 1)
166+
# It's fftshifted in dims (0, 1) but not in dim 2
167+
# We need to convert to real-space shape for phase_randomize
168+
d, h, w_rfft = template_dft.shape
169+
w = 2 * (w_rfft - 1) # Convert RFFT width to real-space width
170+
image_shape = (d, h, w)
171+
172+
# Apply phase randomization
173+
# phase_randomize expects the DFT to be RFFT'd and unshifted
174+
# (fftshift=False). But our template_dft is fftshifted in dims (0, 1),
175+
# so we need to ifftshift first
176+
# pylint: disable-next=E1102
177+
template_dft_unshifted = torch.fft.ifftshift(template_dft, dim=(0, 1))
178+
179+
# Apply phase randomization
180+
# If cuton is None, set to 0 to randomize all frequencies
181+
cuton_value = self.cuton if self.cuton is not None else 0.0
182+
template_dft_randomized = phase_randomize(
183+
dft=template_dft_unshifted,
184+
image_shape=image_shape,
159185
rfft=True,
160186
fftshift=False,
161-
cuton=self.cuton,
187+
cuton=cuton_value,
162188
)
163189

190+
# Shift back to match the original format
191+
# pylint: disable-next=E1102
192+
template_dft_randomized = torch.fft.fftshift(
193+
template_dft_randomized, dim=(0, 1)
194+
)
195+
196+
return template_dft_randomized
197+
164198

165199
class BandpassFilterConfig(BaseModel2DTM):
166200
"""Configuration for the bandpass filter.

src/leopard_em/pydantic_models/managers/match_template_manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,12 @@ def make_backend_core_function_kwargs(self) -> dict[str, Any]:
205205

206206
template_dft = volume_to_rfft_fourier_slice(template)
207207

208+
# Apply phase randomization to template if enabled
209+
phase_rand_filter = self.preprocessing_filters.phase_randomization_filter
210+
template_dft = phase_rand_filter.apply_phase_randomization_to_template(
211+
template_dft
212+
)
213+
208214
return {
209215
"image_dft": image_preprocessed_dft,
210216
"template_dft": template_dft,

src/leopard_em/utils/backend_setup.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ def _process_particle_images_for_filters(
8383

8484
template_dft = volume_to_rfft_fourier_slice(template)
8585

86+
# Apply phase randomization to template if enabled
87+
phase_rand_filter = preprocessing_filters.phase_randomization_filter
88+
template_dft = phase_rand_filter.apply_phase_randomization_to_template(template_dft)
89+
8690
return (
8791
particle_images_dft,
8892
template_dft,

0 commit comments

Comments
 (0)