6
6
import json
7
7
import os
8
8
import sys
9
+ import typing as ty
9
10
import warnings
10
11
from pathlib import Path
11
- from typing import Literal , Optional , Union
12
12
13
+ import nibabel as nb
14
+ import numpy as np
13
15
from bids .layout import BIDSLayout , Query
14
16
15
17
_spec : dict = {
38
40
class Derivatives :
39
41
"""A container class for storing precomputed derivatives."""
40
42
43
+ def __getattribute__ (self , attr ):
44
+ """In cases where the spec may change, avoid errors."""
45
+ try :
46
+ return object .__getattribute__ (self , attr )
47
+ except AttributeError :
48
+ return None
49
+
41
50
def __init__ (self , spec : dict | Path | str | None = None , ** args ):
42
51
self .spec = _spec
43
52
if spec is not None :
@@ -56,9 +65,13 @@ def __repr__(self):
56
65
def __contains__ (self , val : str ):
57
66
return val in self .names
58
67
68
+ def __bool__ (self ):
69
+ return any (getattr (self , name ) for name in self .names )
70
+
59
71
def populate (
60
72
self , deriv_path , subject_id : str , session_id : str | Query | None = Query .OPTIONAL
61
- ):
73
+ ) -> None :
74
+ """Query a derivatives directory and populate values and references based on the spec."""
62
75
layout = BIDSLayout (deriv_path , validate = False )
63
76
for name , query in self .spec .items ():
64
77
items = layout .get (
@@ -82,8 +95,32 @@ def populate(
82
95
continue
83
96
reference = reference [0 ]
84
97
98
+ reference = (Path (deriv_path ) / reference ).absolute ()
99
+ if not self .validate (item .path , str (reference )):
100
+ # raise warning
101
+ continue
102
+
85
103
setattr (self , name , Path (item .path ))
86
- self .references [name ] = (Path (deriv_path ) / reference ).absolute ()
104
+ self .references [name ] = reference
105
+
106
+ @property
107
+ def mask (self ) -> str | None :
108
+ return self .t1w_mask or self .t2w_mask
109
+
110
+ @property
111
+ def aseg (self ) -> str | None :
112
+ return self .t1w_aseg or self .t2w_aseg
113
+
114
+ @staticmethod
115
+ def validate (derivative : str , reference : str , atol : float = 1e-5 ) -> bool :
116
+ anat = nb .load (reference )
117
+ expected_ort = nb .aff2axcodes (anat .affine )
118
+ img = nb .load (derivative )
119
+ if nb .aff2axcodes (img .affine ) != expected_ort :
120
+ return False
121
+ if img .shape != anat .shape or not np .allclose (anat .affine , img .affine , atol = atol ):
122
+ return False
123
+ return True
87
124
88
125
89
126
def write_bidsignore (deriv_dir ):
@@ -280,10 +317,10 @@ def validate_input_dir(exec_env, bids_dir, participant_label):
280
317
281
318
282
319
def parse_bids_for_age_months (
283
- bids_root : Union [ str , Path ] ,
320
+ bids_root : str | Path ,
284
321
subject_id : str ,
285
- session_id : Optional [ str ] = None ,
286
- ) -> Optional [ int ] :
322
+ session_id : str | None = None ,
323
+ ) -> int | None :
287
324
"""
288
325
Given a BIDS root, query the BIDS metadata files for participant age, in months.
289
326
@@ -309,8 +346,8 @@ def parse_bids_for_age_months(
309
346
310
347
311
348
def _get_age_from_tsv (
312
- bids_tsv : Path , level : Literal ['session' , 'participant' ], key : str
313
- ) -> Optional [ int ] :
349
+ bids_tsv : Path , level : ty . Literal ['session' , 'participant' ], key : str
350
+ ) -> int | None :
314
351
import pandas as pd
315
352
316
353
df = pd .read_csv (str (bids_tsv ), sep = '\t ' )
0 commit comments