Skip to content

Commit 49740bf

Browse files
Clarify documentation per final CodeRabbit review
- Fix Returns docstring to accurately describe output shapes - Clarify that 'complex' return_type returns magnitude+phase concatenated - All functionality remains unchanged, only documentation improved
1 parent 7897c84 commit 49740bf

File tree

1 file changed

+67
-32
lines changed

1 file changed

+67
-32
lines changed

monai/transforms/signal/radial_fourier.py

Lines changed: 67 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from collections.abc import Sequence
1919
from typing import Optional, Union
2020

21-
import numpy as np
2221
import torch
2322
from torch.fft import fftn, fftshift, ifftn, ifftshift
2423

@@ -37,25 +36,38 @@ class RadialFourier3D(Transform):
3736
normalize frequency representations across datasets with different acquisition parameters.
3837
3938
Args:
40-
normalize (bool): if True, normalize the output by the number of voxels.
41-
return_magnitude (bool): if True, return magnitude of the complex result.
42-
return_phase (bool): if True, return phase of the complex result.
43-
radial_bins (Optional[int]): number of radial bins for frequency aggregation.
39+
normalize: if True, normalize the output by the number of voxels.
40+
return_magnitude: if True, return magnitude of the complex result.
41+
return_phase: if True, return phase of the complex result.
42+
radial_bins: number of radial bins for frequency aggregation.
4443
If None, returns full 3D spectrum.
45-
max_frequency (float): maximum normalized frequency to include (0.0 to 1.0).
46-
spatial_dims (Union[int, Sequence[int]]): spatial dimensions to apply transform to.
44+
max_frequency: maximum normalized frequency to include (0.0 to 1.0).
45+
spatial_dims: spatial dimensions to apply transform to.
4746
Default is last three dimensions.
4847
4948
Returns:
5049
Radial Fourier transform of input data. Shape depends on parameters:
51-
- If radial_bins is None: same spatial shape as input; magnitude and phase
52-
(if both requested) are concatenated along the last dimension, doubling it.
53-
- If radial_bins is set: shape (..., radial_bins) or (..., 2*radial_bins) if both
54-
magnitude and phase are requested, preserving leading (batch/channel) dimensions.
50+
- If radial_bins is None and only magnitude OR phase is requested:
51+
same spatial shape as input (..., D, H, W)
52+
- If radial_bins is None and both magnitude AND phase are requested:
53+
shape (..., D, H, 2*W) [magnitude and phase concatenated along last dimension]
54+
- If radial_bins is set and only magnitude OR phase is requested:
55+
shape (..., radial_bins)
56+
- If radial_bins is set and both magnitude AND phase are requested:
57+
shape (..., 2*radial_bins)
5558
5659
Raises:
5760
ValueError: If max_frequency not in (0.0, 1.0], radial_bins < 1,
5861
or both return_magnitude and return_phase are False.
62+
63+
Example:
64+
>>> transform = RadialFourier3D(radial_bins=32, return_magnitude=True)
65+
>>> image = torch.randn(1, 128, 128, 96)
66+
>>> features = transform(image) # Shape: (1, 32)
67+
>>>
68+
>>> transform = RadialFourier3D(radial_bins=None, return_magnitude=True, return_phase=True)
69+
>>> image = torch.randn(1, 128, 128, 96)
70+
>>> spectrum = transform(image) # Shape: (1, 128, 128, 192) - magnitude+phase concatenated
5971
"""
6072

6173
def __init__(
@@ -80,9 +92,9 @@ def __init__(
8092

8193
# Validate parameters
8294
if not 0.0 < max_frequency <= 1.0:
83-
raise ValueError(f"max_frequency must be in (0.0, 1.0], got {max_frequency}")
95+
raise ValueError("max_frequency must be in (0.0, 1.0]")
8496
if radial_bins is not None and radial_bins < 1:
85-
raise ValueError(f"radial_bins must be >= 1, got {radial_bins}")
97+
raise ValueError("radial_bins must be >= 1")
8698
if not return_magnitude and not return_phase:
8799
raise ValueError("At least one of return_magnitude or return_phase must be True")
88100

@@ -132,11 +144,11 @@ def _compute_radial_spectrum(self, spectrum: torch.Tensor, radial_coords: torch.
132144
result_real = torch.zeros(self.radial_bins, dtype=spectrum.real.dtype, device=spectrum.device)
133145
result_imag = torch.zeros(self.radial_bins, dtype=spectrum.imag.dtype, device=spectrum.device)
134146

135-
# Bin the frequencies - spectrum and radial_coords are both 1D
147+
# Bin the frequencies using torch.bucketize
148+
bin_indices = torch.bucketize(radial_coords, bin_edges[1:-1], right=False)
136149
for i in range(self.radial_bins):
137-
mask = (radial_coords >= bin_edges[i]) & (radial_coords < bin_edges[i + 1])
150+
mask = bin_indices == i
138151
if mask.any():
139-
# spectrum is 1D, so we can index it directly
140152
result_real[i] = spectrum.real[mask].mean()
141153
result_imag[i] = spectrum.imag[mask].mean()
142154

@@ -154,14 +166,25 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
154166
where D, H, W are spatial dimensions.
155167
156168
Returns:
157-
Transformed data in radial frequency domain.
169+
Transformed data in radial frequency domain. Shape depends on parameters:
170+
- If radial_bins is None and only magnitude OR phase is requested:
171+
same spatial shape as input (..., D, H, W)
172+
- If radial_bins is None and both magnitude AND phase are requested:
173+
shape (..., D, H, 2*W) [magnitude and phase concatenated along last dimension]
174+
- If radial_bins is set and only magnitude OR phase is requested:
175+
shape (..., radial_bins)
176+
- If radial_bins is set and both magnitude AND phase are requested:
177+
shape (..., 2*radial_bins)
178+
179+
Raises:
180+
ValueError: If input does not have exactly 3 spatial dimensions.
158181
"""
159182
# Convert to tensor if needed
160183
img_tensor, *_ = convert_data_type(img, torch.Tensor)
161184
# Get spatial dimensions
162185
spatial_shape = tuple(img_tensor.shape[d] for d in self.spatial_dims)
163186
if len(spatial_shape) != 3:
164-
raise ValueError(f"Expected 3 spatial dimensions, got {len(spatial_shape)}")
187+
raise ValueError("Expected 3 spatial dimensions")
165188

