Skip to content

Commit d216490

Browse files
committed
tst: add conform test, refactor existing csf_norm one
1 parent 01defc8 commit d216490

File tree

1 file changed

+64
-26
lines changed

1 file changed

+64
-26
lines changed
Lines changed: 64 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,93 @@
1-
import typing as ty
21
from pathlib import Path
32

43
import nibabel as nb
54
import numpy as np
65
import pytest
6+
from nibabel.orientations import aff2axcodes
77

8-
from nibabies.workflows.anatomical.preproc import _normalize_roi, init_csf_norm_wf
8+
from nibabies.workflows.anatomical.preproc import (
9+
_normalize_roi,
10+
init_conform_derivative_wf,
11+
init_csf_norm_wf,
12+
)
913

10-
EXPECTED_CSF_NORM = np.array([[[10, 73], [73, 29]], [[77, 80], [6, 16]]], dtype='uint8')
14+
EXPECTED_CSF_NORM = np.array([[[49, 75], [23, 75]], [[77, 80], [33, 3]]], dtype='uint8')
1115

1216

1317
@pytest.fixture
14-
def csf_norm_data(tmp_path) -> ty.Generator[tuple[Path, list[Path]], None, None]:
15-
np.random.seed(10)
16-
17-
in_file = tmp_path / 'input.nii.gz'
18-
data = np.random.randint(1, 101, size=(2, 2, 2), dtype='uint8')
18+
def anat_file(tmp_path):
19+
data = np.array([[[49, 73], [23, 73]], [[77, 80], [33, 3]]], dtype='uint8')
1920
img = nb.Nifti1Image(data, np.eye(4))
20-
img.to_filename(in_file)
21+
out = tmp_path / 'input.nii.gz'
22+
img.to_filename(out)
23+
return out
24+
2125

22-
masks = []
23-
for tpm in ('gm', 'wm', 'csf'):
26+
def test_csf_norm_wf(tmp_path, anat_file):
27+
tpms = []
28+
for tpm, thresh in (('gm', 25), ('wm', 75), ('csf', 50)):
2429
name = tmp_path / f'{tpm}.nii.gz'
25-
binmask = data > np.random.randint(10, 90)
30+
anat_img = nb.load(anat_file)
31+
anat_data = np.asanyarray(nb.load(anat_file).dataobj)
32+
33+
binmask = anat_data > thresh
2634
masked = (binmask * 1).astype('uint8')
27-
mask = nb.Nifti1Image(masked, img.affine)
35+
mask = nb.Nifti1Image(masked, anat_img.affine)
2836
mask.to_filename(name)
29-
masks.append(name)
30-
31-
yield in_file, masks
37+
tpms.append(name)
3238

33-
in_file.unlink()
34-
for m in masks:
35-
m.unlink()
36-
37-
38-
def test_csf_norm_wf(tmp_path, csf_norm_data):
39-
anat, tpms = csf_norm_data
4039
wf = init_csf_norm_wf()
4140
wf.base_dir = tmp_path
42-
43-
wf.inputs.inputnode.anat_preproc = anat
41+
wf.inputs.inputnode.anat_preproc = anat_file
4442
wf.inputs.inputnode.anat_tpms = tpms
4543

4644
# verify workflow runs
4745
wf.run()
4846

4947
# verify function works as expected
50-
outfile = _normalize_roi(anat, tpms[2])
48+
outfile = _normalize_roi(anat_file, tpms[2])
5149
assert np.array_equal(
5250
np.asanyarray(nb.load(outfile).dataobj),
5351
EXPECTED_CSF_NORM,
5452
)
5553
Path(outfile).unlink()
54+
55+
56+
@pytest.mark.parametrize(
57+
('affine_mismatch', 'ornt_mismatch'),
58+
[
59+
(False, False),
60+
(True, False),
61+
(False, True),
62+
(True, True),
63+
],
64+
)
65+
def test_conform_derivative_wf(tmp_path, anat_file, affine_mismatch, ornt_mismatch):
66+
deriv = tmp_path / 'mask.nii.gz'
67+
ref_img = nb.load(anat_file)
68+
aff = ref_img.affine.copy()
69+
if affine_mismatch:
70+
# Alter affine slightly
71+
aff[:3, :3] += 0.01
72+
assert not np.array_equal(aff, ref_img.affine)
73+
74+
img = ref_img.__class__(ref_img.dataobj, affine=aff)
75+
if ornt_mismatch:
76+
from niworkflows.interfaces.nibabel import reorient_image
77+
78+
img = reorient_image(img, target_ornt='LPI')
79+
assert aff2axcodes(img.affine) != aff2axcodes(ref_img.affine)
80+
81+
img.to_filename(deriv)
82+
wf = init_conform_derivative_wf(in_file=deriv)
83+
wf.base_dir = tmp_path
84+
wf.inputs.inputnode.ref_file = anat_file
85+
86+
wf.run()
87+
88+
output = list((tmp_path / 'conform_derivative_wf' / 'match_header').glob('*.nii.gz'))
89+
assert output
90+
out_file = output[0]
91+
out_img = nb.load(out_file)
92+
assert np.array_equal(out_img.affine, ref_img.affine)
93+
assert aff2axcodes(out_img.affine) == aff2axcodes(ref_img.affine)

0 commit comments

Comments
 (0)