Skip to content

Commit cc52a47

Browse files
authored
Merge pull request #418 from mgxd/fix/sloppy-template-res
FIX: Fetch templates during workflow construction
2 parents 4b6d38d + 9f13a86 commit cc52a47

File tree

3 files changed

+67
-43
lines changed

3 files changed

+67
-43
lines changed

smriprep/interfaces/templateflow.py

Lines changed: 49 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -108,34 +108,9 @@ def _run_interface(self, runtime):
108108
if isdefined(self.inputs.cohort):
109109
specs['cohort'] = self.inputs.cohort
110110

111-
name = self.inputs.template.strip(':').split(':', 1)
112-
if len(name) > 1:
113-
specs.update(
114-
{
115-
k: v
116-
for modifier in name[1].split(':')
117-
for k, v in [tuple(modifier.split('-'))]
118-
if k not in specs
119-
}
120-
)
121-
122-
if specs['resolution'] and not isinstance(specs['resolution'], list):
123-
specs['resolution'] = [specs['resolution']]
124-
125-
available_resolutions = tf.TF_LAYOUT.get_resolutions(template=name[0])
126-
if specs['resolution'] and not set(specs['resolution']) & set(available_resolutions):
127-
fallback_res = available_resolutions[0] if available_resolutions else None
128-
LOGGER.warning(
129-
f"Template {name[0]} does not have resolution(s): {specs['resolution']}."
130-
f"Falling back to resolution: {fallback_res}."
131-
)
132-
specs['resolution'] = fallback_res
133-
134-
self._results['t1w_file'] = tf.get(name[0], desc=None, suffix='T1w', **specs)
135-
136-
self._results['brain_mask'] = tf.get(
137-
name[0], desc='brain', suffix='mask', **specs
138-
) or tf.get(name[0], label='brain', suffix='mask', **specs)
111+
files = fetch_template_files(self.inputs.template, specs)
112+
self._results['t1w_file'] = files['t1w']
113+
self._results['brain_mask'] = files['mask']
139114
return runtime
140115

141116

@@ -186,3 +161,49 @@ def _run_interface(self, runtime):
186161
descsplit = desc.split('-')
187162
self._results['spec'][descsplit[0]] = descsplit[1]
188163
return runtime
164+
165+
166+
def fetch_template_files(
167+
template: str,
168+
specs: dict | None = None,
169+
sloppy: bool = False,
170+
) -> dict:
171+
if specs is None:
172+
specs = {}
173+
174+
name = template.strip(':').split(':', 1)
175+
if len(name) > 1:
176+
specs.update(
177+
{
178+
k: v
179+
for modifier in name[1].split(':')
180+
for k, v in [tuple(modifier.split('-'))]
181+
if k not in specs
182+
}
183+
)
184+
185+
if res := specs.pop('res', None):
186+
if res != 'native':
187+
specs['resolution'] = res
188+
189+
if not specs.get('resolution'):
190+
specs['resolution'] = 2 if sloppy else 1
191+
192+
if specs.get('resolution') and not isinstance(specs['resolution'], list):
193+
specs['resolution'] = [specs['resolution']]
194+
195+
available_resolutions = tf.TF_LAYOUT.get_resolutions(template=name[0])
196+
if specs.get('resolution') and not set(specs['resolution']) & set(available_resolutions):
197+
fallback_res = available_resolutions[0] if available_resolutions else None
198+
LOGGER.warning(
199+
f"Template {name[0]} does not have resolution(s): {specs['resolution']}."
200+
f"Falling back to resolution: {fallback_res}."
201+
)
202+
specs['resolution'] = fallback_res
203+
204+
files = {}
205+
files['t1w'] = tf.get(name[0], desc=None, suffix='T1w', **specs)
206+
files['mask'] = tf.get(name[0], desc='brain', suffix='mask', **specs) or tf.get(
207+
name[0], label='brain', suffix='mask', **specs
208+
)
209+
return files

