1818from collections .abc import Sequence
1919from typing import Optional , Union
2020
21- import numpy as np
2221import torch
2322from torch .fft import fftn , fftshift , ifftn , ifftshift
2423
@@ -37,25 +36,38 @@ class RadialFourier3D(Transform):
3736 normalize frequency representations across datasets with different acquisition parameters.
3837
3938 Args:
40- normalize (bool) : if True, normalize the output by the number of voxels.
41- return_magnitude (bool) : if True, return magnitude of the complex result.
42- return_phase (bool) : if True, return phase of the complex result.
43- radial_bins (Optional[int]) : number of radial bins for frequency aggregation.
39+ normalize: if True, normalize the output by the number of voxels.
40+ return_magnitude: if True, return magnitude of the complex result.
41+ return_phase: if True, return phase of the complex result.
42+ radial_bins: number of radial bins for frequency aggregation.
4443 If None, returns full 3D spectrum.
45- max_frequency (float) : maximum normalized frequency to include (0.0 to 1.0).
46- spatial_dims (Union[int, Sequence[int]]) : spatial dimensions to apply transform to.
44+ max_frequency: maximum normalized frequency to include (0.0 to 1.0).
45+ spatial_dims: spatial dimensions to apply transform to.
4746 Default is last three dimensions.
4847
4948 Returns:
5049 Radial Fourier transform of input data. Shape depends on parameters:
51- - If radial_bins is None: same spatial shape as input; magnitude and phase
52- (if both requested) are concatenated along the last dimension, doubling it.
53- - If radial_bins is set: shape (..., radial_bins) or (..., 2*radial_bins) if both
54- magnitude and phase are requested, preserving leading (batch/channel) dimensions.
50+ - If radial_bins is None and only magnitude OR phase is requested:
51+ same spatial shape as input (..., D, H, W)
52+ - If radial_bins is None and both magnitude AND phase are requested:
53+ shape (..., D, H, 2*W) [magnitude and phase concatenated along last dimension]
54+ - If radial_bins is set and only magnitude OR phase is requested:
55+ shape (..., radial_bins)
56+ - If radial_bins is set and both magnitude AND phase are requested:
57+ shape (..., 2*radial_bins)
5558
5659 Raises:
5760 ValueError: If max_frequency not in (0.0, 1.0], radial_bins < 1,
5861 or both return_magnitude and return_phase are False.
62+
63+ Example:
64+ >>> transform = RadialFourier3D(radial_bins=32, return_magnitude=True)
65+ >>> image = torch.randn(1, 128, 128, 96)
66+ >>> features = transform(image) # Shape: (1, 32)
67+ >>>
68+ >>> transform = RadialFourier3D(radial_bins=None, return_magnitude=True, return_phase=True)
69+ >>> image = torch.randn(1, 128, 128, 96)
70+ >>> spectrum = transform(image) # Shape: (1, 128, 128, 192) - magnitude+phase concatenated
5971 """
6072
6173 def __init__ (
@@ -80,9 +92,9 @@ def __init__(
8092
8193 # Validate parameters
8294 if not 0.0 < max_frequency <= 1.0 :
83- raise ValueError (f "max_frequency must be in (0.0, 1.0], got { max_frequency } " )
95+ raise ValueError ("max_frequency must be in (0.0, 1.0]" )
8496 if radial_bins is not None and radial_bins < 1 :
85- raise ValueError (f "radial_bins must be >= 1, got { radial_bins } " )
97+ raise ValueError ("radial_bins must be >= 1" )
8698 if not return_magnitude and not return_phase :
8799 raise ValueError ("At least one of return_magnitude or return_phase must be True" )
88100
@@ -132,11 +144,11 @@ def _compute_radial_spectrum(self, spectrum: torch.Tensor, radial_coords: torch.
132144 result_real = torch .zeros (self .radial_bins , dtype = spectrum .real .dtype , device = spectrum .device )
133145 result_imag = torch .zeros (self .radial_bins , dtype = spectrum .imag .dtype , device = spectrum .device )
134146
135- # Bin the frequencies - spectrum and radial_coords are both 1D
147+ # Bin the frequencies using torch.bucketize
148+ bin_indices = torch .bucketize (radial_coords , bin_edges [1 :- 1 ], right = False )
136149 for i in range (self .radial_bins ):
137- mask = ( radial_coords >= bin_edges [ i ]) & ( radial_coords < bin_edges [ i + 1 ])
150+ mask = bin_indices == i
138151 if mask .any ():
139- # spectrum is 1D, so we can index it directly
140152 result_real [i ] = spectrum .real [mask ].mean ()
141153 result_imag [i ] = spectrum .imag [mask ].mean ()
142154
@@ -154,14 +166,25 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
154166 where D, H, W are spatial dimensions.
155167
156168 Returns:
157- Transformed data in radial frequency domain.
169+ Transformed data in radial frequency domain. Shape depends on parameters:
170+ - If radial_bins is None and only magnitude OR phase is requested:
171+ same spatial shape as input (..., D, H, W)
172+ - If radial_bins is None and both magnitude AND phase are requested:
173+ shape (..., D, H, 2*W) [magnitude and phase concatenated along last dimension]
174+ - If radial_bins is set and only magnitude OR phase is requested:
175+ shape (..., radial_bins)
176+ - If radial_bins is set and both magnitude AND phase are requested:
177+ shape (..., 2*radial_bins)
178+
179+ Raises:
180+ ValueError: If input does not have exactly 3 spatial dimensions.
158181 """
159182 # Convert to tensor if needed
160183 img_tensor , * _ = convert_data_type (img , torch .Tensor )
161184 # Get spatial dimensions
162185 spatial_shape = tuple (img_tensor .shape [d ] for d in self .spatial_dims )
163186 if len (spatial_shape ) != 3 :
164- raise ValueError (f "Expected 3 spatial dimensions, got { len ( spatial_shape ) } " )
187+ raise ValueError ("Expected 3 spatial dimensions" )
165188
166189 # Compute 3D FFT
167190 # Shift zero frequency to center and compute FFT
@@ -238,12 +261,18 @@ def inverse(self, radial_data: NdarrayOrTensor, original_shape: tuple[int, ...])
238261 Inverse transform from radial frequency domain to spatial domain.
239262
240263 Args:
241- radial_data: data in radial frequency domain.
264+ radial_data: data in radial frequency domain. When both magnitude and phase
265+ are requested with radial_bins=None, they should be concatenated along
266+ the last dimension (magnitude first, then phase).
242267 original_shape: original spatial shape (D, H, W).
243268
244269 Returns:
245270 Reconstructed spatial data.
246271
272+ Raises:
273+ ValueError: If input dimensions don't match expected shape for magnitude+phase concatenation.
274+ NotImplementedError: If radial_bins is not None.
275+
247276 Note:
248277 Only exact inverse is supported (radial_bins=None). Raises NotImplementedError otherwise.
249278 """
@@ -258,9 +287,8 @@ def inverse(self, radial_data: NdarrayOrTensor, original_shape: tuple[int, ...])
258287 last_dim = radial_tensor .shape [- 1 ]
259288 if last_dim != original_shape [- 1 ] * 2 :
260289 raise ValueError (
261- f"For inverse with magnitude+phase and radial_bins=None, "
262- f"expected last dimension to be doubled. "
263- f"Got { last_dim } , expected { original_shape [- 1 ] * 2 } "
290+ "For inverse with magnitude+phase and radial_bins=None, "
291+ "expected last dimension to be doubled."
264292 )
265293
266294 split_size = original_shape [- 1 ]
@@ -295,16 +323,26 @@ class RadialFourierFeatures3D(Transform):
295323 Args:
296324 n_bins_list: list of radial bin counts to compute.
297325 return_types: list of return types: 'magnitude', 'phase', or 'complex'.
298- 'complex' returns both magnitude and phase concatenated as real values.
326+ 'complex' returns both magnitude and phase concatenated as real values
327+ along the last dimension (when radial_bins=None) or along the feature
328+ dimension (when radial_bins is set).
299329 normalize: if True, normalize the output.
300330
301331 Returns:
302- Concatenated radial Fourier features.
332+ Concatenated radial Fourier features. Shape: (..., total_features) where
333+ total_features = sum(bins * (2 if return_type=='complex' else 1) for bins in n_bins_list).
334+
335+ Raises:
336+ ValueError: If n_bins_list or return_types is empty.
303337
304338 Example:
305339 >>> transform = RadialFourierFeatures3D(n_bins_list=[32, 64, 128])
306340 >>> image = torch.randn(1, 128, 128, 96)
307341 >>> features = transform(image) # Shape: (1, 32+64+128=224)
342+ >>>
343+ >>> transform = RadialFourierFeatures3D(n_bins_list=[16, 32], return_types=['complex'])
344+ >>> image = torch.randn(1, 128, 128, 96)
345+ >>> features = transform(image) # Shape: (1, (16+32)*2=96)
308346 """
309347
310348 def __init__ (
@@ -343,17 +381,14 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
343381 feat = transform (img )
344382 features .append (feat )
345383
346- # Convert all features to tensors if any are numpy arrays
384+ # Convert all features to tensors using convert_data_type
347385 features_tensors = []
348386 for feat in features :
349- if isinstance (feat , np .ndarray ):
350- features_tensors .append (torch .from_numpy (feat ))
351- else :
352- features_tensors .append (feat )
387+ feat_tensor , * _ = convert_data_type (feat , torch .Tensor )
388+ features_tensors .append (feat_tensor )
353389 output = torch .cat (features_tensors , dim = - 1 )
354390
355- # Convert to original type if needed
356- if isinstance (img , np .ndarray ):
357- output = output .cpu ().numpy ()
391+ # Convert back to original type
392+ output , * _ = convert_data_type (output , type (img ))
358393
359394 return output
0 commit comments