Skip to content

Commit 131f0d5

Browse files
authored
Merge pull request #305 from nipreps/docker/t1-t2-derivatives
ENH+RF: Allow precomputed derivatives in T1w or T2w space
2 parents 45d48ef + b6838ef commit 131f0d5

File tree

10 files changed

+591
-254
lines changed

10 files changed

+591
-254
lines changed

nibabies/utils/bids.py

Lines changed: 122 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,134 @@
11
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
22
# vi: set ft=python sts=4 ts=4 sw=4 et:
33
"""Utilities to handle BIDS inputs."""
4+
from __future__ import annotations
5+
46
import json
57
import os
68
import sys
9+
import typing as ty
710
import warnings
8-
from dataclasses import dataclass, field
911
from pathlib import Path
10-
from typing import IO, List, Literal, Optional, Union
1112

13+
import nibabel as nb
14+
import numpy as np
15+
from bids.layout import BIDSLayout, Query
16+
17+
_spec: dict = {
18+
't1w_mask': {
19+
'datatype': 'anat',
20+
'desc': 'brain',
21+
'space': 'T1w',
22+
'suffix': 'mask',
23+
},
24+
't1w_aseg': {'datatype': 'anat', 'desc': 'aseg', 'space': 'T1w', 'suffix': 'dseg'},
25+
't2w_mask': {
26+
'datatype': 'anat',
27+
'desc': 'brain',
28+
'space': 'T2w',
29+
'suffix': 'mask',
30+
},
31+
't2w_aseg': {
32+
'datatype': 'anat',
33+
'desc': 'aseg',
34+
'space': 'T2w',
35+
'suffix': 'dseg',
36+
},
37+
}
38+
39+
40+
class Derivatives:
41+
"""
42+
A container class for collecting and storing derivatives.
43+
44+
A specification (either dictionary or JSON file) can be used to customize derivatives and
45+
queries.
46+
To populate this class with derivatives, the `populate()` method must first be called.
47+
"""
1248

13-
@dataclass
14-
class BOLDGrouping:
15-
"""This class is used to facilitate the grouping of BOLD series."""
49+
def __getattribute__(self, attr):
50+
"""In cases where the spec may change, avoid errors."""
51+
try:
52+
return object.__getattribute__(self, attr)
53+
except AttributeError:
54+
return None
55+
56+
def __init__(self, bids_root: Path | str, spec: dict | Path | str | None = None, **args):
57+
self.bids_root = Path(bids_root)
58+
self.spec = _spec
59+
if spec is not None:
60+
if not isinstance(spec, dict):
61+
spec: dict = json.loads(Path(spec).read_text())
62+
self.spec = spec
63+
64+
self.names = set(self.spec.keys())
65+
self.references = {name: None for name in self.names}
66+
for name in self.names:
67+
setattr(self, name, None)
68+
69+
def __repr__(self):
70+
return '\n'.join([name for name in self.names if getattr(self, name)])
71+
72+
def __contains__(self, val: str):
73+
return val in self.names
74+
75+
def __bool__(self):
76+
return any(getattr(self, name) for name in self.names)
77+
78+
def populate(
79+
self, deriv_path, subject_id: str, session_id: str | Query | None = Query.OPTIONAL
80+
) -> None:
81+
"""Query a derivatives directory and populate values and references based on the spec."""
82+
layout = BIDSLayout(deriv_path, validate=False)
83+
for name, query in self.spec.items():
84+
items = layout.get(
85+
subject=subject_id,
86+
session=session_id,
87+
extension=['.nii', '.nii.gz'],
88+
**query,
89+
)
90+
if not items or len(items) > 1:
91+
warnings.warn(f"Could not find {name}")
92+
continue
93+
item = items[0]
94+
95+
# Skip if derivative does not have valid metadata
96+
metadata = item.get_metadata()
97+
if not metadata or not (reference := metadata.get('SpatialReference')):
98+
warnings.warn(f"No metadata found for {item}")
99+
continue
100+
if isinstance(reference, list):
101+
if len(reference) > 1:
102+
warnings.warn(f"Multiple reference found: {reference}")
103+
continue
104+
reference = reference[0]
105+
106+
reference = self.bids_root / reference
107+
if not self.validate(item.path, str(reference)):
108+
warnings.warn(f"Validation failed between: {item.path} and {reference}")
109+
continue
110+
111+
setattr(self, name, Path(item.path))
112+
self.references[name] = reference
16113