smriprep/workflows/anatomical.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def init_anat_preproc_wf(
277277
omp_nthreads=omp_nthreads,
278278
skull_strip_fixed_seed=skull_strip_fixed_seed,
279279
)
280-
template_iterator_wf = init_template_iterator_wf(spaces=spaces)
280+
template_iterator_wf = init_template_iterator_wf(spaces=spaces, sloppy=sloppy)
281281
ds_std_volumes_wf = init_ds_anat_volumes_wf(
282282
bids_root=bids_root,
283283
output_dir=output_dir,
@@ -725,6 +725,7 @@ def init_anat_fit_wf(
725725
spaces=spaces,
726726
freesurfer=freesurfer,
727727
output_dir=output_dir,
728+
sloppy=sloppy,
728729
)
729730
# fmt:off
730731
workflow.connect([

smriprep/workflows/outputs.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,15 @@
3333
from niworkflows.interfaces.utility import KeySelect
3434

3535
from ..interfaces import DerivativesDataSink
36-
from ..interfaces.templateflow import TemplateFlowSelect
36+
from ..interfaces.templateflow import TemplateFlowSelect, fetch_template_files
37+
38+
if ty.TYPE_CHECKING:
39+
from niworkflows.utils.spaces import SpatialReferences
3740

3841
BIDS_TISSUE_ORDER = ('GM', 'WM', 'CSF')
3942

4043

41-
def init_anat_reports_wf(*, spaces, freesurfer, output_dir, name='anat_reports_wf'):
44+
def init_anat_reports_wf(*, spaces, freesurfer, output_dir, sloppy=False, name='anat_reports_wf'):
4245
"""
4346
Set up a battery of datasinks to store reports in the right location.
4447
@@ -131,7 +134,7 @@ def init_anat_reports_wf(*, spaces, freesurfer, output_dir, name='anat_reports_w
131134
# fmt:on
132135

133136
if spaces._cached is not None and spaces.cached.references:
134-
template_iterator_wf = init_template_iterator_wf(spaces=spaces)
137+
template_iterator_wf = init_template_iterator_wf(spaces=spaces, sloppy=sloppy)
135138
t1w_std = pe.Node(
136139
ApplyTransforms(
137140
dimension=3,
@@ -1112,7 +1115,9 @@ def init_anat_second_derivatives_wf(
11121115
return workflow
11131116

11141117

1115-
def init_template_iterator_wf(*, spaces, name='template_iterator_wf'):
1118+
def init_template_iterator_wf(
1119+
*, spaces: 'SpatialReferences', sloppy: bool = False, name='template_iterator_wf'
1120+
):
11161121
"""Prepare the necessary components to resample an image to a template space
11171122
11181123
This produces a workflow with an unjoined iterable named "spacesource".
@@ -1122,6 +1127,9 @@ def init_template_iterator_wf(*, spaces, name='template_iterator_wf'):
11221127
11231128
The fields in `outputnode` can be used as if they come from a single template.
11241129
"""
1130+
for template in spaces.get_spaces(nonstandard=False, dim=(3,)):
1131+
fetch_template_files(template, specs=None, sloppy=sloppy)
1132+
11251133
workflow = pe.Workflow(name=name)
11261134

11271135
inputnode = pe.Node(
@@ -1159,9 +1167,7 @@ def init_template_iterator_wf(*, spaces, name='template_iterator_wf'):
11591167
name='select_xfm',
11601168
run_without_submitting=True,
11611169
)
1162-
select_tpl = pe.Node(
1163-
TemplateFlowSelect(resolution=1), name='select_tpl', run_without_submitting=True
1164-
)
1170+
select_tpl = pe.Node(TemplateFlowSelect(), name='select_tpl', run_without_submitting=True)
11651171

11661172
# fmt:off
11671173
workflow.connect([
@@ -1177,7 +1183,7 @@ def init_template_iterator_wf(*, spaces, name='template_iterator_wf'):
11771183
(spacesource, select_tpl, [
11781184
('space', 'template'),
11791185
('cohort', 'cohort'),
1180-
(('resolution', _no_native), 'resolution'),
1186+
(('resolution', _no_native, sloppy), 'resolution'),
11811187
]),
11821188
(spacesource, outputnode, [
11831189
('space', 'space'),
@@ -1243,10 +1249,6 @@ def _pick_cohort(in_template):
12431249
return [_pick_cohort(v) for v in in_template]
12441250

12451251

1246-
def _fmt(in_template):
1247-
return in_template.replace(':', '_')
1248-
1249-
12501252
def _empty_report(in_file=None):
12511253
from pathlib import Path
12521254

@@ -1268,11 +1270,11 @@ def _is_native(value):
12681270
return value == 'native'
12691271

12701272

1271-
def _no_native(value):
1273+
def _no_native(value, sloppy=False):
12721274
try:
12731275
return int(value)
12741276
except (TypeError, ValueError):
1275-
return 1
1277+
return 2 if sloppy else 1
12761278

12771279

12781280
def _drop_path(in_path):

0 commit comments

Comments
 (0)