6363DTI_MIN_ORIENTATIONS = 6
6464"""Minimum number of nonzero b-values in a DWI dataset."""
6565
66- GRADIENT_VOLUME_DIMENSIONALITY_MISMATCH_ERROR = "Gradient table shape does not match the number of diffusion volumes: expected {n_volumes} rows, found {n_gradients}."
66+ GRADIENT_VOLUME_DIMENSIONALITY_MISMATCH_ERROR = """\
67+ Gradient table shape does not match the number of diffusion volumes: \
68+ expected {n_volumes} rows, found {n_gradients}."""
6769"""dMRI volume count vs. gradient count mismatch error message."""
6870
69- GRADIENT_BVAL_BVEC_PRIORITY_WARN_MSG = "Both a gradients table file and b-vec/val files are defined; ignoring b-vec/val files in favor of the gradients_file."
71+ GRADIENT_BVAL_BVEC_PRIORITY_WARN_MSG = """\
72+ Both a gradients table file and b-vec/val files are defined; \
73+ ignoring b-vec/val files in favor of the gradients_file."""
7074""""dMRI gradient file priority warning message."""
7175
7276GRADIENT_NDIM_ERROR_MSG = "Gradient table must be a 2D array"
7377"""dMRI gradient dimensionality error message."""
7478
75- GRADIENT_DATA_MISSING_ERROR = (
76- "No gradient data provided. Please specify either a gradients_file or (bvec_file & bval_file)."
77- )
79+ GRADIENT_DATA_MISSING_ERROR = "No gradient data provided."
7880"""dMRI missing gradient data error message."""
7981
8082GRADIENT_EXPECTED_COLUMNS_ERROR_MSG = (
8385"""dMRI gradient expected columns error message."""
8486
8587
88+ def format_gradients (
89+ value : npt .ArrayLike | None ,
90+ ) -> np .ndarray | None :
91+ """
92+ Validate and orient gradient tables to row-major convention.
93+
94+ Parameters
95+ ----------
96+ value : :obj:`ArrayLike`
97+ The value to format.
98+
99+ Returns
100+ -------
101+ :obj:`~numpy.ndarray`
102+ Row-major convention gradient table.
103+
104+ Raises
105+ ------
106+ exc:`ValueError`
107+ If ``value`` is not a 2D :obj:`~numpy.ndarray` (``value.ndim != 2``).
108+
109+ Examples
110+ --------
111+ Passing an already well-formed table returns the data unchanged::
112+
113+ >>> format_gradients(
114+ ... [
115+ ... [1, 0, 0, 0],
116+ ... [0, 1, 0, 1000],
117+ ... [0, 0, 1, 2000],
118+ ... [0, 0, 0, 0],
119+ ... [0, 0, 0, 1000],
120+ ... ]
121+ ... )
122+ array([[ 1, 0, 0, 0],
123+ [ 0, 1, 0, 1000],
124+ [ 0, 0, 1, 2000],
125+ [ 0, 0, 0, 0],
126+ [ 0, 0, 0, 1000]])
127+
128+ Column-major inputs are automatically transposed when an expected
129+ number of diffusion volumes is provided::
130+
131+ >>> format_gradients(
132+ ... [[1, 0], [0, 1], [0, 0], [1000, 2000]],
133+ ... )
134+ array([[ 1, 0, 0, 1000],
135+ [ 0, 1, 0, 2000]])
136+
137+ Gradient tables must always have two dimensions::
138+
139+ >>> format_gradients([0, 1, 0, 1000])
140+ Traceback (most recent call last):
141+ ...
142+ ValueError: Gradient table must be a 2D array
143+
144+ """
145+
146+ formatted = np .asarray (value )
147+ if formatted .ndim != 2 :
148+ raise ValueError (GRADIENT_NDIM_ERROR_MSG )
149+
150+ # Transpose if column-major
151+ return formatted .T if formatted .shape [0 ] == 4 and formatted .shape [1 ] != 4 else formatted
152+
153+
154+ def validate_gradients (
155+ inst : DWI ,
156+ attr : attrs .Attribute ,
157+ value : npt .NDArray [np .floating ],
158+ ) -> None :
159+ """Strict validator for use in attribute validation (e.g. attrs / validators).
160+
161+ Ensures row-major convention for gradient table.
162+
163+ This function is intended for use as an attrs-style validator.
164+
165+ Parameters
166+ ----------
167+ inst : :obj:`~nifreeze.data.dmri.DWI`
168+ The instance being validated (unused; present for validator signature).
169+ attr : :obj:`~attrs.Attribute`
170+ The attribute being validated; attr.name is used in the error message.
171+ value : :obj:`~npt.NDArray`
172+ The value to validate.
173+ """
174+ if value .shape [1 ] != 4 :
175+ raise ValueError (GRADIENT_EXPECTED_COLUMNS_ERROR_MSG )
176+
177+
86178@attrs .define (slots = True )
87179class DWI (BaseDataset [np .ndarray ]):
88180 """Data representation structure for dMRI data."""
89181
182+ gradients : np .ndarray = attrs .field (
183+ default = None ,
184+ repr = _data_repr ,
185+ eq = attrs .cmp_using (eq = _cmp ),
186+ converter = format_gradients ,
187+ validator = validate_gradients ,
188+ )
189+ """A 2D numpy array of the gradient table (``N`` orientations x ``C`` components)."""
90190 bzero : np .ndarray | None = attrs .field (
91191 default = None , repr = _data_repr , eq = attrs .cmp_using (eq = _cmp )
92192 )
93- """A *b=0* reference map, preferably obtained by some smart averaging."""
94- gradients : np .ndarray = attrs .field (default = None , repr = _data_repr , eq = attrs .cmp_using (eq = _cmp ))
95- """A 2D numpy array of the gradient table (``N`` orientations x ``C`` components)."""
193+ """A *b=0* reference map, computed automatically when low-b frames are present."""
96194 eddy_xfms : list = attrs .field (default = None )
97195 """List of transforms to correct for estimated eddy current distortions."""
98196
99197 def __attrs_post_init__ (self ) -> None :
100- self ._normalize_gradients ()
101-
102- def _normalize_gradients (self ) -> None :
103198 if self .gradients is None :
104- return
105-
106- gradients = np .asarray (self .gradients )
107- if gradients .ndim != 2 :
108- raise ValueError (GRADIENT_NDIM_ERROR_MSG )
109- if gradients .shape [1 ] == 4 :
110- pass
111- elif gradients .shape [0 ] == 4 :
112- gradients = gradients .T
113- else :
114- raise ValueError (GRADIENT_EXPECTED_COLUMNS_ERROR_MSG )
115-
116- n_volumes = None
117- if self .dataobj is not None :
118- try :
119- n_volumes = self .dataobj .shape [- 1 ]
120- except Exception : # pragma: no cover - extremely defensive
121- n_volumes = None
199+ raise ValueError (GRADIENT_DATA_MISSING_ERROR )
122200
123- if n_volumes is not None and gradients . shape [0 ] != n_volumes :
201+ if self . dataobj . shape [- 1 ] != self . gradients . shape [ 0 ] :
124202 raise ValueError (
125- GRADIENT_VOLUME_DIMENSIONALITY_MISMATCH_MISSING_ERROR .format (
126- n_volumes = n_volumes , n_gradients = gradients .shape [0 ]
203+ GRADIENT_VOLUME_DIMENSIONALITY_MISMATCH_ERROR .format (
204+ n_volumes = self .dataobj .shape [- 1 ],
205+ n_gradients = self .gradients .shape [0 ],
127206 )
128207 )
129208
130- self .gradients = gradients
209+ b0_mask = self .gradients [:, - 1 ] <= DEFAULT_LOWB_THRESHOLD
210+ b0_num = np .sum (b0_mask )
211+
212+ if b0_num > 0 and self .bzero is None :
213+ bzeros = self .dataobj [..., b0_mask ]
214+ self .bzero = bzeros if bzeros .ndim == 3 else np .median (bzeros , axis = - 1 )
215+
216+ if b0_num > 0 :
217+ # Remove b0 volumes from dataobj and gradients
218+ self .gradients = self .gradients [~ b0_mask , :]
219+ self .dataobj = self .dataobj [..., ~ b0_mask ]
220+
221+ if self .gradients .shape [0 ] < DTI_MIN_ORIENTATIONS :
222+ raise ValueError (
223+ f"DWI datasets must have at least { DTI_MIN_ORIENTATIONS } diffusion-weighted "
224+ f"orientations; found { self .dataobj .shape [- 1 ]} ."
225+ )
131226
132227 def _getextra (self , idx : int | slice | tuple | np .ndarray ) -> tuple [np .ndarray ]:
133228 return (self .gradients [idx , ...],)
@@ -339,12 +434,10 @@ def to_nifti(
339434def from_nii (
340435 filename : Path | str ,
341436 brainmask_file : Path | str | None = None ,
342- motion_file : Path | str | None = None ,
343437 gradients_file : Path | str | None = None ,
344438 bvec_file : Path | str | None = None ,
345439 bval_file : Path | str | None = None ,
346440 b0_file : Path | str | None = None ,
347- b0_thres : float = DEFAULT_LOWB_THRESHOLD ,
348441) -> DWI :
349442 """
350443 Load DWI data from NIfTI and construct a DWI object.
@@ -359,8 +452,6 @@ def from_nii(
359452 brainmask_file : :obj:`os.pathlike`, optional
360453 A brainmask NIfTI file. If provided, will be loaded and
361454 stored in the returned dataset.
362- motion_file : :obj:`os.pathlike`, optional
363- A file containing head motion affine matrices (linear)
364455 gradients_file : :obj:`os.pathlike`, optional
365456 A text file containing the gradients table, shape (N, C) where the last column
366457 stores the b-values. If provided following the column-major convention(C, N),
@@ -373,9 +464,6 @@ def from_nii(
373464 b0_file : :obj:`os.pathlike`, optional
374465 A NIfTI file containing a b=0 volume (possibly averaged or reference).
375466 If not provided, and the data contains at least one b=0 volume, one will be computed.
376- b0_thres : float, optional
377- Threshold for determining which volumes are considered DWI vs. b=0
378- if you combine them in the same file.
379467
380468 Returns
381469 -------
@@ -390,10 +478,6 @@ def from_nii(
390478 ``bvec_file`` + ``bval_file``).
391479
392480 """
393-
394- if motion_file :
395- raise NotImplementedError
396-
397481 filename = Path (filename )
398482
399483 # 1) Load a NIfTI
@@ -405,18 +489,8 @@ def from_nii(
405489 grad = np .loadtxt (gradients_file , dtype = "float32" )
406490 if bvec_file and bval_file :
407491 warn (GRADIENT_BVAL_BVEC_PRIORITY_WARN_MSG , stacklevel = 2 )
408- if grad .ndim != 2 :
409- raise ValueError (GRADIENT_NDIM_ERROR_MSG )
410- if grad .shape [1 ] == 4 :
411- pass
412- elif grad .shape [0 ] == 4 :
413- grad = grad .T
414- else :
415- raise ValueError (GRADIENT_EXPECTED_COLUMNS_ERROR_MSG )
416492 elif bvec_file and bval_file :
417493 bvecs = np .loadtxt (bvec_file , dtype = "float32" )
418- if bvecs .ndim == 1 :
419- bvecs = bvecs [np .newaxis , :]
420494 if bvecs .shape [1 ] != 3 and bvecs .shape [0 ] == 3 :
421495 bvecs = bvecs .T
422496
@@ -427,40 +501,26 @@ def from_nii(
427501 else :
428502 raise RuntimeError (GRADIENT_DATA_MISSING_ERROR )
429503
430- # 3) Create the DWI instance. We'll filter out volumes where b-value > b0_thres
431- # as "DW volumes" if the user wants to store only the high-b volumes here
432- gradmsk = grad [:, - 1 ] > b0_thres
433-
434- dwi_obj = DWI (
435- dataobj = fulldata [..., gradmsk ],
436- affine = img .affine ,
437- # We'll assign the filtered gradients below.
438- )
439-
440- dwi_obj .gradients = grad [gradmsk , :]
441- dwi_obj ._normalize_gradients ()
442-
443- # 4) b=0 volume (bzero)
444- # If the user provided a b0_file, load it
504+ # 3) Read b-zero volume if provided
505+ b0_data = None
445506 if b0_file :
446507 b0img = load_api (b0_file , SpatialImage )
447- b0vol = np .asanyarray (b0img .dataobj )
448- # We'll assume your DWI class has a bzero: np.ndarray | None attribute
449- dwi_obj .bzero = b0vol
450- # Otherwise, if any volumes remain outside gradmsk, compute a median B0:
451- elif np .any (~ gradmsk ):
452- # The b=0 volumes are those that did NOT pass b0_thres
453- b0_volumes = fulldata [..., ~ gradmsk ]
454- # A simple approach is to take the median across that last dimension
455- # Note that axis=3 is valid only if your data is 4D (x, y, z, volumes).
456- dwi_obj .bzero = np .median (b0_volumes , axis = 3 )
457-
458- # 5) If a brainmask_file was provided, load it
508+ b0_data = np .asanyarray (b0img .dataobj )
509+
510+ # 4) If a brainmask_file was provided, load it
511+ brainmask_data = None
459512 if brainmask_file :
460513 mask_img = load_api (brainmask_file , SpatialImage )
461- dwi_obj . brainmask = np .asanyarray (mask_img .dataobj , dtype = bool )
514+ brainmask_data = np .asanyarray (mask_img .dataobj , dtype = bool )
462515
463- return dwi_obj
516+ # 5) Create and return the DWI instance.
517+ return DWI (
518+ dataobj = fulldata ,
519+ affine = img .affine ,
520+ gradients = grad ,
521+ bzero = b0_data ,
522+ brainmask = brainmask_data ,
523+ )
464524
465525
466526def find_shelling_scheme (
0 commit comments