Skip to content

Commit 9734973

Browse files
committed
RF: Clean up derivatives class, move validation over
1 parent 37bea41 commit 9734973

File tree

2 files changed

+45
-61
lines changed

2 files changed

+45
-61
lines changed

nibabies/utils/bids.py

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66
import json
77
import os
88
import sys
9+
import typing as ty
910
import warnings
1011
from pathlib import Path
11-
from typing import Literal, Optional, Union
1212

13+
import nibabel as nb
14+
import numpy as np
1315
from bids.layout import BIDSLayout, Query
1416

1517
_spec: dict = {
@@ -38,6 +40,13 @@
3840
class Derivatives:
3941
"""A container class for storing precomputed derivatives."""
4042

43+
def __getattribute__(self, attr):
44+
"""In cases where the spec may change, avoid errors."""
45+
try:
46+
return object.__getattribute__(self, attr)
47+
except AttributeError:
48+
return None
49+
4150
def __init__(self, spec: dict | Path | str | None = None, **args):
4251
self.spec = _spec
4352
if spec is not None:
@@ -56,9 +65,13 @@ def __repr__(self):
5665
def __contains__(self, val: str):
5766
return val in self.names
5867

68+
def __bool__(self):
69+
return any(getattr(self, name) for name in self.names)
70+
5971
def populate(
6072
self, deriv_path, subject_id: str, session_id: str | Query | None = Query.OPTIONAL
61-
):
73+
) -> None:
74+
"""Query a derivatives directory and populate values and references based on the spec."""
6275
layout = BIDSLayout(deriv_path, validate=False)
6376
for name, query in self.spec.items():
6477
items = layout.get(
@@ -82,8 +95,32 @@ def populate(
8295
continue
8396
reference = reference[0]
8497

98+
reference = (Path(deriv_path) / reference).absolute()
99+
if not self.validate(item.path, str(reference)):
100+
# raise warning
101+
continue
102+
85103
setattr(self, name, Path(item.path))
86-
self.references[name] = (Path(deriv_path) / reference).absolute()
104+
self.references[name] = reference
105+
106+
@property
107+
def mask(self) -> str | None:
108+
return self.t1w_mask or self.t2w_mask
109+
110+
@property
111+
def aseg(self) -> str | None:
112+
return self.t1w_aseg or self.t2w_aseg
113+
114+
@staticmethod
115+
def validate(derivative: str, reference: str, atol: float = 1e-5) -> bool:
116+
anat = nb.load(reference)
117+
expected_ort = nb.aff2axcodes(anat.affine)
118+
img = nb.load(derivative)
119+
if nb.aff2axcodes(img.affine) != expected_ort:
120+
return False
121+
if img.shape != anat.shape or not np.allclose(anat.affine, img.affine, atol=atol):
122+
return False
123+
return True
87124

88125

89126
def write_bidsignore(deriv_dir):
@@ -280,10 +317,10 @@ def validate_input_dir(exec_env, bids_dir, participant_label):
280317

281318

282319
def parse_bids_for_age_months(
283-
bids_root: Union[str, Path],
320+
bids_root: str | Path,
284321
subject_id: str,
285-
session_id: Optional[str] = None,
286-
) -> Optional[int]:
322+
session_id: str | None = None,
323+
) -> int | None:
287324
"""
288325
Given a BIDS root, query the BIDS metadata files for participant age, in months.
289326
@@ -309,8 +346,8 @@ def parse_bids_for_age_months(
309346

310347

311348
def _get_age_from_tsv(
312-
bids_tsv: Path, level: Literal['session', 'participant'], key: str
313-
) -> Optional[int]:
349+
bids_tsv: Path, level: ty.Literal['session', 'participant'], key: str
350+
) -> int | None:
314351
import pandas as pd
315352

316353
df = pd.read_csv(str(bids_tsv), sep='\t')

nibabies/utils/validation.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

0 commit comments

Comments
 (0)