Skip to content

Commit 7e3a9da

Browse files
author
Asma TANABENE
committed
Fixes #258: Enable NUFFT operator to support batched sensitivity maps for batched inputs
1 parent 7213460 commit 7e3a9da

File tree

5 files changed

+293
-52
lines changed

5 files changed

+293
-52
lines changed

src/mrinufft/operators/autodiff.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,94 @@ def samples(self, value):
152152
def __getattr__(self, name):
153153
"""Forward all other attributes to the nufft_op."""
154154
return getattr(self.nufft_op, name)
155+
156+
157+
class BatchedNufftAutoGrad(MRINufftAutoGrad):
158+
"""
159+
A batched wrapper for NUFFT operator with support for autodifferentiation
160+
and varying sensitivity maps (smaps) across batches.
161+
162+
Parameters
163+
----------
164+
nufft : object
165+
An instance of a standard NUFFT operator.
166+
batch_size : int
167+
Number of batches to process simultaneously.
168+
**kwargs : dict, optional
169+
Additional arguments.
170+
171+
Notes
172+
-----
173+
#TODO Future improvements may include support for varying trajectories across batches.
174+
"""
175+
176+
def __init__(self, nufft_op, wrt_data=True, wrt_traj=False, batch_size=1, **kwargs):
177+
super().__init__(
178+
nufft_op=nufft_op, wrt_data=wrt_data, wrt_traj=wrt_traj
179+
)
180+
self.batch_size = batch_size
181+
182+
def op(self, batched_smaps, batched_imgs):
183+
"""Compute the forward batched_imgs -> batched_kspace."""
184+
self._check_input_shape(imgs=batched_imgs)
185+
self._check_input_shape(imgs=batched_smaps)
186+
batched_kspace = []
187+
for i in range(self.batch_size):
188+
try:
189+
# update smaps for proper backward computation
190+
self.nufft_op.smaps = batched_smaps[i]
191+
batched_kspace.append(
192+
_NUFFT_OP.apply(batched_imgs[i], self.samples, self.nufft_op)
193+
)
194+
except Exception as e:
195+
raise RuntimeError(
196+
f"Failed at batch index {i+1}: {e}"
197+
) # For an easier debugging
198+
return torch.stack(batched_kspace, dim=0)
199+
200+
def adj_op(self, batched_smaps, batched_kspace):
201+
"""Compute the adjoint batched_kspace -> batched_imgs."""
202+
self._check_input_shape(ksps=batched_kspace)
203+
self._check_input_shape(imgs=batched_smaps)
204+
batched_imgs = []
205+
for i in range(self.batch_size):
206+
try:
207+
self.nufft_op.smaps = batched_smaps[i]
208+
batched_imgs.append(
209+
_NUFFT_ADJOP.apply(batched_kspace[i], self.samples, self.nufft_op)
210+
)
211+
except Exception as e:
212+
raise RuntimeError(f"Failed at batch index {i+1}: {e}")
213+
return torch.stack(batched_imgs, dim=0)
214+
215+
def _check_input_shape(self, *, imgs=None, ksps=None):
216+
"""
217+
Validates the batch size of either image or k-space input against the expected batch size.
218+
219+
Parameters
220+
----------
221+
imgs : np.ndarray, optional
222+
Image data array. If provided, its batch dimension will be validated.
223+
224+
ksps : np.ndarray or object, optional
225+
K-space data array or compatible object. If provided, its batch dimension will be validated.
226+
227+
Raises
228+
------
229+
ValueError
230+
If the batch size of the image or k-space input does not match the expected batch size.
231+
"""
232+
if imgs is not None:
233+
if imgs.shape[0] != self.batch_size:
234+
raise ValueError(
235+
f"Image batch size mismatch: got {imgs.shape[0]}, expected {self.batch_size}. "
236+
f"Image shape: {imgs.shape}"
237+
)
238+
if ksps is not None:
239+
if ksps.shape[0] != self.batch_size:
240+
raise ValueError(
241+
f"K-space batch size mismatch: got {ksps.shape[0]}, expected {self.batch_size}. "
242+
f"K-space shape: {ksps.shape}"
243+
)
244+
if imgs is None and ksps is None:
245+
raise ValueError("Provide either `imgs` or `ksps` as input.")

