Skip to content

Commit d48007f

Browse files
committed
TST: Add test for Derivatives class, add some info to docstring
1 parent 16b603b commit d48007f

File tree

3 files changed

+118
-1
lines changed

3 files changed

+118
-1
lines changed

nibabies/utils/bids.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,13 @@
3838

3939

4040
class Derivatives:
41-
"""A container class for storing precomputed derivatives."""
41+
"""
42+
A container class for collecting and storing derivatives.
43+
44+
A specification (either dictionary or JSON file) can be used to customize derivatives and
45+
queries.
46+
To populate this class with derivatives, the `populate()` method must first be called.
47+
"""
4248

4349
def __getattribute__(self, attr):
4450
"""In cases where the spec may change, avoid errors."""

nibabies/utils/tests/__init__.py

Whitespace-only changes.

nibabies/utils/tests/test_bids.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import json
2+
import typing as ty
3+
from pathlib import Path
4+
5+
import pytest
6+
7+
from nibabies.utils import bids
8+
9+
10+
def _create_nifti(filename: str) -> str:
11+
import nibabel as nb
12+
import numpy as np
13+
14+
data = np.zeros((4, 4, 4), dtype='int8')
15+
nb.Nifti1Image(data, np.eye(4)).to_filename(filename)
16+
return filename
17+
18+
19+
def _create_bids_dir(root_path: Path):
20+
if not root_path.exists():
21+
root_path.mkdir()
22+
anat_dir = root_path / 'sub-01' / 'anat'
23+
anat_dir.mkdir(parents=True)
24+
_create_nifti(str(anat_dir / 'sub-01_T1w.nii.gz'))
25+
_create_nifti(str(anat_dir / 'sub-01_T2w.nii.gz'))
26+
27+
28+
def _create_bids_derivs(
29+
root_path: Path,
30+
*,
31+
t1w_mask: bool = False,
32+
t1w_aseg: bool = False,
33+
t2w_mask: bool = False,
34+
t2w_aseg: bool = False,
35+
):
36+
if not root_path.exists():
37+
root_path.mkdir()
38+
(root_path / 'dataset_description.json').write_text(
39+
json.dumps(
40+
{'Name': 'Derivatives Test', 'BIDSVersion': '1.8.0', 'DatasetType': 'derivative'}
41+
)
42+
)
43+
anat_dir = root_path / 'sub-01' / 'anat'
44+
anat_dir.mkdir(parents=True)
45+
46+
def _create_deriv(name: str, modality: ty.Literal['t1w', 't2w']):
47+
if modality == 't1w':
48+
reference = 'sub-01/anat/sub-01_T1w.nii.gz'
49+
elif modality == 't2w':
50+
reference = 'sub-01/anat/sub-01_T2w.nii.gz'
51+
52+
_create_nifti(str((anat_dir / name).with_suffix('.nii.gz')))
53+
(anat_dir / name).with_suffix('.json').write_text(
54+
json.dumps({'SpatialReference': reference})
55+
)
56+
57+
if t1w_mask:
58+
_create_deriv('sub-01_space-T1w_desc-brain_mask', 't1w')
59+
if t1w_aseg:
60+
_create_deriv('sub-01_space-T1w_desc-aseg_dseg', 't1w')
61+
if t2w_mask:
62+
_create_deriv('sub-01_space-T2w_desc-brain_mask', 't2w')
63+
if t2w_aseg:
64+
_create_deriv('sub-01_space-T2w_desc-aseg_dseg', 't2w')
65+
66+
67+
@pytest.mark.parametrize(
68+
't1w_mask,t1w_aseg,t2w_mask,t2w_aseg,mask,aseg',
69+
[
70+
(True, True, False, False, 't1w_mask', 't1w_aseg'),
71+
(True, True, True, True, 't1w_mask', 't1w_aseg'),
72+
(False, False, True, True, 't2w_mask', 't2w_aseg'),
73+
(True, False, False, True, 't1w_mask', 't2w_aseg'),
74+
(False, False, False, False, None, None),
75+
],
76+
)
77+
def test_derivatives(
78+
tmp_path: Path,
79+
t1w_mask: bool,
80+
t1w_aseg: bool,
81+
t2w_mask: bool,
82+
t2w_aseg: bool,
83+
mask: str | None,
84+
aseg: str | None,
85+
):
86+
bids_dir = tmp_path / 'bids'
87+
_create_bids_dir(bids_dir)
88+
deriv_dir = tmp_path / 'derivatives'
89+
_create_bids_derivs(
90+
deriv_dir, t1w_mask=t1w_mask, t1w_aseg=t1w_aseg, t2w_mask=t2w_mask, t2w_aseg=t2w_aseg
91+
)
92+
93+
derivatives = bids.Derivatives(bids_dir)
94+
assert derivatives.mask is None
95+
assert derivatives.t1w_mask is None
96+
assert derivatives.t2w_mask is None
97+
assert derivatives.aseg is None
98+
assert derivatives.t1w_aseg is None
99+
assert derivatives.t2w_aseg is None
100+
101+
derivatives.populate(deriv_dir, subject_id='01')
102+
if mask:
103+
assert derivatives.mask == getattr(derivatives, mask)
104+
assert derivatives.references[mask]
105+
else:
106+
assert derivatives.mask is None
107+
if aseg:
108+
assert derivatives.aseg == getattr(derivatives, aseg)
109+
assert derivatives.references[aseg]
110+
else:
111+
assert derivatives.aseg == None

0 commit comments

Comments
 (0)