1
1
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
2
2
# vi: set ft=python sts=4 ts=4 sw=4 et:
3
3
"""Utilities to handle BIDS inputs."""
4
+ from __future__ import annotations
5
+
4
6
import json
5
7
import os
6
8
import sys
9
+ import typing as ty
7
10
import warnings
8
- from dataclasses import dataclass , field
9
11
from pathlib import Path
10
- from typing import IO , List , Literal , Optional , Union
11
12
13
+ import nibabel as nb
14
+ import numpy as np
15
+ from bids .layout import BIDSLayout , Query
16
+
17
+ _spec : dict = {
18
+ 't1w_mask' : {
19
+ 'datatype' : 'anat' ,
20
+ 'desc' : 'brain' ,
21
+ 'space' : 'T1w' ,
22
+ 'suffix' : 'mask' ,
23
+ },
24
+ 't1w_aseg' : {'datatype' : 'anat' , 'desc' : 'aseg' , 'space' : 'T1w' , 'suffix' : 'dseg' },
25
+ 't2w_mask' : {
26
+ 'datatype' : 'anat' ,
27
+ 'desc' : 'brain' ,
28
+ 'space' : 'T2w' ,
29
+ 'suffix' : 'mask' ,
30
+ },
31
+ 't2w_aseg' : {
32
+ 'datatype' : 'anat' ,
33
+ 'desc' : 'aseg' ,
34
+ 'space' : 'T2w' ,
35
+ 'suffix' : 'dseg' ,
36
+ },
37
+ }
38
+
39
+
40
+ class Derivatives :
41
+ """
42
+ A container class for collecting and storing derivatives.
43
+
44
+ A specification (either dictionary or JSON file) can be used to customize derivatives and
45
+ queries.
46
+ To populate this class with derivatives, the `populate()` method must first be called.
47
+ """
12
48
13
- @dataclass
14
- class BOLDGrouping :
15
- """This class is used to facilitate the grouping of BOLD series."""
49
+ def __getattribute__ (self , attr ):
50
+ """In cases where the spec may change, avoid errors."""
51
+ try :
52
+ return object .__getattribute__ (self , attr )
53
+ except AttributeError :
54
+ return None
55
+
56
+ def __init__ (self , bids_root : Path | str , spec : dict | Path | str | None = None , ** args ):
57
+ self .bids_root = Path (bids_root )
58
+ self .spec = _spec
59
+ if spec is not None :
60
+ if not isinstance (spec , dict ):
61
+ spec : dict = json .loads (Path (spec ).read_text ())
62
+ self .spec = spec
63
+
64
+ self .names = set (self .spec .keys ())
65
+ self .references = {name : None for name in self .names }
66
+ for name in self .names :
67
+ setattr (self , name , None )
68
+
69
+ def __repr__ (self ):
70
+ return '\n ' .join ([name for name in self .names if getattr (self , name )])
71
+
72
+ def __contains__ (self , val : str ):
73
+ return val in self .names
74
+
75
+ def __bool__ (self ):
76
+ return any (getattr (self , name ) for name in self .names )
77
+
78
+ def populate (
79
+ self , deriv_path , subject_id : str , session_id : str | Query | None = Query .OPTIONAL
80
+ ) -> None :
81
+ """Query a derivatives directory and populate values and references based on the spec."""
82
+ layout = BIDSLayout (deriv_path , validate = False )
83
+ for name , query in self .spec .items ():
84
+ items = layout .get (
85
+ subject = subject_id ,
86
+ session = session_id ,
87
+ extension = ['.nii' , '.nii.gz' ],
88
+ ** query ,
89
+ )
90
+ if not items or len (items ) > 1 :
91
+ warnings .warn (f"Could not find { name } " )
92
+ continue
93
+ item = items [0 ]
94
+
95
+ # Skip if derivative does not have valid metadata
96
+ metadata = item .get_metadata ()
97
+ if not metadata or not (reference := metadata .get ('SpatialReference' )):
98
+ warnings .warn (f"No metadata found for { item } " )
99
+ continue
100
+ if isinstance (reference , list ):
101
+ if len (reference ) > 1 :
102
+ warnings .warn (f"Multiple reference found: { reference } " )
103
+ continue
104
+ reference = reference [0 ]
105
+
106
+ reference = self .bids_root / reference
107
+ if not self .validate (item .path , str (reference )):
108
+ warnings .warn (f"Validation failed between: { item .path } and { reference } " )
109
+ continue
110
+
111
+ setattr (self , name , Path (item .path ))
112
+ self .references [name ] = reference
16
113
17
- session : Union [str , None ]
18
- pe_dir : str
19
- readout : float
20
- multiecho_id : str = None
21
- files : List [IO ] = field (default_factory = list )
114
+ @property
115
+ def mask (self ) -> str | None :
116
+ return self .t1w_mask or self .t2w_mask
22
117
23
118
@property
24
- def name (self ) -> str :
25
- return f" { self .session } - { self .pe_dir } - { self . readout } - { self . multiecho_id } "
119
+ def aseg (self ) -> str | None :
120
+ return self .t1w_aseg or self .t2w_aseg
26
121
27
- def add_file (self , fl ) -> None :
28
- self .files .append (fl )
122
+ @staticmethod
123
+ def validate (derivative : str , reference : str , atol : float = 1e-5 ) -> bool :
124
+ anat = nb .load (reference )
125
+ expected_ort = nb .aff2axcodes (anat .affine )
126
+ img = nb .load (derivative )
127
+ if nb .aff2axcodes (img .affine ) != expected_ort :
128
+ return False
129
+ if img .shape != anat .shape or not np .allclose (anat .affine , img .affine , atol = atol ):
130
+ return False
131
+ return True
29
132
30
133
31
134
def write_bidsignore (deriv_dir ):
@@ -221,55 +324,11 @@ def validate_input_dir(exec_env, bids_dir, participant_label):
221
324
print ("bids-validator does not appear to be installed" , file = sys .stderr )
222
325
223
326
224
- def collect_precomputed_derivatives (layout , subject_id , derivatives_filters = None ):
225
- """
226
- Query and collect precomputed derivatives.
227
-
228
- This function is used to determine which workflow steps can be skipped,
229
- based on the files found.
230
- """
231
-
232
- deriv_queries = {
233
- 'anat_mask' : {
234
- 'datatype' : 'anat' ,
235
- 'desc' : 'brain' ,
236
- 'space' : 'orig' ,
237
- 'suffix' : 'mask' ,
238
- },
239
- 'anat_aseg' : {
240
- 'datatype' : 'anat' ,
241
- 'desc' : 'aseg' ,
242
- 'space' : 'orig' ,
243
- 'suffix' : 'dseg' ,
244
- },
245
- }
246
- if derivatives_filters is not None :
247
- deriv_queries .update (derivatives_filters )
248
-
249
- derivatives = {}
250
- for deriv , query in deriv_queries .items ():
251
- res = layout .get (
252
- scope = 'derivatives' ,
253
- subject = subject_id ,
254
- extension = ['.nii' , '.nii.gz' ],
255
- return_type = "filename" ,
256
- ** query ,
257
- )
258
- if not res :
259
- continue
260
- if len (res ) > 1 : # Some queries may want multiple results
261
- raise Exception (
262
- f"When searching for <{ deriv } >, found multiple results: { [f .path for f in res ]} "
263
- )
264
- derivatives [deriv ] = res [0 ]
265
- return derivatives
266
-
267
-
268
327
def parse_bids_for_age_months (
269
- bids_root : Union [ str , Path ] ,
328
+ bids_root : str | Path ,
270
329
subject_id : str ,
271
- session_id : Optional [ str ] = None ,
272
- ) -> Optional [ int ] :
330
+ session_id : str | None = None ,
331
+ ) -> int | None :
273
332
"""
274
333
Given a BIDS root, query the BIDS metadata files for participant age, in months.
275
334
@@ -295,8 +354,8 @@ def parse_bids_for_age_months(
295
354
296
355
297
356
def _get_age_from_tsv (
298
- bids_tsv : Path , level : Literal ['session' , 'participant' ], key : str
299
- ) -> Optional [ int ] :
357
+ bids_tsv : Path , level : ty . Literal ['session' , 'participant' ], key : str
358
+ ) -> int | None :
300
359
import pandas as pd
301
360
302
361
df = pd .read_csv (str (bids_tsv ), sep = '\t ' )
0 commit comments