6363DTI_MIN_ORIENTATIONS = 6
6464"""Minimum number of nonzero b-values in a DWI dataset."""
6565
66- GRADIENT_VOLUME_DIMENSIONALITY_MISMATCH_MISSING_ERROR = "Gradient table shape does not match the number of diffusion volumes: expected {n_volumes} rows, found {n_gradients}."
66+ GRADIENT_VOLUME_DIMENSIONALITY_MISMATCH_MISSING_ERROR = """\
67+ Gradient table shape does not match the number of diffusion volumes: \
68+ expected {n_volumes} rows, found {n_gradients}."""
6769"""dMRI volume count vs. gradient count mismatch error message."""
6870
69- GRADIENT_BVAL_BVEC_PRIORITY_WARN_MSG = "Both a gradients table file and b-vec/val files are defined; ignoring b-vec/val files in favor of the gradients_file."
71+ GRADIENT_BVAL_BVEC_PRIORITY_WARN_MSG = """\
72+ Both a gradients table file and b-vec/val files are defined; \
73+ ignoring b-vec/val files in favor of the gradients_file."""
7074""""dMRI gradient file priority warning message."""
7175
7276GRADIENT_NDIM_ERROR_MSG = "Gradient table must be a 2D array"
7377"""dMRI gradient dimensionality error message."""
7478
75- GRADIENT_DATA_MISSING_ERROR = (
76- "No gradient data provided. Please specify either a gradients_file or (bvec_file & bval_file)."
77- )
79+ GRADIENT_DATA_MISSING_ERROR = "No gradient data provided."
7880"""dMRI missing gradient data error message."""
7981
8082GRADIENT_EXPECTED_COLUMNS_ERROR_MSG = (
8385"""dMRI gradient expected columns error message."""
8486
8587
88+ def format_gradients (
89+ value : npt .ArrayLike | None ,
90+ ) -> np .ndarray | None :
91+ """
92+ Validate and orient gradient tables to a consistent shape.
93+
94+ Examples
95+ --------
96+ Passing an already well-formed table returns the data unchanged::
97+
98+ >>> format_gradients(
99+ ... [
100+ ... [1, 0, 0, 0],
101+ ... [0, 1, 0, 1000],
102+ ... [0, 0, 1, 2000],
103+ ... [0, 0, 0, 0],
104+ ... [0, 0, 0, 1000],
105+ ... ]
106+ ... )
107+ array([[ 1, 0, 0, 0],
108+ [ 0, 1, 0, 1000],
109+ [ 0, 0, 1, 2000],
110+ [ 0, 0, 0, 0],
111+ [ 0, 0, 0, 1000]])
112+
113+ Column-major inputs are automatically transposed when an expected
114+ number of diffusion volumes is provided::
115+
116+ >>> format_gradients(
117+ ... [[1, 0], [0, 1], [0, 0], [1000, 2000]],
118+ ... )
119+ array([[ 1, 0, 0, 1000],
120+ [ 0, 1, 0, 2000]])
121+
122+ Gradient tables must always have two dimensions::
123+
124+ >>> format_gradients([0, 1, 0, 1000])
125+ Traceback (most recent call last):
126+ ...
127+ ValueError: Gradient table must be a 2D array
128+
129+ """
130+
131+ formatted = np .asarray (value )
132+ if formatted .ndim != 2 :
133+ raise ValueError (GRADIENT_NDIM_ERROR_MSG )
134+
135+ # Transpose if column-major
136+ return formatted .T if formatted .shape [0 ] == 4 and formatted .shape [1 ] != 4 else formatted
137+
138+
139+ def validate_gradients (
140+ instance : DWI ,
141+ attribute : attrs .Attribute ,
142+ value : npt .NDArray [np .floating ],
143+ ) -> None :
144+ """Ensure row-major convention for gradient table."""
145+ if value .shape [1 ] != 4 :
146+ raise ValueError (GRADIENT_EXPECTED_COLUMNS_ERROR_MSG )
147+
148+
86149@attrs .define (slots = True )
87150class DWI (BaseDataset [np .ndarray ]):
88151 """Data representation structure for dMRI data."""
89152
153+ gradients : np .ndarray = attrs .field (
154+ default = None ,
155+ repr = _data_repr ,
156+ eq = attrs .cmp_using (eq = _cmp ),
157+ converter = format_gradients ,
158+ validator = validate_gradients ,
159+ )
160+ """A 2D numpy array of the gradient table (``N`` orientations x ``C`` components)."""
90161 bzero : np .ndarray | None = attrs .field (
91162 default = None , repr = _data_repr , eq = attrs .cmp_using (eq = _cmp )
92163 )
93- """A *b=0* reference map, preferably obtained by some smart averaging."""
94- gradients : np .ndarray = attrs .field (default = None , repr = _data_repr , eq = attrs .cmp_using (eq = _cmp ))
95- """A 2D numpy array of the gradient table (``N`` orientations x ``C`` components)."""
164+ """A *b=0* reference map, computed automatically when low-b frames are present."""
96165 eddy_xfms : list = attrs .field (default = None )
97166 """List of transforms to correct for estimated eddy current distortions."""
98167
99168 def __attrs_post_init__ (self ) -> None :
100- self ._normalize_gradients ()
101-
102- def _normalize_gradients (self ) -> None :
103169 if self .gradients is None :
104- return
105-
106- gradients = np .asarray (self .gradients )
107- if gradients .ndim != 2 :
108- raise ValueError (GRADIENT_NDIM_ERROR_MSG )
109- if gradients .shape [1 ] == 4 :
110- pass
111- elif gradients .shape [0 ] == 4 :
112- gradients = gradients .T
113- else :
114- raise ValueError (GRADIENT_EXPECTED_COLUMNS_ERROR_MSG )
115-
116- n_volumes = None
117- if self .dataobj is not None :
118- try :
119- n_volumes = self .dataobj .shape [- 1 ]
120- except Exception : # pragma: no cover - extremely defensive
121- n_volumes = None
170+ raise ValueError (GRADIENT_DATA_MISSING_ERROR )
122171
123- if n_volumes is not None and gradients . shape [0 ] != n_volumes :
172+ if self . dataobj . shape [- 1 ] != self . gradients . shape [ 0 ] :
124173 raise ValueError (
125174 GRADIENT_VOLUME_DIMENSIONALITY_MISMATCH_MISSING_ERROR .format (
126- n_volumes = n_volumes , n_gradients = gradients .shape [0 ]
175+ n_volumes = self .dataobj .shape [- 1 ],
176+ n_gradients = self .gradients .shape [0 ],
127177 )
128178 )
129179
130- self .gradients = gradients
180+ b0_mask = self .gradients [:, - 1 ] <= DEFAULT_LOWB_THRESHOLD
181+ b0_num = np .sum (b0_mask )
182+
183+ if b0_num > 0 and self .bzero is None :
184+ bzeros = self .dataobj [..., b0_mask ]
185+ self .bzero = bzeros if bzeros .ndim == 3 else np .median (bzeros , axis = - 1 )
186+
187+ if b0_num > 0 :
188+ # Remove b0 volumes from dataobj and gradients
189+ self .gradients = self .gradients [~ b0_mask , :]
190+ self .dataobj = self .dataobj [..., ~ b0_mask ]
191+
192+ if self .gradients .shape [0 ] < DTI_MIN_ORIENTATIONS :
193+ raise ValueError (
194+ f"DWI datasets must have at least { DTI_MIN_ORIENTATIONS } diffusion-weighted "
195+ f"orientations; found { self .dataobj .shape [- 1 ]} ."
196+ )
131197
132198 def _getextra (self , idx : int | slice | tuple | np .ndarray ) -> tuple [np .ndarray ]:
133199 return (self .gradients [idx , ...],)
@@ -339,12 +405,10 @@ def to_nifti(
339405def from_nii (
340406 filename : Path | str ,
341407 brainmask_file : Path | str | None = None ,
342- motion_file : Path | str | None = None ,
343408 gradients_file : Path | str | None = None ,
344409 bvec_file : Path | str | None = None ,
345410 bval_file : Path | str | None = None ,
346411 b0_file : Path | str | None = None ,
347- b0_thres : float = DEFAULT_LOWB_THRESHOLD ,
348412) -> DWI :
349413 """
350414 Load DWI data from NIfTI and construct a DWI object.
@@ -359,8 +423,6 @@ def from_nii(
359423 brainmask_file : :obj:`os.pathlike`, optional
360424 A brainmask NIfTI file. If provided, will be loaded and
361425 stored in the returned dataset.
362- motion_file : :obj:`os.pathlike`, optional
363- A file containing head motion affine matrices (linear)
364426 gradients_file : :obj:`os.pathlike`, optional
365427 A text file containing the gradients table, shape (N, C) where the last column
366428 stores the b-values. If provided following the column-major convention(C, N),
@@ -373,9 +435,6 @@ def from_nii(
373435 b0_file : :obj:`os.pathlike`, optional
374436 A NIfTI file containing a b=0 volume (possibly averaged or reference).
375437 If not provided, and the data contains at least one b=0 volume, one will be computed.
376- b0_thres : float, optional
377- Threshold for determining which volumes are considered DWI vs. b=0
378- if you combine them in the same file.
379438
380439 Returns
381440 -------
@@ -390,10 +449,6 @@ def from_nii(
390449 ``bvec_file`` + ``bval_file``).
391450
392451 """
393-
394- if motion_file :
395- raise NotImplementedError
396-
397452 filename = Path (filename )
398453
399454 # 1) Load a NIfTI
@@ -405,18 +460,8 @@ def from_nii(
405460 grad = np .loadtxt (gradients_file , dtype = "float32" )
406461 if bvec_file and bval_file :
407462 warn (GRADIENT_BVAL_BVEC_PRIORITY_WARN_MSG , stacklevel = 2 )
408- if grad .ndim != 2 :
409- raise ValueError (GRADIENT_NDIM_ERROR_MSG )
410- if grad .shape [1 ] == 4 :
411- pass
412- elif grad .shape [0 ] == 4 :
413- grad = grad .T
414- else :
415- raise ValueError (GRADIENT_EXPECTED_COLUMNS_ERROR_MSG )
416463 elif bvec_file and bval_file :
417464 bvecs = np .loadtxt (bvec_file , dtype = "float32" )
418- if bvecs .ndim == 1 :
419- bvecs = bvecs [np .newaxis , :]
420465 if bvecs .shape [1 ] != 3 and bvecs .shape [0 ] == 3 :
421466 bvecs = bvecs .T
422467
@@ -427,40 +472,26 @@ def from_nii(
427472 else :
428473 raise RuntimeError (GRADIENT_DATA_MISSING_ERROR )
429474
430- # 3) Create the DWI instance. We'll filter out volumes where b-value > b0_thres
431- # as "DW volumes" if the user wants to store only the high-b volumes here
432- gradmsk = grad [:, - 1 ] > b0_thres
433-
434- dwi_obj = DWI (
435- dataobj = fulldata [..., gradmsk ],
436- affine = img .affine ,
437- # We'll assign the filtered gradients below.
438- )
439-
440- dwi_obj .gradients = grad [gradmsk , :]
441- dwi_obj ._normalize_gradients ()
442-
443- # 4) b=0 volume (bzero)
444- # If the user provided a b0_file, load it
475+ # 3) Read b-zero volume if provided
476+ b0_data = None
445477 if b0_file :
446478 b0img = load_api (b0_file , SpatialImage )
447- b0vol = np .asanyarray (b0img .dataobj )
448- # We'll assume your DWI class has a bzero: np.ndarray | None attribute
449- dwi_obj .bzero = b0vol
450- # Otherwise, if any volumes remain outside gradmsk, compute a median B0:
451- elif np .any (~ gradmsk ):
452- # The b=0 volumes are those that did NOT pass b0_thres
453- b0_volumes = fulldata [..., ~ gradmsk ]
454- # A simple approach is to take the median across that last dimension
455- # Note that axis=3 is valid only if your data is 4D (x, y, z, volumes).
456- dwi_obj .bzero = np .median (b0_volumes , axis = 3 )
457-
458- # 5) If a brainmask_file was provided, load it
479+ b0_data = np .asanyarray (b0img .dataobj )
480+
481+ # 4) If a brainmask_file was provided, load it
482+ brainmask_data = None
459483 if brainmask_file :
460484 mask_img = load_api (brainmask_file , SpatialImage )
461- dwi_obj . brainmask = np .asanyarray (mask_img .dataobj , dtype = bool )
485+ brainmask_data = np .asanyarray (mask_img .dataobj , dtype = bool )
462486
463- return dwi_obj
487+ # 5) Create and return the DWI instance.
488+ return DWI (
489+ dataobj = fulldata ,
490+ affine = img .affine ,
491+ gradients = grad ,
492+ bzero = b0_data ,
493+ brainmask = brainmask_data ,
494+ )
464495
465496
466497def find_shelling_scheme (
0 commit comments