@@ -71,12 +71,43 @@ class DWI(BaseDataset[np.ndarray]):
7171 bzero : np .ndarray = attrs .field (default = None , repr = _data_repr , eq = attrs .cmp_using (eq = _cmp ))
7272 """A *b=0* reference map, preferably obtained by some smart averaging."""
7373 gradients : np .ndarray = attrs .field (default = None , repr = _data_repr , eq = attrs .cmp_using (eq = _cmp ))
74- """A 2D numpy array of the gradient table (4xN )."""
74+ """A 2D numpy array of the gradient table (``N`` orientations x ``C`` components )."""
7575 eddy_xfms : list = attrs .field (default = None )
7676 """List of transforms to correct for estimated eddy current distortions."""
7777
78+ def __attrs_post_init__ (self ) -> None :
79+ self ._normalize_gradients ()
80+
81+ def _normalize_gradients (self ) -> None :
82+ if self .gradients is None :
83+ return
84+
85+ gradients = np .asarray (self .gradients )
86+ if gradients .ndim != 2 :
87+ raise ValueError ("Gradient table must be a 2D array" )
88+
89+ n_volumes = None
90+ if self .dataobj is not None :
91+ try :
92+ n_volumes = self .dataobj .shape [- 1 ]
93+ except Exception : # pragma: no cover - extremely defensive
94+ n_volumes = None
95+
96+ if n_volumes is not None and gradients .shape [0 ] != n_volumes :
97+ if gradients .shape [1 ] == n_volumes :
98+ gradients = gradients .T
99+ else :
100+ raise ValueError (
101+ "Gradient table shape does not match the number of diffusion volumes: "
102+ f"expected { n_volumes } rows, found { gradients .shape [0 ]} "
103+ )
104+ elif n_volumes is None and gradients .shape [1 ] > gradients .shape [0 ]:
105+ gradients = gradients .T
106+
107+ self .gradients = gradients
108+
78109 def _getextra (self , idx : int | slice | tuple | np .ndarray ) -> tuple [np .ndarray ]:
79- return (self .gradients [..., idx ],)
110+ return (self .gradients [idx , ...],)
80111
81112 # For the sake of the docstring
82113 def __getitem__ (
@@ -99,8 +130,8 @@ def __getitem__(
99130 motion_affine : :obj:`~numpy.ndarray` or ``None``
100131 The corresponding per-volume motion affine(s) or ``None`` if identity transform(s).
101132 gradient : :obj:`~numpy.ndarray`
102- The corresponding gradient(s), which may have shape ``(4 ,)`` if a single volume
103- or ``(4, k )`` if multiple volumes, or ``None`` if gradients are not available.
133+ The corresponding gradient(s), which may have shape ``(C ,)`` if a single volume
134+ or ``(k, C )`` if multiple volumes, or ``None`` if gradients are not available.
104135
105136 """
106137
@@ -126,11 +157,11 @@ def from_filename(cls, filename: Path | str) -> Self:
126157
127158 @property
128159 def bvals (self ):
129- return self .gradients [- 1 , ...]
160+ return self .gradients [..., - 1 ]
130161
131162 @property
132163 def bvecs (self ):
133- return self .gradients [: - 1 , ...]
164+ return self .gradients [..., : - 1 ]
134165
135166 def get_shells (
136167 self ,
@@ -160,14 +191,12 @@ def get_shells(
160191 """
161192
162193 _ , bval_groups , bval_estimated = find_shelling_scheme (
163- self .gradients [ - 1 , ...] ,
194+ self .bvals ,
164195 num_bins = num_bins ,
165196 multishell_nonempty_bin_count_thr = multishell_nonempty_bin_count_thr ,
166197 bval_cap = bval_cap ,
167198 )
168- indices = [
169- np .hstack (np .where (np .isin (self .gradients [- 1 , ...], bvals ))) for bvals in bval_groups
170- ]
199+ indices = [np .where (np .isin (self .bvals , bvals ))[0 ] for bvals in bval_groups ]
171200 return list (zip (bval_estimated , indices , strict = True ))
172201
173202 def to_filename (
@@ -232,16 +261,14 @@ def to_nifti(
232261 bvals = self .bvals
233262
234263 # Rotate b-vectors if self.motion_affines is not None
235- bvecs = (
236- np .array (
237- [
238- transform_fsl_bvec (bvec , affine , self .affine , invert = True )
239- for bvec , affine in zip (self .gradients .T , self .motion_affines , strict = True )
240- ]
241- ).T
242- if self .motion_affines is not None
243- else self .bvecs
244- )
264+ if self .motion_affines is not None :
265+ rotated = [
266+ transform_fsl_bvec (bvec , affine , self .affine , invert = True )
267+ for bvec , affine in zip (self .gradients [:, :3 ], self .motion_affines , strict = True )
268+ ]
269+ bvecs = np .asarray (rotated )
270+ else :
271+ bvecs = self .bvecs
245272
246273 # Parent's to_nifti to handle the primary NIfTI export.
247274 nii = super ().to_nifti (
@@ -266,7 +293,7 @@ def to_nifti(
266293 # If inserting a b0 volume is requested, add the corresponding null
267294 # gradient value to the bval/bvec pair
268295 bvals = np .concatenate ((np .zeros (1 ), bvals ))
269- bvecs = np .concatenate ((np .zeros (3 )[:, np . newaxis ] , bvecs ), axis = - 1 )
296+ bvecs = np .vstack ((np .zeros (( 1 , bvecs . shape [ 1 ])) , bvecs ))
270297
271298 if filename is not None :
272299 # Convert filename to a Path object.
@@ -279,9 +306,8 @@ def to_nifti(
279306 bvecs_file = out_root .with_suffix (".bvec" )
280307 bvals_file = out_root .with_suffix (".bval" )
281308
282- # Save bvecs and bvals to text files
283- # Each row of bvecs is one direction (3 rows, N columns).
284- np .savetxt (bvecs_file , bvecs , fmt = f"%.{ bvecs_dec_places } f" )
309+ # Save bvecs and bvals to text files. BIDS expects 3 rows x N columns.
310+ np .savetxt (bvecs_file , bvecs .T , fmt = f"%.{ bvecs_dec_places } f" )
285311 np .savetxt (bvals_file , bvals [np .newaxis , :], fmt = f"%.{ bvals_dec_places } f" )
286312
287313 return nii
@@ -313,10 +339,12 @@ def from_nii(
313339 motion_file : :obj:`os.pathlike`, optional
314340 A file containing head motion affine matrices (linear)
315341 gradients_file : :obj:`os.pathlike`, optional
316- A text file containing the gradients table, shape (4, N) or (N, 4).
317- If provided, it supersedes any .bvec / .bval combination.
342+ A text file containing the gradients table, shape (N, C) where the last column
343+ stores the b-values. If provided following the column-major convention(C, N),
344+ it will be transposed automatically. If provided, it supersedes any .bvec / .bval
345+ combination.
318346 bvec_file : :obj:`os.pathlike`, optional
319- A text file containing b-vectors, shape (3, N).
347+ A text file containing b-vectors, shape (N, 3) or ( 3, N).
320348 bval_file : :obj:`os.pathlike`, optional
321349 A text file containing b-values, shape (N,).
322350 b0_file : :obj:`os.pathlike`, optional
@@ -359,31 +387,48 @@ def from_nii(
359387 stacklevel = 2 ,
360388 )
361389 elif bvec_file and bval_file :
362- bvecs = np .loadtxt (bvec_file , dtype = "float32" ) # shape (3, N)
363- if bvecs .shape [0 ] != 3 and bvecs .shape [1 ] == 3 :
390+ bvecs = np .loadtxt (bvec_file , dtype = "float32" )
391+ if bvecs .ndim == 1 :
392+ bvecs = bvecs [np .newaxis , :]
393+ if bvecs .shape [1 ] != 3 and bvecs .shape [0 ] == 3 :
364394 bvecs = bvecs .T
365395
366- bvals = np .loadtxt (bval_file , dtype = "float32" ) # shape (N,)
367- # Stack to shape (4, N)
368- grad = np .vstack ((bvecs , bvals ))
396+ bvals = np .loadtxt (bval_file , dtype = "float32" )
397+ if bvals .ndim > 1 :
398+ bvals = np .squeeze (bvals )
399+ grad = np .column_stack ((bvecs , bvals ))
369400 else :
370401 raise RuntimeError (
371402 "No gradient data provided. "
372403 "Please specify either a gradients_file or (bvec_file & bval_file)."
373404 )
374405
406+ if grad .ndim == 1 :
407+ grad = grad [np .newaxis , :]
408+
409+ if grad .shape [1 ] < 2 :
410+ raise ValueError ("Gradient table must have at least two columns (direction + b-value)." )
411+
412+ if grad .shape [1 ] != 4 :
413+ if grad .shape [0 ] == 4 :
414+ grad = grad .T
415+ else :
416+ raise ValueError (
417+ "Gradient table must have four columns (3 direction components and one b-value)."
418+ )
419+
375420 # 3) Create the DWI instance. We'll filter out volumes where b-value > b0_thres
376421 # as "DW volumes" if the user wants to store only the high-b volumes here
377- gradmsk = ( grad [- 1 ] if grad . shape [ 0 ] == 4 else grad [ :, - 1 ]) > b0_thres
422+ gradmsk = grad [:, - 1 ] > b0_thres
378423
379- # The shape checking is somewhat flexible: (4, N) or (N, 4)
380424 dwi_obj = DWI (
381425 dataobj = fulldata [..., gradmsk ],
382426 affine = img .affine ,
383427 # We'll assign the filtered gradients below.
384428 )
385429
386- dwi_obj .gradients = grad [:, gradmsk ] if grad .shape [0 ] == 4 else grad [gradmsk , :].T
430+ dwi_obj .gradients = grad [gradmsk , :]
431+ dwi_obj ._normalize_gradients ()
387432
388433 # 4) b=0 volume (bzero)
389434 # If the user provided a b0_file, load it
0 commit comments