src/mrinufft/operators/base.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,13 @@ def list_backends(available_only=False):
5050

5151

5252
def get_operator(
53-
backend_name: str, wrt_data: bool = False, wrt_traj: bool = False, *args, **kwargs
53+
backend_name: str,
54+
wrt_data: bool = False,
55+
wrt_traj: bool = False,
56+
use_batched_mode: bool = False,
57+
batch_size: int = 1,
58+
*args,
59+
**kwargs,
5460
):
5561
"""Return an MRI Fourier operator interface using the correct backend.
5662
@@ -62,6 +68,10 @@ def get_operator(
6268
if set gradients wrt to data and images will be available.
6369
wrt_traj: bool, default False
6470
if set gradients wrt to trajectory will be available.
71+
use_batched_mode : bool, optional
72+
If True, uses a batched version of the NUFFT operator that supports varying data/smaps pairs.
73+
batch_size : int, optional
74+
Batch size to be used in batched mode. Only relevant if `use_batched_mode=True`. Default is 1.
6575
*args, **kwargs:
6676
Arguments to pass to the operator constructor.
6777
@@ -97,10 +107,14 @@ class or instance of class if args or kwargs are given.
97107
# if autograd:
98108
if wrt_data or wrt_traj:
99109
if isinstance(operator, FourierOperatorBase):
100-
operator = operator.make_autograd(wrt_data, wrt_traj)
110+
operator = operator.make_autograd(
111+
wrt_data, wrt_traj, use_batched_mode, batch_size
112+
)
101113
else:
102114
# instance will be created later
103-
operator = partial(operator.with_autograd, wrt_data, wrt_traj)
115+
operator = partial(
116+
operator.with_autograd, wrt_data, wrt_traj, use_batched_mode, batch_size
117+
)
104118

105119
return operator
106120

@@ -257,7 +271,9 @@ def compute_smaps(self, method: NDArray | Callable | str | dict | None = None):
257271
**kwargs,
258272
)
259273

260-
def make_autograd(self, wrt_data=True, wrt_traj=False):
274+
def make_autograd(
275+
self, wrt_data=True, wrt_traj=False, use_batched_mode=False, batch_size=1
276+
):
261277
"""Make a new Operator with autodiff support.
262278
263279
Parameters
@@ -271,6 +287,12 @@ def make_autograd(self, wrt_data=True, wrt_traj=False):
271287
wrt_traj : bool, optional
272288
If the gradient with respect to the trajectory is computed, default is false
273289
290+
use_batched_mode : bool, optional
291+
If True, uses a batched version of the NUFFT operator that supports varying smaps
292+
293+
batch_size : int, optional
294+
Batch size to be used in batched mode. Only relevant if `use_batched_mode=True`. Default is 1.
295+
274296
Returns
275297
-------
276298
torch.nn.module
@@ -286,9 +308,20 @@ def make_autograd(self, wrt_data=True, wrt_traj=False):
286308
if not self.autograd_available:
287309
raise ValueError("Backend does not support auto-differentiation.")
288310

289-
from mrinufft.operators.autodiff import MRINufftAutoGrad
311+
if use_batched_mode:
312+
if batch_size < 1:
313+
raise ValueError(
314+
"Provide a valid batch size." f"Batch size : {batch_size}"
315+
)
316+
from mrinufft.operators.autodiff import BatchedNufftAutoGrad
317+
318+
return BatchedNufftAutoGrad(
319+
self, wrt_data=wrt_data, wrt_traj=wrt_traj, batch_size=batch_size
320+
)
321+
else:
322+
from mrinufft.operators.autodiff import MRINufftAutoGrad
290323

