Skip to content

Commit 60f26d2

Browse files
authored
Merge pull request #450 from nipreps/fix/aparc-select
FIX: Select function in segmentation resampling workflow
2 parents f368a28 + 4738420 commit 60f26d2

File tree

2 files changed

+35
-22
lines changed

2 files changed

+35
-22
lines changed

smriprep/workflows/surfaces.py

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,11 +1170,11 @@ def init_hcp_morphometrics_wf(
11701170
def init_segs_to_native_wf(
11711171
*,
11721172
image_type: ty.Literal['T1w', 'T2w'] = 'T1w',
1173-
segmentation: ty.Literal['aseg', 'aparc_aseg', 'wmparc'] = 'aseg',
1173+
segmentation: ty.Literal['aseg', 'aparc_aseg', 'aparc_a2009s', 'aparc_dkt'] | str = 'aseg',
11741174
name: str = 'segs_to_native_wf',
11751175
) -> Workflow:
11761176
"""
1177-
Get a segmentation from FreeSurfer conformed space into native T1w space.
1177+
Get a segmentation from FreeSurfer conformed space into native anatomical space.
11781178
11791179
Workflow Graph
11801180
.. workflow::
@@ -1219,30 +1219,15 @@ def init_segs_to_native_wf(
12191219

12201220
lta = pe.Node(ConcatenateXFMs(out_fmt='fs'), name='lta', run_without_submitting=True)
12211221

1222-
# Resample from T1.mgz to T1w.nii.gz, applying any offset in fsnative2anat_xfm,
1222+
# Resample from Freesurfer anat to native anat, applying any offset in fsnative2anat_xfm,
12231223
# and convert to NIfTI while we're at it
12241224
resample = pe.Node(
12251225
fs.ApplyVolTransform(transformed_file='seg.nii.gz', interp='nearest'),
12261226
name='resample',
12271227
)
12281228

1229-
if segmentation.startswith('aparc'):
1230-
if segmentation == 'aparc_aseg':
1231-
1232-
def _sel(x):
1233-
return [parc for parc in x if 'aparc+' in parc][0] # noqa
1234-
1235-
elif segmentation == 'aparc_a2009s':
1236-
1237-
def _sel(x):
1238-
return [parc for parc in x if 'a2009s+' in parc][0] # noqa
1239-
1240-
elif segmentation == 'aparc_dkt':
1241-
1242-
def _sel(x):
1243-
return [parc for parc in x if 'DKTatlas+' in parc][0] # noqa
1244-
1245-
segmentation = (segmentation, _sel)
1229+
select_seg = pe.Node(niu.Function(function=_select_seg), name='select_seg')
1230+
select_seg.inputs.segmentation = segmentation
12461231

12471232
anat = 'T2' if image_type == 'T2w' else 'T1'
12481233

@@ -1254,7 +1239,8 @@ def _sel(x):
12541239
('fsnative2anat_xfm', 'in_xfms')]),
12551240
(fssource, lta, [(anat, 'moving')]),
12561241
(inputnode, resample, [('in_file', 'target_file')]),
1257-
(fssource, resample, [(segmentation, 'source_file')]),
1242+
(fssource, select_seg, [(segmentation, 'in_files')]),
1243+
(select_seg, resample, [('out', 'source_file')]),
12581244
(lta, resample, [('out_xfm', 'lta_file')]),
12591245
(resample, outputnode, [('transformed_file', 'out_file')]),
12601246
]) # fmt:skip
@@ -1678,3 +1664,17 @@ def _get_surfaces(subjects_dir: str, subject_id: str, surfaces: list[str]) -> tu
16781664

16791665
ret = tuple(all_surfs[surface] for surface in surfaces)
16801666
return ret if len(ret) > 1 else ret[0]
1667+
1668+
1669+
def _select_seg(in_files, segmentation):
1670+
if isinstance(in_files, str):
1671+
return in_files
1672+
1673+
seg_mapping = {'aparc_aseg': 'aparc+', 'aparc_a2009s': 'a2009s+', 'aparc_dkt': 'DKTatlas+'}
1674+
if segmentation in seg_mapping:
1675+
segmentation = seg_mapping[segmentation]
1676+
1677+
for fl in in_files:
1678+
if segmentation in fl:
1679+
return fl
1680+
raise FileNotFoundError(f'No segmentation containing "{segmentation}" was found.')

smriprep/workflows/tests/test_surfaces.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from smriprep.interfaces.tests.data import load as load_test_data
1111

12-
from ..surfaces import init_anat_ribbon_wf, init_gifti_surfaces_wf
12+
from ..surfaces import _select_seg, init_anat_ribbon_wf, init_gifti_surfaces_wf
1313

1414

1515
def test_ribbon_workflow(tmp_path: Path):
@@ -53,3 +53,16 @@ def test_ribbon_workflow(tmp_path: Path):
5353
assert np.allclose(ribbon.affine, expected.affine)
5454
# Mask data is binary, so we can use np.array_equal
5555
assert np.array_equal(ribbon.dataobj, expected.dataobj)
56+
57+
58+
@pytest.mark.parametrize(
59+
('in_files', 'segmentation', 'expected'),
60+
[
61+
('aparc+aseg.mgz', 'aparc_aseg', 'aparc+aseg.mgz'),
62+
(['a2009s+aseg.mgz', 'aparc+aseg.mgz'], 'aparc_aseg', 'aparc+aseg.mgz'),
63+
(['a2009s+aseg.mgz', 'aparc+aseg.mgz'], 'aparc_a2009s', 'a2009s+aseg.mgz'),
64+
('wmparc.mgz', 'wmparc.mgz', 'wmparc.mgz'),
65+
],
66+
)
67+
def test_select_seg(in_files, segmentation, expected):
68+
assert _select_seg(in_files, segmentation) == expected

0 commit comments

Comments
 (0)