Skip to content

Commit 83eae58

Browse files
committed
RF: Create dedicated container class for derivatives
1 parent 1d90499 commit 83eae58

File tree

1 file changed

+78
-46
lines changed

1 file changed

+78
-46
lines changed

nibabies/utils/bids.py

Lines changed: 78 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,89 @@
11
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
22
# vi: set ft=python sts=4 ts=4 sw=4 et:
33
"""Utilities to handle BIDS inputs."""
4+
from __future__ import annotations
5+
46
import json
57
import os
68
import sys
79
import warnings
8-
from dataclasses import dataclass, field
910
from pathlib import Path
10-
from typing import IO, List, Literal, Optional, Union
11+
from typing import Literal, Optional, Union
12+
13+
from bids.layout import BIDSLayout, Query
14+
15+
_spec: dict = {
16+
't1w_mask': {
17+
'datatype': 'anat',
18+
'desc': 'brain',
19+
'space': 'T1w',
20+
'suffix': 'mask',
21+
},
22+
't1w_aseg': {'datatype': 'anat', 'desc': 'aseg', 'space': 'T1w', 'suffix': 'dseg'},
23+
't2w_mask': {
24+
'datatype': 'anat',
25+
'desc': 'brain',
26+
'space': 'T2w',
27+
'suffix': 'mask',
28+
},
29+
't2w_aseg': {
30+
'datatype': 'anat',
31+
'desc': 'aseg',
32+
'space': 'T2w',
33+
'suffix': 'dseg',
34+
},
35+
}
36+
37+
38+
class Derivatives:
39+
"""A container class for storing precomputed derivatives."""
40+
41+
def __init__(self, spec: dict | Path | str | None = None, **args):
42+
self.spec = _spec
43+
if spec is not None:
44+
if not isinstance(spec, dict):
45+
spec: dict = json.loads(Path(spec).read_text())
46+
self.spec = spec
47+
48+
self.names = set(self.spec.keys())
49+
self.references = {name: None for name in self.names}
50+
for name in self.names:
51+
setattr(self, name, None)
52+
53+
def __repr__(self):
54+
return '\n'.join([name for name in self.names if getattr(self, name)])
55+
56+
def __contains__(self, val: str):
57+
return val in self.names
58+
59+
def populate(
60+
self, deriv_path, subject_id: str, session_id: str | Query | None = Query.OPTIONAL
61+
):
62+
layout = BIDSLayout(deriv_path, validate=False)
63+
for name, query in self.spec.items():
64+
items = layout.get(
65+
subject=subject_id,
66+
session=session_id,
67+
extension=['.nii', '.nii.gz'],
68+
**query,
69+
)
70+
if not items or len(items) > 1:
71+
continue
72+
item = items[0]
73+
74+
# Skip if derivative does not have valid metadata
75+
metadata = item.get_metadata()
76+
if not metadata or not (reference := metadata.get('SpatialReference')):
77+
# raise warning
78+
continue
79+
if isinstance(reference, list):
80+
if len(reference) > 1:
81+
# raise warning
82+
continue
83+
reference = reference[0]
84+
85+
setattr(self, name, Path(item.path))
86+
self.references[name] = (Path(deriv_path) / reference).absolute()
1187

1288

1389
def write_bidsignore(deriv_dir):
@@ -203,50 +279,6 @@ def validate_input_dir(exec_env, bids_dir, participant_label):
203279
print("bids-validator does not appear to be installed", file=sys.stderr)
204280

205281

206-
def collect_precomputed_derivatives(layout, subject_id, derivatives_filters=None):
207-
"""
208-
Query and collect precomputed derivatives.
209-
210-
This function is used to determine which workflow steps can be skipped,
211-
based on the files found.
212-
"""
213-
214-
deriv_queries = {
215-
'anat_mask': {
216-
'datatype': 'anat',
217-
'desc': 'brain',
218-
'space': 'orig',
219-
'suffix': 'mask',
220-
},
221-
'anat_aseg': {
222-
'datatype': 'anat',
223-
'desc': 'aseg',
224-
'space': 'orig',
225-
'suffix': 'dseg',
226-
},
227-
}
228-
if derivatives_filters is not None:
229-
deriv_queries.update(derivatives_filters)
230-
231-
derivatives = {}
232-
for deriv, query in deriv_queries.items():
233-
res = layout.get(
234-
scope='derivatives',
235-
subject=subject_id,
236-
extension=['.nii', '.nii.gz'],
237-
return_type="filename",
238-
**query,
239-
)
240-
if not res:
241-
continue
242-
if len(res) > 1: # Some queries may want multiple results
243-
raise Exception(
244-
f"When searching for <{deriv}>, found multiple results: {[f.path for f in res]}"
245-
)
246-
derivatives[deriv] = res[0]
247-
return derivatives
248-
249-
250282
def parse_bids_for_age_months(
251283
bids_root: Union[str, Path],
252284
subject_id: str,

0 commit comments

Comments
 (0)