|
1 | 1 | # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
|
2 | 2 | # vi: set ft=python sts=4 ts=4 sw=4 et:
|
3 | 3 | """Utilities to handle BIDS inputs."""
|
| 4 | +from __future__ import annotations |
| 5 | + |
4 | 6 | import json
|
5 | 7 | import os
|
6 | 8 | import sys
|
7 | 9 | import warnings
|
8 |
| -from dataclasses import dataclass, field |
9 | 10 | from pathlib import Path
|
10 |
| -from typing import IO, List, Literal, Optional, Union |
| 11 | +from typing import Literal, Optional, Union |
| 12 | + |
| 13 | +from bids.layout import BIDSLayout, Query |
| 14 | + |
| 15 | +_spec: dict = { |
| 16 | + 't1w_mask': { |
| 17 | + 'datatype': 'anat', |
| 18 | + 'desc': 'brain', |
| 19 | + 'space': 'T1w', |
| 20 | + 'suffix': 'mask', |
| 21 | + }, |
| 22 | + 't1w_aseg': {'datatype': 'anat', 'desc': 'aseg', 'space': 'T1w', 'suffix': 'dseg'}, |
| 23 | + 't2w_mask': { |
| 24 | + 'datatype': 'anat', |
| 25 | + 'desc': 'brain', |
| 26 | + 'space': 'T2w', |
| 27 | + 'suffix': 'mask', |
| 28 | + }, |
| 29 | + 't2w_aseg': { |
| 30 | + 'datatype': 'anat', |
| 31 | + 'desc': 'aseg', |
| 32 | + 'space': 'T2w', |
| 33 | + 'suffix': 'dseg', |
| 34 | + }, |
| 35 | +} |
| 36 | + |
| 37 | + |
| 38 | +class Derivatives: |
| 39 | + """A container class for storing precomputed derivatives.""" |
| 40 | + |
| 41 | + def __init__(self, spec: dict | Path | str | None = None, **args): |
| 42 | + self.spec = _spec |
| 43 | + if spec is not None: |
| 44 | + if not isinstance(spec, dict): |
| 45 | + spec: dict = json.loads(Path(spec).read_text()) |
| 46 | + self.spec = spec |
| 47 | + |
| 48 | + self.names = set(self.spec.keys()) |
| 49 | + self.references = {name: None for name in self.names} |
| 50 | + for name in self.names: |
| 51 | + setattr(self, name, None) |
| 52 | + |
| 53 | + def __repr__(self): |
| 54 | + return '\n'.join([name for name in self.names if getattr(self, name)]) |
| 55 | + |
| 56 | + def __contains__(self, val: str): |
| 57 | + return val in self.names |
| 58 | + |
| 59 | + def populate( |
| 60 | + self, deriv_path, subject_id: str, session_id: str | Query | None = Query.OPTIONAL |
| 61 | + ): |
| 62 | + layout = BIDSLayout(deriv_path, validate=False) |
| 63 | + for name, query in self.spec.items(): |
| 64 | + items = layout.get( |
| 65 | + subject=subject_id, |
| 66 | + session=session_id, |
| 67 | + extension=['.nii', '.nii.gz'], |
| 68 | + **query, |
| 69 | + ) |
| 70 | + if not items or len(items) > 1: |
| 71 | + continue |
| 72 | + item = items[0] |
| 73 | + |
| 74 | + # Skip if derivative does not have valid metadata |
| 75 | + metadata = item.get_metadata() |
| 76 | + if not metadata or not (reference := metadata.get('SpatialReference')): |
| 77 | + # raise warning |
| 78 | + continue |
| 79 | + if isinstance(reference, list): |
| 80 | + if len(reference) > 1: |
| 81 | + # raise warning |
| 82 | + continue |
| 83 | + reference = reference[0] |
| 84 | + |
| 85 | + setattr(self, name, Path(item.path)) |
| 86 | + self.references[name] = (Path(deriv_path) / reference).absolute() |
11 | 87 |
|
12 | 88 |
|
13 | 89 | def write_bidsignore(deriv_dir):
|
@@ -203,50 +279,6 @@ def validate_input_dir(exec_env, bids_dir, participant_label):
|
203 | 279 | print("bids-validator does not appear to be installed", file=sys.stderr)
|
204 | 280 |
|
205 | 281 |
|
206 |
| -def collect_precomputed_derivatives(layout, subject_id, derivatives_filters=None): |
207 |
| - """ |
208 |
| - Query and collect precomputed derivatives. |
209 |
| -
|
210 |
| - This function is used to determine which workflow steps can be skipped, |
211 |
| - based on the files found. |
212 |
| - """ |
213 |
| - |
214 |
| - deriv_queries = { |
215 |
| - 'anat_mask': { |
216 |
| - 'datatype': 'anat', |
217 |
| - 'desc': 'brain', |
218 |
| - 'space': 'orig', |
219 |
| - 'suffix': 'mask', |
220 |
| - }, |
221 |
| - 'anat_aseg': { |
222 |
| - 'datatype': 'anat', |
223 |
| - 'desc': 'aseg', |
224 |
| - 'space': 'orig', |
225 |
| - 'suffix': 'dseg', |
226 |
| - }, |
227 |
| - } |
228 |
| - if derivatives_filters is not None: |
229 |
| - deriv_queries.update(derivatives_filters) |
230 |
| - |
231 |
| - derivatives = {} |
232 |
| - for deriv, query in deriv_queries.items(): |
233 |
| - res = layout.get( |
234 |
| - scope='derivatives', |
235 |
| - subject=subject_id, |
236 |
| - extension=['.nii', '.nii.gz'], |
237 |
| - return_type="filename", |
238 |
| - **query, |
239 |
| - ) |
240 |
| - if not res: |
241 |
| - continue |
242 |
| - if len(res) > 1: # Some queries may want multiple results |
243 |
| - raise Exception( |
244 |
| - f"When searching for <{deriv}>, found multiple results: {[f.path for f in res]}" |
245 |
| - ) |
246 |
| - derivatives[deriv] = res[0] |
247 |
| - return derivatives |
248 |
| - |
249 |
| - |
250 | 282 | def parse_bids_for_age_months(
|
251 | 283 | bids_root: Union[str, Path],
|
252 | 284 | subject_id: str,
|
|
0 commit comments