3434import numpy as np
3535import numpy .typing as npt
3636from nibabel .spatialimages import SpatialImage
37+ from numpy .typing import ArrayLike
3738from typing_extensions import Self
3839
39- from nifreeze .data .base import BaseDataset , _cmp , _data_repr
40+ from nifreeze .data .base import BaseDataset , _cmp , _data_repr , _has_dim_size , _has_ndim
4041from nifreeze .utils .ndimage import get_data , load_api
4142
43+ GRADIENT_ABSENCE_ERROR_MSG = "DWI 'gradients' may not be None"
44+ """DWI initialization gradient absence error message."""
45+
46+ GRADIENT_OBJECT_ERROR_MSG = "DWI 'gradients' must be a numpy array."
47+ """DWI initialization gradient object error message."""
48+
49+ GRADIENT_COUNT_MISMATCH_ERROR_MSG = (
50+ "DWI gradients count ({n_gradients}) does not match dataset volumes ({data_vols})."
51+ )
52+ """DWI initialization gradient count mismatch error message."""
53+
4254DEFAULT_CLIP_PERCENTILE = 75
4355"""Upper percentile threshold for intensity clipping."""
4456
6476"""Minimum number of nonzero b-values in a DWI dataset."""
6577
6678
79+ def _check_gradient_shape (value : np .ndarray ) -> None :
80+ """Strictly validate a gradients ndarray.
81+
82+ Validates that ``value`` is a correctly-shaped NumPy array representing
83+ gradients. It performs a sequence of checks and raises :exc:`TypeError` or
84+ :exc:`ValueError` with intentionally explicit messages suitable for use by
85+ higher-level validators.
86+
87+ The following conditions raise an exception:
88+ - ``value`` is not a 2D :obj:`~numpy.ndarray`.
89+ - ``value`` does not have 4 columns.
90+
91+ Parameters
92+ ----------
93+ value : :obj:`~numpy.ndarray`
94+ The candidate gradients array.
95+
96+ Raises
97+ ------
98+ :exc:`ValueError`
99+ If ``value`` fails any of the checks described above.
100+
101+ Examples
102+ --------
103+ >>> _check_gradient_shape(np.zeros((10, 3))) # valid: does not raise
104+ >>> _check_gradient_shape(np.asarray([[1, 2, 3], [1, 2]]) # raises ValueError
105+ >>> _check_gradient_shape(np.zeros((5,))) # raises ValueError
106+ >>> _check_gradient_shape(np.zeros((2, 6))) # raises ValueError
107+ """
108+
109+ if value is None :
110+ raise ValueError (GRADIENT_ABSENCE_ERROR_MSG )
111+
112+ # Reject ragged/object-dtype arrays explicitly
113+ if value .dtype == object :
114+ raise TypeError (GRADIENT_OBJECT_ERROR_MSG )
115+
116+ if not _has_ndim (value , 2 ):
117+ raise ValueError (GRADIENT_NDIM_ERROR_MSG )
118+
119+ if not _has_dim_size (value , 4 ):
120+ raise ValueError (GRADIENT_EXPECTED_COLUMNS_ERROR_MSG )
121+
122+
123+ def _gradients_converter (value : ArrayLike ) -> np .ndarray :
124+ """Permissive gradient converter.
125+
126+ Behavior:
127+ - Converts the incoming ``value`` to a float NumPy array.
128+ - Ensures the result is 2-D and that one dimension equals 4.
129+ - If a 2-D array has ``shape[0] == 4`` and ``shape[1] != 4``, it will be
130+ transposed so the returned array has ``shape[1] == 4``.
131+ - For 1-D inputs of length 4, returns an array shaped ``(1, 4)``.
132+ - Raises exc:`TypeError` for conversion failures and exc:`ValueError` for
133+ shape violations.
134+
135+ Parameters
136+ ----------
137+ value : :obj:`ArrayLike`
138+ Input to convert to a :obj:`~numpy.ndarray` of floats.
139+
140+ Returns
141+ -------
142+ :obj:`~numpy.ndarray`
143+ A 2-D float array with ``shape[1] == 4``.
144+
145+ Raises
146+ ------
147+ exc:`TypeError`
148+ If the input cannot be converted to a float :obj:`~numpy.ndarray`.
149+ exc:`ValueError`
150+ If the converted array is not 2-D (after the 1-D -> 2-D promotion)
151+ or does not have a dimension of size 4 such that the returned array
152+ can be shaped with ``shape[1] == 4``.
153+
154+ Examples
155+ --------
156+ >>> _gradients_converter([0, 0, 0, 1]).shape
157+ (1, 4)
158+ >>> _gradients_converter(np.zeros((10, 4))).shape
159+ (10, 4)
160+ >>> _gradients_converter(np.zeros((4, 10))).shape
161+ (10, 4) # transposed so shape[1] == 4
162+ """
163+
164+ if value is None :
165+ raise ValueError (GRADIENT_ABSENCE_ERROR_MSG )
166+
167+ # Convert to ndarray
168+ if isinstance (value , np .ndarray ):
169+ arr = value .astype (float , copy = False )
170+ else :
171+ try :
172+ arr = np .asarray (value , dtype = float )
173+ except (TypeError , ValueError ) as exc :
174+ # Conversion failed (e.g. nested ragged objects, non-numeric)
175+ raise TypeError (GRADIENT_OBJECT_ERROR_MSG ) from exc
176+
177+ _check_gradient_shape (arr )
178+
179+ if arr .shape [1 ] == 4 :
180+ pass
181+ else :
182+ arr = arr .T
183+
184+ # ToDo
185+ # Call gradient normalization
186+ return arr
187+
188+
189+ def _gradients_validator (inst : DWI , attr : attrs .Attribute , value : Any ) -> None :
190+ """Strict validator for use in attribute validation (e.g. attrs / validators).
191+
192+ Enforces that ``value`` is a NumPy array and has the expected 2-D shape
193+ with 4 columns (``shape[1] == 4``).
194+
195+ This function is intended for use as an attrs-style validator.
196+
197+ Raises
198+ ------
199+ exc:`TypeError`
200+ If ``value`` is not a :obj:`~numpy.ndarray`.
201+ exc:`ValueError``
202+ If ``value`` is not 2-D or its shape does not have 4 columns.
203+
204+ Parameters
205+ ----------
206+ inst : :obj:`:obj:`~nifreeze.data.dmri.DWI`
207+ The instance being validated (unused, present for validator signature).
208+ attr : :obj:`~attrs.Attribute`
209+ The attribute being validated (unused, present for validator signature).
210+ value : :obj:`Any`
211+ The value to validate.
212+ """
213+
214+ if value is None :
215+ raise ValueError (GRADIENT_ABSENCE_ERROR_MSG )
216+
217+ if not isinstance (value , np .ndarray ):
218+ raise TypeError (GRADIENT_OBJECT_ERROR_MSG )
219+
220+ _check_gradient_shape (value )
221+
222+
67223@attrs .define (slots = True )
68224class DWI (BaseDataset [np .ndarray ]):
69225 """Data representation structure for dMRI data."""
@@ -72,41 +228,44 @@ class DWI(BaseDataset[np.ndarray]):
72228 default = None , repr = _data_repr , eq = attrs .cmp_using (eq = _cmp )
73229 )
74230 """A *b=0* reference map, preferably obtained by some smart averaging."""
75- gradients : np .ndarray = attrs .field (default = None , repr = _data_repr , eq = attrs .cmp_using (eq = _cmp ))
231+ gradients : np .ndarray = attrs .field (
232+ default = None ,
233+ repr = _data_repr ,
234+ eq = attrs .cmp_using (eq = _cmp ),
235+ validator = _gradients_validator ,
236+ converter = _gradients_converter ,
237+ )
76238 """A 2D numpy array of the gradient table (``N`` orientations x ``C`` components)."""
77239 eddy_xfms : list = attrs .field (default = None )
78240 """List of transforms to correct for estimated eddy current distortions."""
79241
80242 def __attrs_post_init__ (self ) -> None :
81- self ._normalize_gradients ()
82-
83- def _normalize_gradients (self ) -> None :
84- if self .gradients is None :
85- return
243+ """Enforce basic consistency of required dMRI fields at instantiation
244+ time.
86245
87- gradients = np . asarray ( self . gradients )
88- if gradients . ndim != 2 :
89- raise ValueError ( "Gradient table must be a 2D array" )
246+ Specifically, the number of gradient directions must match the last
247+ dimension of the data (number of volumes).
248+ """
90249
250+ # If the data object exists and has a time/volume axis, ensure sizes
251+ # match.
91252 n_volumes = None
92- if self .dataobj is not None :
93- try :
94- n_volumes = self .dataobj .shape [- 1 ]
95- except Exception : # pragma: no cover - extremely defensive
96- n_volumes = None
97-
98- if n_volumes is not None and gradients .shape [0 ] != n_volumes :
99- if gradients .shape [1 ] == n_volumes :
100- gradients = gradients .T
101- else :
253+ if getattr (self , "dataobj" , None ) is not None :
254+ shape = getattr (self .dataobj , "shape" , None )
255+ if isinstance (shape , (tuple , list )) and len (shape ) >= 1 :
256+ try :
257+ n_volumes = int (shape [- 1 ])
258+ except (TypeError , ValueError ):
259+ n_volumes = None
260+
261+ if n_volumes is not None :
262+ n_gradients = self .gradients .shape [1 ]
263+ if n_gradients != n_volumes :
102264 raise ValueError (
103- "Gradient table shape does not match the number of diffusion volumes: "
104- f"expected { n_volumes } rows, found { gradients .shape [0 ]} "
265+ GRADIENT_COUNT_MISMATCH_ERROR_MSG .format (
266+ n_gradients = n_gradients , data_vols = n_volumes
267+ )
105268 )
106- elif n_volumes is None and gradients .shape [1 ] > gradients .shape [0 ]:
107- gradients = gradients .T
108-
109- self .gradients = gradients
110269
111270 def _getextra (self , idx : int | slice | tuple | np .ndarray ) -> tuple [np .ndarray ]:
112271 return (self .gradients [idx , ...],)
@@ -315,6 +474,20 @@ def to_nifti(
315474 return nii
316475
317476
477+ def _compose_gradients (bvec_file : Path | str , bval_file : Path | str ):
478+ bvecs = np .loadtxt (bvec_file , dtype = "float32" )
479+ if bvecs .ndim == 1 :
480+ bvecs = bvecs [np .newaxis , :]
481+ if bvecs .shape [1 ] != 3 and bvecs .shape [0 ] == 3 :
482+ bvecs = bvecs .T
483+
484+ bvals = np .loadtxt (bval_file , dtype = "float32" )
485+ if bvals .ndim > 1 :
486+ bvals = np .squeeze (bvals )
487+
488+ return np .column_stack ((bvecs , bvals ))
489+
490+
318491def from_nii (
319492 filename : Path | str ,
320493 brainmask_file : Path | str | None = None ,
@@ -389,35 +562,14 @@ def from_nii(
389562 stacklevel = 2 ,
390563 )
391564 elif bvec_file and bval_file :
392- bvecs = np .loadtxt (bvec_file , dtype = "float32" )
393- if bvecs .ndim == 1 :
394- bvecs = bvecs [np .newaxis , :]
395- if bvecs .shape [1 ] != 3 and bvecs .shape [0 ] == 3 :
396- bvecs = bvecs .T
397-
398- bvals = np .loadtxt (bval_file , dtype = "float32" )
399- if bvals .ndim > 1 :
400- bvals = np .squeeze (bvals )
401- grad = np .column_stack ((bvecs , bvals ))
565+ grad = _compose_gradients (bvec_file , bval_file )
402566 else :
403567 raise RuntimeError (
404568 "No gradient data provided. "
405569 "Please specify either a gradients_file or (bvec_file & bval_file)."
406570 )
407571
408- if grad .ndim == 1 :
409- grad = grad [np .newaxis , :]
410-
411- if grad .shape [1 ] < 2 :
412- raise ValueError ("Gradient table must have at least two columns (direction + b-value)." )
413-
414- if grad .shape [1 ] != 4 :
415- if grad .shape [0 ] == 4 :
416- grad = grad .T
417- else :
418- raise ValueError (
419- "Gradient table must have four columns (3 direction components and one b-value)."
420- )
572+ grad = _gradients_converter (grad )
421573
422574 # 3) Create the DWI instance. We'll filter out volumes where b-value > b0_thres
423575 # as "DW volumes" if the user wants to store only the high-b volumes here
@@ -426,11 +578,14 @@ def from_nii(
426578 dwi_obj = DWI (
427579 dataobj = fulldata [..., gradmsk ],
428580 affine = img .affine ,
429- # We'll assign the filtered gradients below.
581+ gradients = grad [
582+ gradmsk , :
583+ ], # ToDo Duplicate call to _gradients_converter but cannot do better I think
430584 )
431585
432- dwi_obj .gradients = grad [gradmsk , :]
433- dwi_obj ._normalize_gradients ()
586+ # removing gradients = np.asarray(self.gradients) from _normalize_gradients:
587+ # the annotation does not suggest anything other than arrays: if we want a list of lists, we should type hint that.
588+ # The converter duplicates the checks, and we could skip it in the signature, but I think it is wise to keep it
434589
435590 # 4) b=0 volume (bzero)
436591 # If the user provided a b0_file, load it
0 commit comments