291-
return MRINufftAutoGrad(self, wrt_data=wrt_data, wrt_traj=wrt_traj)
324+
return MRINufftAutoGrad(self, wrt_data=wrt_data, wrt_traj=wrt_traj)
292325

293326
def compute_density(self, method=None):
294327
"""Compute the density compensation weights and set it.
@@ -476,9 +509,19 @@ def __repr__(self):
476509
)
477510

478511
@classmethod
479-
def with_autograd(cls, wrt_data=True, wrt_traj=False, *args, **kwargs):
512+
def with_autograd(
513+
cls,
514+
wrt_data=True,
515+
wrt_traj=False,
516+
use_batched_mode=False,
517+
batch_size=1,
518+
*args,
519+
**kwargs,
520+
):
480521
"""Return a Fourier operator with autograd capabilities."""
481-
return cls(*args, **kwargs).make_autograd(wrt_data, wrt_traj)
522+
return cls(*args, **kwargs).make_autograd(
523+
wrt_data, wrt_traj, use_batched_mode, batch_size
524+
)
482525

483526

484527
class FourierOperatorCPU(FourierOperatorBase):

tests/helpers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .factories import (
55
kspace_from_op,
66
image_from_op,
7+
batchedSmpas_from_op,
78
to_interface,
89
from_interface,
910
CUPY_AVAILABLE,
@@ -21,4 +22,5 @@
2122
"CUPY_AVAILABLE",
2223
"TORCH_AVAILABLE",
2324
"param_array_interface",
25+
"batchedSmpas_from_op",
2426
]

tests/helpers/factories.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,28 +19,42 @@
1919

2020
def image_from_op(operator):
2121
"""Generate a random image."""
22+
batch_dim = (operator.batch_size,) if hasattr(operator, "batch_size") else ()
2223
if operator.smaps is None:
23-
img = np.random.randn(operator.n_coils, *operator.shape).astype(
24+
img = np.random.randn(*batch_dim, operator.n_coils, *operator.shape).astype(
2425
operator.cpx_dtype
2526
)
2627
elif operator.smaps is not None and operator.n_coils > 1:
27-
img = np.random.randn(*operator.shape).astype(operator.cpx_dtype)
28+
img = np.random.randn(*batch_dim, *operator.shape).astype(operator.cpx_dtype)
2829

2930
img += 1j * np.random.randn(*img.shape).astype(operator.cpx_dtype)
3031
return img
3132

3233

3334
def kspace_from_op(operator):
3435
"""Generate a random kspace data."""
35-
kspace = (1j * np.random.randn(operator.n_coils, operator.n_samples)).astype(
36-
operator.cpx_dtype
37-
)
38-
kspace += np.random.randn(operator.n_coils, operator.n_samples).astype(
36+
batch_dim = (operator.batch_size,) if hasattr(operator, "batch_size") else ()
37+
kspace = (
38+
1j * np.random.randn(*batch_dim, operator.n_coils, operator.n_samples)
39+
).astype(operator.cpx_dtype)
40+
kspace += np.random.randn(*batch_dim, operator.n_coils, operator.n_samples).astype(
3941
operator.cpx_dtype
4042
)
4143
return kspace
4244

4345

46+
def batchedSmpas_from_op(operator):
47+
"""Generate random batched smaps."""
48+
smaps = 1j * np.random.randn(
49+
operator.batch_size, operator.n_coils, *operator.shape
50+
).astype(np.complex64)
51+
smaps += np.random.randn(
52+
operator.batch_size, operator.n_coils, *operator.shape
53+
).astype(np.complex64)
54+
smaps /= np.linalg.norm(smaps, axis=1, keepdims=True)
55+
return smaps
56+
57+
4458
def to_interface(data, interface):
4559
"""Make DATA an array from INTERFACE."""
4660
if interface == "cupy":

0 commit comments

Comments
 (0)