17-
session: Union[str, None]
18-
pe_dir: str
19-
readout: float
20-
multiecho_id: str = None
21-
files: List[IO] = field(default_factory=list)
114+
@property
115+
def mask(self) -> str | None:
116+
return self.t1w_mask or self.t2w_mask
22117

23118
@property
24-
def name(self) -> str:
25-
return f"{self.session}-{self.pe_dir}-{self.readout}-{self.multiecho_id}"
119+
def aseg(self) -> str | None:
120+
return self.t1w_aseg or self.t2w_aseg
26121

27-
def add_file(self, fl) -> None:
28-
self.files.append(fl)
122+
@staticmethod
123+
def validate(derivative: str, reference: str, atol: float = 1e-5) -> bool:
124+
anat = nb.load(reference)
125+
expected_ort = nb.aff2axcodes(anat.affine)
126+
img = nb.load(derivative)
127+
if nb.aff2axcodes(img.affine) != expected_ort:
128+
return False
129+
if img.shape != anat.shape or not np.allclose(anat.affine, img.affine, atol=atol):
130+
return False
131+
return True
29132

30133

31134
def write_bidsignore(deriv_dir):
@@ -221,55 +324,11 @@ def validate_input_dir(exec_env, bids_dir, participant_label):
221324
print("bids-validator does not appear to be installed", file=sys.stderr)
222325

223326

