@@ -50,7 +50,13 @@ def list_backends(available_only=False):
5050
5151
5252def get_operator (
53- backend_name : str , wrt_data : bool = False , wrt_traj : bool = False , * args , ** kwargs
53+ backend_name : str ,
54+ wrt_data : bool = False ,
55+ wrt_traj : bool = False ,
56+ use_batched_mode : bool = False ,
57+ batch_size : int = 1 ,
58+ * args ,
59+ ** kwargs ,
5460):
5561 """Return an MRI Fourier operator interface using the correct backend.
5662
@@ -62,6 +68,10 @@ def get_operator(
6268 if set gradients wrt to data and images will be available.
6369 wrt_traj: bool, default False
6470 if set gradients wrt to trajectory will be available.
71+ use_batched_mode : bool, optional
72+ If True, uses a batched version of the NUFFT operator that supports varying data/smaps pairs.
73+ batch_size : int, optional
74+ Batch size to be used in batched mode. Only relevant if `use_batched_mode=True`. Default is 1.
6575 *args, **kwargs:
6676 Arguments to pass to the operator constructor.
6777
@@ -97,10 +107,14 @@ class or instance of class if args or kwargs are given.
97107 # if autograd:
98108 if wrt_data or wrt_traj :
99109 if isinstance (operator , FourierOperatorBase ):
100- operator = operator .make_autograd (wrt_data , wrt_traj )
110+ operator = operator .make_autograd (
111+ wrt_data , wrt_traj , use_batched_mode , batch_size
112+ )
101113 else :
102114 # instance will be created later
103- operator = partial (operator .with_autograd , wrt_data , wrt_traj )
115+ operator = partial (
116+ operator .with_autograd , wrt_data , wrt_traj , use_batched_mode , batch_size
117+ )
104118
105119 return operator
106120
@@ -257,7 +271,9 @@ def compute_smaps(self, method: NDArray | Callable | str | dict | None = None):
257271 ** kwargs ,
258272 )
259273
260- def make_autograd (self , wrt_data = True , wrt_traj = False ):
274+ def make_autograd (
275+ self , wrt_data = True , wrt_traj = False , use_batched_mode = False , batch_size = 1
276+ ):
261277 """Make a new Operator with autodiff support.
262278
263279 Parameters
@@ -271,6 +287,12 @@ def make_autograd(self, wrt_data=True, wrt_traj=False):
271287 wrt_traj : bool, optional
272288 If the gradient with respect to the trajectory is computed, default is false
273289
290+ use_batched_mode : bool, optional
291+ If True, uses a batched version of the NUFFT operator that supports varying smaps
292+
293+ batch_size : int, optional
294+ Batch size to be used in batched mode. Only relevant if `use_batched_mode=True`. Default is 1.
295+
274296 Returns
275297 -------
276298 torch.nn.module
@@ -286,9 +308,20 @@ def make_autograd(self, wrt_data=True, wrt_traj=False):
286308 if not self .autograd_available :
287309 raise ValueError ("Backend does not support auto-differentiation." )
288310
289- from mrinufft .operators .autodiff import MRINufftAutoGrad
311+ if use_batched_mode :
312+ if batch_size < 1 :
313+ raise ValueError (
314+ "Provide a valid batch size." f"Batch size : { batch_size } "
315+ )
316+ from mrinufft .operators .autodiff import BatchedNufftAutoGrad
317+
318+ return BatchedNufftAutoGrad (
319+ self , wrt_data = wrt_data , wrt_traj = wrt_traj , batch_size = batch_size
320+ )
321+ else :
322+ from mrinufft .operators .autodiff import MRINufftAutoGrad
290323
291- return MRINufftAutoGrad (self , wrt_data = wrt_data , wrt_traj = wrt_traj )
324+ return MRINufftAutoGrad (self , wrt_data = wrt_data , wrt_traj = wrt_traj )
292325
293326 def compute_density (self , method = None ):
294327 """Compute the density compensation weights and set it.
@@ -476,9 +509,19 @@ def __repr__(self):
476509 )
477510
478511 @classmethod
479- def with_autograd (cls , wrt_data = True , wrt_traj = False , * args , ** kwargs ):
512+ def with_autograd (
513+ cls ,
514+ wrt_data = True ,
515+ wrt_traj = False ,
516+ use_batched_mode = False ,
517+ batch_size = 1 ,
518+ * args ,
519+ ** kwargs ,
520+ ):
480521 """Return a Fourier operator with autograd capabilities."""
481- return cls (* args , ** kwargs ).make_autograd (wrt_data , wrt_traj )
522+ return cls (* args , ** kwargs ).make_autograd (
523+ wrt_data , wrt_traj , use_batched_mode , batch_size
524+ )
482525
483526
484527class FourierOperatorCPU (FourierOperatorBase ):
0 commit comments