Skip to content

Commit 7933b75

Browse files
committed
TST: Add additional object instantiation equality checks
Add additional object instantiation equality checks: check that objects intantiated through reading NIfTI files equal objects instantiated directly.
1 parent 064413e commit 7933b75

File tree

2 files changed

+34
-8
lines changed

2 files changed

+34
-8
lines changed

src/nifreeze/data/dmri.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def format_gradients(
116116
--------
117117
Passing an already well-formed table returns the data unchanged::
118118
119+
>>> np.set_printoptions(formatter={"float_kind": lambda x: f"{int(x):4d}"})
119120
>>> format_gradients(
120121
... [
121122
... [1, 0, 0, 0],
@@ -161,6 +162,7 @@ def format_gradients(
161162
...
162163
ValueError: Gradient table must be a 2D array
163164
165+
>>> np.set_printoptions() # reset to defaults
164166
"""
165167

166168
if value is None:
@@ -175,10 +177,6 @@ def format_gradients(
175177
if formatted.ndim != 2:
176178
raise ValueError(GRADIENT_NDIM_ERROR_MSG)
177179

178-
# If the numeric values are all integers, preserve integer dtype
179-
if np.all(np.isfinite(formatted)) and np.allclose(formatted, np.round(formatted)):
180-
formatted = formatted.astype(int)
181-
182180
# Transpose if column-major
183181
return formatted.T if formatted.shape[0] == 4 and formatted.shape[1] != 4 else formatted
184182

test/test_data_dmri.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -689,20 +689,48 @@ def test_equality_operator(tmp_path, setup_random_dwi_data):
689689
tmp_path,
690690
)
691691

692-
dwi_obj = from_nii(
692+
# Read back using public API
693+
dwi_obj_from_nii = from_nii(
693694
dwi_fname,
694695
gradients_file=gradients_fname,
695696
b0_file=b0_fname,
696697
brainmask_file=brainmask_fname,
697698
)
699+
700+
# Direct instantiation with the same arrays
701+
dwi_obj_direct = DWI(
702+
dataobj=dwi_dataobj,
703+
affine=affine,
704+
brainmask=brainmask_dataobj,
705+
gradients=gradients,
706+
bzero=b0_dataobj,
707+
)
708+
709+
# Sanity checks (element-wise)
710+
assert np.allclose(dwi_obj_direct.dataobj, dwi_obj_from_nii.dataobj)
711+
assert np.allclose(dwi_obj_direct.affine, dwi_obj_from_nii.affine)
712+
if dwi_obj_direct.brainmask is None or dwi_obj_from_nii.brainmask is None:
713+
assert dwi_obj_direct.brainmask is None
714+
assert dwi_obj_from_nii.brainmask is None
715+
else:
716+
assert np.array_equal(dwi_obj_direct.brainmask, dwi_obj_from_nii.brainmask)
717+
assert np.allclose(dwi_obj_direct.gradients, dwi_obj_from_nii.gradients)
718+
# Properties derived from gradients should also match
719+
assert np.allclose(dwi_obj_direct.bvals, dwi_obj_from_nii.bvals)
720+
assert np.allclose(dwi_obj_direct.bvecs, dwi_obj_from_nii.bvecs)
721+
722+
# Test equality operator
723+
assert dwi_obj_direct == dwi_obj_from_nii
724+
725+
# Test equality operator against an instance from HDF5
698726
hdf5_filename = tmp_path / "test_dwi.h5"
699-
dwi_obj.to_filename(hdf5_filename)
727+
dwi_obj_from_nii.to_filename(hdf5_filename)
700728

701729
round_trip_dwi_obj = DWI.from_filename(hdf5_filename)
702730

703731
# Symmetric equality
704-
assert dwi_obj == round_trip_dwi_obj
705-
assert round_trip_dwi_obj == dwi_obj
732+
assert dwi_obj_from_nii == round_trip_dwi_obj
733+
assert round_trip_dwi_obj == dwi_obj_from_nii
706734

707735

708736
@pytest.mark.random_dwi_data(50, (34, 36, 24), False)

0 commit comments

Comments
 (0)