@@ -111,56 +111,90 @@ def calculate_whitening_filter(
111111class PhaseRandomizationFilterConfig (BaseModel2DTM ):
112112 """Configuration for phase randomization filter.
113113
114- NOTE: Something is not working with the underlying torch_fourier_filter code
115- for phase randomization .
114+ Phase randomization is applied directly to the template volume's Fourier transform,
115+ randomizing the phases above a certain frequency cutoff while preserving amplitudes .
116116
117117 Attributes
118118 ----------
119119 enabled : bool
120- If True, apply a phase randomization filter to the input image. Default
121- is False.
122- cuton : float
123- Spatial resolution, in terms of Nyquist, above which to randomize the phase.
120+ If True, apply phase randomization to the template volume. Default is False.
121+ cuton : Optional[float]
122+ Spatial frequency cutoff, in terms of Nyquist frequency, above which to
123+ randomize the phase. Frequencies above this value will have their phases
124+ randomized while amplitudes are preserved. If None, phase randomization
125+ is applied to all frequencies. Default is None.
124126
125127 Methods
126128 -------
127- calculate_phase_randomization_filter(ref_img_rfft)
128- Helper function for the phase randomization filter based on the input reference
129- image and held configuration parameters.
129+ apply_phase_randomization_to_template(template_dft)
130+ Apply phase randomization directly to a 3D template volume's Fourier transform.
130131 """
131132
132133 enabled : bool = False
133134 cuton : Optional [Annotated [float , Field (ge = 0.0 )]] = None
134135
135- def calculate_phase_randomization_filter (
136- self , ref_img_rfft : torch .Tensor
136+ def apply_phase_randomization_to_template (
137+ self , template_dft : torch .Tensor
137138 ) -> torch .Tensor :
138- """Helper function for phase randomization filter based on the reference image.
139+ """Apply phase randomization to a 3D template volume's Fourier transform.
140+
141+ This method modifies the template DFT in-place (or returns a modified copy)
142+ by randomizing the phases while preserving amplitudes. If cuton is provided,
143+ only frequencies above the cutoff are randomized. If cuton is None, all
144+ frequencies are randomized. The template DFT should be in RFFT format as
145+ produced by `volume_to_rfft_fourier_slice()`.
139146
140147 Parameters
141148 ----------
142- ref_img_rfft : torch.Tensor
143- The image to phase randomization.
144- This should be RFFT'd and unshifted
145- (zero-frequency component at the top-left corner).
146- """
147- output_shape = ref_img_rfft .shape
149+ template_dft : torch.Tensor
150+ The 3D template volume's Fourier transform. Should have shape
151+ (d, h, w // 2 + 1) and be fftshifted in dimensions (0, 1) as
152+ produced by `volume_to_rfft_fourier_slice()`. This should be RFFT'd
153+ and fftshifted in the first two dimensions.
148154
149- # Handle case where phase randomization filter is disabled
155+ Returns
156+ -------
157+ torch.Tensor
158+ The phase-randomized template DFT. If phase randomization is disabled,
159+ returns the input tensor unchanged.
160+ """
161+ # Handle case where phase randomization is disabled
150162 if not self .enabled :
151- return torch .ones (output_shape , dtype = torch .float32 )
152-
153- # Fix for underlying shape bug in torch_fourier_filter
154- output_shape = output_shape [:- 1 ] + (2 * (output_shape [- 1 ] - 1 ),)
155-
156- return phase_randomize (
157- dft = ref_img_rfft ,
158- image_shape = output_shape ,
163+ return template_dft
164+
165+ # The template_dft is 3D with shape (d, h, w // 2 + 1)
166+ # It's fftshifted in dims (0, 1) but not in dim 2
167+ # We need to convert to real-space shape for phase_randomize
168+ d , h , w_rfft = template_dft .shape
169+ w = 2 * (w_rfft - 1 ) # Convert RFFT width to real-space width
170+ image_shape = (d , h , w )
171+
172+ # Apply phase randomization
173+ # phase_randomize expects the DFT to be RFFT'd and unshifted
174+ # (fftshift=False). But our template_dft is fftshifted in dims (0, 1),
175+ # so we need to ifftshift first
176+ # pylint: disable-next=E1102
177+ template_dft_unshifted = torch .fft .ifftshift (template_dft , dim = (0 , 1 ))
178+
179+ # Apply phase randomization
180+ # If cuton is None, set to 0 to randomize all frequencies
181+ cuton_value = self .cuton if self .cuton is not None else 0.0
182+ template_dft_randomized = phase_randomize (
183+ dft = template_dft_unshifted ,
184+ image_shape = image_shape ,
159185 rfft = True ,
160186 fftshift = False ,
161- cuton = self . cuton ,
187+ cuton = cuton_value ,
162188 )
163189
190+ # Shift back to match the original format
191+ # pylint: disable-next=E1102
192+ template_dft_randomized = torch .fft .fftshift (
193+ template_dft_randomized , dim = (0 , 1 )
194+ )
195+
196+ return template_dft_randomized
197+
164198
165199class BandpassFilterConfig (BaseModel2DTM ):
166200 """Configuration for the bandpass filter.
0 commit comments