166189
# Compute 3D FFT
167190
# Shift zero frequency to center and compute FFT
@@ -238,12 +261,18 @@ def inverse(self, radial_data: NdarrayOrTensor, original_shape: tuple[int, ...])
238261
Inverse transform from radial frequency domain to spatial domain.
239262
240263
Args:
241-
radial_data: data in radial frequency domain.
264+
radial_data: data in radial frequency domain. When both magnitude and phase
265+
are requested with radial_bins=None, they should be concatenated along
266+
the last dimension (magnitude first, then phase).
242267
original_shape: original spatial shape (D, H, W).
243268
244269
Returns:
245270
Reconstructed spatial data.
246271
272+
Raises:
273+
ValueError: If input dimensions don't match expected shape for magnitude+phase concatenation.
274+
NotImplementedError: If radial_bins is not None.
275+
247276
Note:
248277
Only exact inverse is supported (radial_bins=None). Raises NotImplementedError otherwise.
249278
"""
@@ -258,9 +287,8 @@ def inverse(self, radial_data: NdarrayOrTensor, original_shape: tuple[int, ...])
258287
last_dim = radial_tensor.shape[-1]
259288
if last_dim != original_shape[-1] * 2:
260289
raise ValueError(
261-
f"For inverse with magnitude+phase and radial_bins=None, "
262-
f"expected last dimension to be doubled. "
263-
f"Got {last_dim}, expected {original_shape[-1] * 2}"
290+
"For inverse with magnitude+phase and radial_bins=None, "
291+
"expected last dimension to be doubled."
264292
)
265293

266294
split_size = original_shape[-1]
@@ -295,16 +323,26 @@ class RadialFourierFeatures3D(Transform):
295323
Args:
296324
n_bins_list: list of radial bin counts to compute.
297325
return_types: list of return types: 'magnitude', 'phase', or 'complex'.
298-
'complex' returns both magnitude and phase concatenated as real values.
326+
'complex' returns both magnitude and phase concatenated as real values
327+
along the last dimension (when radial_bins=None) or along the feature
328+
dimension (when radial_bins is set).
299329
normalize: if True, normalize the output.
300330
301331
Returns:
302-
Concatenated radial Fourier features.
332+
Concatenated radial Fourier features. Shape: (..., total_features) where
333+
total_features = sum(bins * (2 if return_type=='complex' else 1) for bins in n_bins_list).
334+
335+
Raises:
336+
ValueError: If n_bins_list or return_types is empty.
303337
304338
Example:
305339
>>> transform = RadialFourierFeatures3D(n_bins_list=[32, 64, 128])
306340
>>> image = torch.randn(1, 128, 128, 96)
307341
>>> features = transform(image) # Shape: (1, 32+64+128=224)
342+
>>>
343+
>>> transform = RadialFourierFeatures3D(n_bins_list=[16, 32], return_types=['complex'])
344+
>>> image = torch.randn(1, 128, 128, 96)
345+
>>> features = transform(image) # Shape: (1, (16+32)*2=96)
308346
"""
309347

310348
def __init__(
@@ -343,17 +381,14 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
343381
feat = transform(img)
344382
features.append(feat)
345383

346-
# Convert all features to tensors if any are numpy arrays
384+
# Convert all features to tensors using convert_data_type
347385
features_tensors = []
348386
for feat in features:
349-
if isinstance(feat, np.ndarray):
350-
features_tensors.append(torch.from_numpy(feat))
351-
else:
352-
features_tensors.append(feat)
387+
feat_tensor, *_ = convert_data_type(feat, torch.Tensor)
388+
features_tensors.append(feat_tensor)
353389
output = torch.cat(features_tensors, dim=-1)
354390

355-
# Convert to original type if needed
356-
if isinstance(img, np.ndarray):
357-
output = output.cpu().numpy()
391+
# Convert back to original type
392+
output, *_ = convert_data_type(output, type(img))
358393

359394
return output

0 commit comments

Comments
 (0)