224-
def collect_precomputed_derivatives(layout, subject_id, derivatives_filters=None):
225-
"""
226-
Query and collect precomputed derivatives.
227-
228-
This function is used to determine which workflow steps can be skipped,
229-
based on the files found.
230-
"""
231-
232-
deriv_queries = {
233-
'anat_mask': {
234-
'datatype': 'anat',
235-
'desc': 'brain',
236-
'space': 'orig',
237-
'suffix': 'mask',
238-
},
239-
'anat_aseg': {
240-
'datatype': 'anat',
241-
'desc': 'aseg',
242-
'space': 'orig',
243-
'suffix': 'dseg',
244-
},
245-
}
246-
if derivatives_filters is not None:
247-
deriv_queries.update(derivatives_filters)
248-
249-
derivatives = {}
250-
for deriv, query in deriv_queries.items():
251-
res = layout.get(
252-
scope='derivatives',
253-
subject=subject_id,
254-
extension=['.nii', '.nii.gz'],
255-
return_type="filename",
256-
**query,
257-
)
258-
if not res:
259-
continue
260-
if len(res) > 1: # Some queries may want multiple results
261-
raise Exception(
262-
f"When searching for <{deriv}>, found multiple results: {[f.path for f in res]}"
263-
)
264-
derivatives[deriv] = res[0]
265-
return derivatives
266-
267-
268327
def parse_bids_for_age_months(
269-
bids_root: Union[str, Path],
328+
bids_root: str | Path,
270329
subject_id: str,
271-
session_id: Optional[str] = None,
272-
) -> Optional[int]:
330+
session_id: str | None = None,
331+
) -> int | None:
273332
"""
274333
Given a BIDS root, query the BIDS metadata files for participant age, in months.
275334
@@ -295,8 +354,8 @@ def parse_bids_for_age_months(
295354

296355

297356
def _get_age_from_tsv(
298-
bids_tsv: Path, level: Literal['session', 'participant'], key: str
299-
) -> Optional[int]:
357+
bids_tsv: Path, level: ty.Literal['session', 'participant'], key: str
358+
) -> int | None:
300359
import pandas as pd
301360

302361
df = pd.read_csv(str(bids_tsv), sep='\t')

nibabies/utils/tests/__init__.py

Whitespace-only changes.

nibabies/utils/tests/test_bids.py

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
from __future__ import annotations
2+
3+
import json
4+
import typing as ty
5+
from pathlib import Path
6+
7+
import pytest
8+
9+
from nibabies.utils import bids
10+
11+
12+
def _create_nifti(filename: str) -> str:
13+
import nibabel as nb
14+
import numpy as np
15+
16+
data = np.zeros((4, 4, 4), dtype='int8')
17+
nb.Nifti1Image(data, np.eye(4)).to_filename(filename)
18+
return filename
19+
20+
21+
def _create_bids_dir(root_path: Path):
22+
if not root_path.exists():
23+
root_path.mkdir()
24+
anat_dir = root_path / 'sub-01' / 'anat'
25+
anat_dir.mkdir(parents=True)
26+
_create_nifti(str(anat_dir / 'sub-01_T1w.nii.gz'))
27+
_create_nifti(str(anat_dir / 'sub-01_T2w.nii.gz'))
28+
29+
30+
def _create_bids_derivs(
31+
root_path: Path,
32+
*,
33+
t1w_mask: bool = False,
34+
t1w_aseg: bool = False,
35+
t2w_mask: bool = False,
36+
t2w_aseg: bool = False,
37+
):
38+
if not root_path.exists():
39+
root_path.mkdir()
40+
(root_path / 'dataset_description.json').write_text(
41+
json.dumps(
42+
{'Name': 'Derivatives Test', 'BIDSVersion': '1.8.0', 'DatasetType': 'derivative'}
43+
)
44+
)
45+
anat_dir = root_path / 'sub-01' / 'anat'
46+
anat_dir.mkdir(parents=True)
47+
48+
def _create_deriv(name: str, modality: ty.Literal['t1w', 't2w']):
49+
if modality == 't1w':
50+
reference = 'sub-01/anat/sub-01_T1w.nii.gz'
51+
elif modality == 't2w':
52+
reference = 'sub-01/anat/sub-01_T2w.nii.gz'
53+
54+
_create_nifti(str((anat_dir / name).with_suffix('.nii.gz')))
55+
(anat_dir / name).with_suffix('.json').write_text(
56+
json.dumps({'SpatialReference': reference})
57+
)
58+
59+
if t1w_mask:
60+
_create_deriv('sub-01_space-T1w_desc-brain_mask', 't1w')
61+
if t1w_aseg:
62+
_create_deriv('sub-01_space-T1w_desc-aseg_dseg', 't1w')
63+
if t2w_mask:
64+
_create_deriv('sub-01_space-T2w_desc-brain_mask', 't2w')
65+
if t2w_aseg:
66+
_create_deriv('sub-01_space-T2w_desc-aseg_dseg', 't2w')
67+
68+
69+
@pytest.mark.parametrize(
70+
't1w_mask,t1w_aseg,t2w_mask,t2w_aseg,mask,aseg',
71+
[
72+
(True, True, False, False, 't1w_mask', 't1w_aseg'),
73+
(True, True, True, True, 't1w_mask', 't1w_aseg'),
74+
(False, False, True, True, 't2w_mask', 't2w_aseg'),
75+
(True, False, False, True, 't1w_mask', 't2w_aseg'),
76+
(False, False, False, False, None, None),
77+
],
78+
)
79+
def test_derivatives(
80+
tmp_path: Path,
81+
t1w_mask: bool,
82+
t1w_aseg: bool,
83+
t2w_mask: bool,
84+
t2w_aseg: bool,
85+
mask: str | None,
86+
aseg: str | None,
87+
):
88+
bids_dir = tmp_path / 'bids'
89+
_create_bids_dir(bids_dir)
90+
deriv_dir = tmp_path / 'derivatives'
91+
_create_bids_derivs(
92+
deriv_dir, t1w_mask=t1w_mask, t1w_aseg=t1w_aseg, t2w_mask=t2w_mask, t2w_aseg=t2w_aseg
93+
)
94+
95+
derivatives = bids.Derivatives(bids_dir)
96+
assert derivatives.mask is None
97+
assert derivatives.t1w_mask is None
98+
assert derivatives.t2w_mask is None
99+
assert derivatives.aseg is None
100+
assert derivatives.t1w_aseg is None
101+
assert derivatives.t2w_aseg is None
102+
103+
derivatives.populate(deriv_dir, subject_id='01')
104+
if mask:
105+
assert derivatives.mask == getattr(derivatives, mask)
106+
assert derivatives.references[mask]
107+
else:
108+
assert derivatives.mask is None
109+
if aseg:
110+
assert derivatives.aseg == getattr(derivatives, aseg)
111+
assert derivatives.references[aseg]
112+
else:
113+
assert derivatives.aseg == None

nibabies/utils/validation.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

0 commit comments

Comments
 (0)