@@ -49,7 +49,7 @@ class DWI(BaseDataset):
4949 be unwarped.
5050 """
5151 gradients = attr .ib (default = None , repr = _data_repr , eq = attr .cmp_using (eq = _cmp ))
52- """A 2D numpy array of the gradient table in RAS+B format (Nx4 )."""
52+ """A 2D numpy array of the gradient table (4xN )."""
5353 eddy_xfms = attr .ib (default = None )
5454 """List of transforms to correct for estimatted eddy current distortions."""
5555
@@ -73,12 +73,12 @@ def __getitem__(
7373 The corresponding per-volume motion affine(s) or `None` if identity transform(s).
7474 gradient : np.ndarray
7575 The corresponding gradient(s), which may have shape ``(4,)`` if a single volume
76- or ``(k, 4 )`` if multiple volumes, or None if gradients are not available.
76+ or ``(4, k )`` if multiple volumes, or None if gradients are not available.
7777
7878 """
7979
8080 data , affine = super ().__getitem__ (idx )
81- return data , affine , self .gradients [idx , ...]
81+ return data , affine , self .gradients [..., idx ]
8282
8383 def set_transform (self , index : int , affine : np .ndarray , order : int = 3 ) -> None :
8484 """
@@ -106,14 +106,14 @@ def set_transform(self, index: int, affine: np.ndarray, order: int = 3) -> None:
106106 shape = self .dataobj .shape [:3 ], affine = self .affine
107107 )
108108 xform = Affine (matrix = affine , reference = reference )
109- bvec = self .gradients [index , : 3 ]
109+ bvec = self .gradients [: 3 , index ]
110110
111111 # invert transform transform b-vector and origin
112112 r_bvec = (~ xform ).map ([bvec , (0.0 , 0.0 , 0.0 )])
113113 # Reset b-vector's origin
114114 new_bvec = r_bvec [1 ] - r_bvec [0 ]
115115 # Normalize and update
116- self .gradients [index , : 3 ] = new_bvec / np .linalg .norm (new_bvec )
116+ self .gradients [: 3 , index ] = new_bvec / np .linalg .norm (new_bvec )
117117
118118 super ().set_transform (index , affine , order )
119119
@@ -172,7 +172,7 @@ def to_filename(
172172 with h5py .File (filename , "r+" ) as out_file :
173173 out_file .attrs ["Type" ] = "dmri"
174174
175- def to_nifti (self , filename : Path | str ) -> None :
175+ def to_nifti (self , filename : Path | str , insert_b0 : bool = False ) -> None :
176176 """
177177 Write a NIfTI 1.0 file to disk, and also write out the gradient table
178178 to sidecar text files (.bvec, .bval).
@@ -183,8 +183,15 @@ def to_nifti(self, filename: Path | str) -> None:
183183 The output NIfTI file path.
184184
185185 """
186- # First call the parent's to_nifti to handle the primary NIfTI export.
187- super ().to_nifti (filename )
186+ if not insert_b0 :
187+ # Parent's to_nifti to handle the primary NIfTI export.
188+ super ().to_nifti (filename )
189+ else :
190+ data = np .concatenate ((self .bzero [..., np .newaxis ], self .dataobj ), axis = - 1 )
191+ nii = nb .Nifti1Image (data , self .affine , self .datahdr )
192+ if self .datahdr is None :
193+ nii .header .set_xyzt_units ("mm" )
194+ nii .to_filename (filename )
188195
189196 # Convert filename to a Path object.
190197 out_root = Path (filename ).absolute ()
@@ -202,8 +209,8 @@ def to_nifti(self, filename: Path | str) -> None:
202209
203210 # Save bvecs and bvals to text files
204211 # Each row of bvecs is one direction (3 rows, N columns).
205- np .savetxt (bvecs_file , self .gradients [..., : 3 ].T , fmt = "%.6f" )
206- np .savetxt (bvals_file , self .gradients [..., - 1 ], fmt = "%.6f" )
212+ np .savetxt (bvecs_file , self .gradients [: 3 , ...].T , fmt = "%.6f" )
213+ np .savetxt (bvals_file , self .gradients [: 3 , ...], fmt = "%.6f" )
207214
208215
209216def load (
@@ -297,7 +304,7 @@ def load(
297304 # We'll assign the filtered gradients below.
298305 )
299306
300- dwi_obj .gradients = grad [:, gradmsk ] if grad .shape [0 ] == 4 else grad [gradmsk , :]
307+ dwi_obj .gradients = grad [:, gradmsk ] if grad .shape [0 ] == 4 else grad [gradmsk , :]. T
301308
302309 # 6) b=0 volume (bzero)
303310 # If the user provided a b0_file, load it
0 commit comments