Skip to content

Commit 66e181a

Browse files
committed
Simplify base workflow.
1 parent 1e88f27 commit 66e181a

File tree

3 files changed

+183
-83
lines changed

3 files changed

+183
-83
lines changed

src/smriprep/workflows/anatomical.py

Lines changed: 22 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
init_ds_fs_segs_wf,
7070
init_ds_grayord_metrics_wf,
7171
init_ds_mask_wf,
72+
init_ds_surface_masks_wf,
7273
init_ds_surface_metrics_wf,
7374
init_ds_surfaces_wf,
7475
init_ds_template_registration_wf,
@@ -78,7 +79,7 @@
7879
)
7980
from .surfaces import (
8081
init_anat_ribbon_wf,
81-
init_cortex_mask_wf,
82+
init_cortex_masks_wf,
8283
init_fsLR_reg_wf,
8384
init_gifti_morphometrics_wf,
8485
init_gifti_surfaces_wf,
@@ -1350,57 +1351,27 @@ def init_anat_fit_wf(
13501351
if len(precomputed.get('cortex_mask', [])) < 2:
13511352
LOGGER.info('ANAT Stage 11: Creating cortical surface mask')
13521353

1353-
# Merge outputs into a single list
1354-
merge_cortex_masks = pe.Node(
1355-
niu.Merge(2),
1356-
name='merge_cortex_masks',
1357-
)
1358-
workflow.connect([(merge_cortex_masks, outputnode, [('out', 'cortex_mask')])])
1359-
1360-
for i_hemi, hemi in enumerate(['L', 'R']):
1361-
select_midthickness = pe.Node(
1362-
niu.Select(index=i_hemi),
1363-
name=f'select_midthickness_{hemi}',
1364-
)
1365-
select_thickness = pe.Node(
1366-
niu.Select(index=i_hemi),
1367-
name=f'select_thickness_{hemi}',
1368-
)
1369-
cortex_mask_wf = init_cortex_mask_wf(name=f'cortex_mask_wf_{hemi}')
1370-
cortex_mask_wf.inputs.inputnode.hemi = hemi
1371-
1372-
workflow.connect([
1373-
(surfaces_buffer, select_midthickness, [('midthickness', 'inlist')]),
1374-
(surfaces_buffer, select_thickness, [('thickness', 'inlist')]),
1375-
(select_midthickness, cortex_mask_wf, [('out', 'inputnode.midthickness')]),
1376-
(select_thickness, cortex_mask_wf, [('out', 'inputnode.thickness')]),
1377-
]) # fmt:skip
1378-
1379-
# Combine the inputs into a list
1380-
combine_inputs = pe.Node(
1381-
niu.Merge(2),
1382-
name=f'combine_inputs_{hemi}',
1383-
)
1384-
workflow.connect([
1385-
(select_midthickness, combine_inputs, [('out', 'in1')]),
1386-
(select_thickness, combine_inputs, [('out', 'in2')]),
1387-
]) # fmt:skip
1388-
1389-
ds_cortex_mask_wf = init_ds_mask_wf(
1390-
bids_root=bids_root,
1391-
output_dir=output_dir,
1392-
mask_type='cortex',
1393-
name=f'ds_cortex_mask_wf_{hemi}',
1394-
extra_entities={'extension': '.label.gii', 'hemi': hemi},
1395-
)
1354+
cortex_masks_wf = init_cortex_masks_wf()
1355+
workflow.connect([
1356+
(surfaces_buffer, cortex_masks_wf, [
1357+
('midthickness', 'inputnode.midthickness'),
1358+
('thickness', 'inputnode.thickness'),
1359+
]),
1360+
]) # fmt:skip
13961361

1397-
workflow.connect([
1398-
(cortex_mask_wf, ds_cortex_mask_wf, [('outputnode.roi', 'inputnode.mask_file')]),
1399-
(combine_inputs, ds_cortex_mask_wf, [('out', 'inputnode.source_files')]),
1400-
(ds_cortex_mask_wf, merge_cortex_masks, [
1401-
('outputnode.mask_file', f'in{i_hemi + 1}'),
1402-
]),
1403-
]) # fmt:skip
1362+
ds_cortex_masks_wf = init_ds_surface_masks_wf(
1363+
output_dir=output_dir,
1364+
mask_type='cortex',
1365+
name='ds_cortex_masks_wf',
1366+
entities={'extension': '.label.gii'},
1367+
)
1368+
workflow.connect([
1369+
(cortex_masks_wf, ds_cortex_masks_wf, [
1370+
('outputnode.cortex_masks', 'inputnode.mask_files'),
1371+
('outputnode.source_files', 'inputnode.source_files'),
1372+
]),
1373+
(ds_cortex_masks_wf, outputnode, [('outputnode.mask_files', 'cortex_mask')]),
1374+
]) # fmt:skip
14041375
else:
14051376
LOGGER.info('ANAT Stage 11: Found pre-computed cortical surface mask')
14061377
outputnode.inputs.cortex_mask = sorted(precomputed['cortex_mask'])

src/smriprep/workflows/outputs.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,6 +1231,85 @@ def init_template_iterator_wf(
12311231
return workflow
12321232

12331233

1234+
def init_ds_surface_masks_wf(
1235+
*,
1236+
output_dir: str,
1237+
mask_type: ty.Literal['cortex', 'roi', 'ribbon', 'brain'],
1238+
entities: dict[str, str] | None = None,
1239+
name='ds_surface_masks_wf',
1240+
) -> Workflow:
1241+
"""Save GIFTI surface masks.
1242+
1243+
Parameters
1244+
----------
1245+
output_dir : :class:`str`
1246+
Directory in which to save derivatives
1247+
mask_type : :class:`str`
1248+
Type of mask to save
1249+
entities : :class:`dict` of :class:`str`
1250+
Entities to include in outputs
1251+
name : :class:`str`
1252+
Workflow name (default: ds_surface_masks_wf)
1253+
1254+
Inputs
1255+
------
1256+
source_files : list of lists of str
1257+
List of lists of source files.
1258+
Left hemisphere sources first, then right hemisphere sources.
1259+
mask_files : list of str
1260+
List of input mask files.
1261+
Left hemisphere mask first, then right hemisphere mask.
1262+
1263+
Outputs
1264+
-------
1265+
mask_files : list of str
1266+
List of output mask files.
1267+
Left hemisphere mask first, then right hemisphere mask.
1268+
"""
1269+
workflow = Workflow(name=name)
1270+
1271+
if entities is None:
1272+
entities = {}
1273+
1274+
inputnode = pe.Node(
1275+
niu.IdentityInterface(fields=['mask_files', 'source_files']),
1276+
name='inputnode',
1277+
)
1278+
outputnode = pe.Node(niu.IdentityInterface(fields=['mask_files']), name='outputnode')
1279+
1280+
sources = pe.MapNode(
1281+
niu.Function(function=_bids_relative),
1282+
name='sources',
1283+
iterfield='in_files',
1284+
)
1285+
sources.inputs.bids_root = output_dir
1286+
1287+
ds_mask = pe.MapNode(
1288+
DerivativesDataSink(
1289+
base_directory=output_dir,
1290+
hemi=['L', 'R'],
1291+
desc=mask_type,
1292+
**entities,
1293+
),
1294+
iterfield=('in_file', 'hemi', 'source_file'),
1295+
name='ds_mask',
1296+
run_without_submitting=True,
1297+
)
1298+
if mask_type == 'brain':
1299+
ds_mask.inputs.Type = 'Brain'
1300+
else:
1301+
ds_mask.inputs.Type = 'ROI'
1302+
1303+
workflow.connect([
1304+
(inputnode, ds_mask, [('mask_files', 'in_file')]),
1305+
(inputnode, sources, [('source_files', 'in_files')]),
1306+
(sources, ds_mask, [('out', 'source_file')]),
1307+
(ds_mask, outputnode, [('out_file', 'mask_files')]),
1308+
]) # fmt:skip
1309+
1310+
return workflow
1311+
1312+
12341313
def _bids_relative(in_files, bids_root):
12351314
from pathlib import Path
12361315

src/smriprep/workflows/surfaces.py

Lines changed: 82 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,9 +1156,9 @@ def init_hcp_morphometrics_wf(
11561156
return workflow
11571157

11581158

1159-
def init_cortex_mask_wf(
1159+
def init_cortex_masks_wf(
11601160
*,
1161-
name: str = 'cortex_mask_wf',
1161+
name: str = 'cortex_masks_wf',
11621162
):
11631163
"""Create a cortical surface mask from a surface file.
11641164
@@ -1167,55 +1167,100 @@ def init_cortex_mask_wf(
11671167
:graph2use: orig
11681168
:simple_form: yes
11691169
1170-
from smriprep.workflows.surfaces import init_cortex_mask_wf
1171-
wf = init_cortex_mask_wf()
1170+
from smriprep.workflows.surfaces import init_cortex_masks_wf
1171+
wf = init_cortex_masks_wf()
11721172
11731173
Inputs
11741174
------
1175-
midthickness : str
1176-
One hemisphere's FreeSurfer midthickness surface file in GIFTI format
1177-
thickness : str
1178-
One hemisphere's FreeSurfer thickness file in GIFTI format
1179-
hemi : {'L', 'R'}
1180-
Hemisphere indicator
1175+
midthickness : list of str
1176+
Each hemisphere's FreeSurfer midthickness surface file in GIFTI format
1177+
thickness : list of str
1178+
Each hemisphere's FreeSurfer thickness file in GIFTI format
11811179
11821180
Outputs
11831181
-------
1184-
roi : str
1185-
Cortical surface mask in GIFTI format
1182+
cortex_masks : list of str
1183+
Cortical surface mask in GIFTI format for each hemisphere
11861184
"""
11871185
DEFAULT_MEMORY_MIN_GB = 0.01
11881186

11891187
workflow = Workflow(name=name)
11901188

11911189
inputnode = pe.Node(
1192-
niu.IdentityInterface(fields=['midthickness', 'thickness', 'hemi']),
1190+
niu.IdentityInterface(fields=['midthickness', 'thickness']),
11931191
name='inputnode',
11941192
)
1195-
outputnode = pe.Node(niu.IdentityInterface(fields=['roi']), name='outputnode')
1196-
1197-
# Thickness is presumably already positive, but HCP uses abs(-thickness)
1198-
abs_thickness = pe.Node(MetricMath(metric='thickness', operation='abs'), name='abs_thickness')
1199-
1200-
# Native ROI is thickness > 0, with holes and islands filled
1201-
initial_roi = pe.Node(MetricMath(metric='roi', operation='bin'), name='initial_roi')
1202-
fill_holes = pe.Node(MetricFillHoles(), name='fill_holes', mem_gb=DEFAULT_MEMORY_MIN_GB)
1203-
native_roi = pe.Node(MetricRemoveIslands(), name='native_roi', mem_gb=DEFAULT_MEMORY_MIN_GB)
1193+
outputnode = pe.Node(niu.IdentityInterface(fields=['cortex_masks']), name='outputnode')
12041194

1195+
# Combine the inputs into a list
1196+
combine_sources = pe.Node(
1197+
niu.Merge(2, no_flatten=True),
1198+
name='combine_sources',
1199+
)
12051200
workflow.connect([
1206-
(inputnode, abs_thickness, [
1207-
('hemi', 'hemisphere'),
1208-
('thickness', 'metric_file'),
1201+
(inputnode, combine_sources, [
1202+
('midthickness', 'in1'),
1203+
('thickness', 'in2'),
12091204
]),
1210-
(inputnode, initial_roi, [('hemi', 'hemisphere')]),
1211-
(abs_thickness, initial_roi, [('metric_file', 'metric_file')]),
1212-
(inputnode, fill_holes, [('midthickness', 'surface_file')]),
1213-
(inputnode, native_roi, [('midthickness', 'surface_file')]),
1214-
(initial_roi, fill_holes, [('metric_file', 'metric_file')]),
1215-
(fill_holes, native_roi, [('out_file', 'metric_file')]),
1216-
(native_roi, outputnode, [('out_file', 'roi')]),
1205+
(combine_sources, outputnode, [(('out', _transpose_lol), 'source_files')]),
12171206
]) # fmt:skip
12181207

1208+
combine_masks = pe.Node(
1209+
niu.Merge(2),
1210+
name='combine_masks',
1211+
)
1212+
workflow.connect([(combine_masks, outputnode, [('out', 'cortex_masks')])])
1213+
1214+
for i_hemi, hemi in enumerate(['L', 'R']):
1215+
select_midthickness = pe.Node(
1216+
niu.Select(index=i_hemi),
1217+
name=f'select_midthickness_{hemi}',
1218+
)
1219+
select_thickness = pe.Node(
1220+
niu.Select(index=i_hemi),
1221+
name=f'select_thickness_{hemi}',
1222+
)
1223+
workflow.connect([
1224+
(inputnode, select_midthickness, [('midthickness', 'inlist')]),
1225+
(inputnode, select_thickness, [('thickness', 'inlist')]),
1226+
]) # fmt:skip
1227+
1228+
# Thickness is presumably already positive, but HCP uses abs(-thickness)
1229+
abs_thickness = pe.Node(
1230+
MetricMath(metric='thickness', operation='abs'),
1231+
name=f'abs_thickness_{hemi}',
1232+
)
1233+
1234+
# Native ROI is thickness > 0, with holes and islands filled
1235+
initial_roi = pe.Node(
1236+
MetricMath(metric='roi', operation='bin'),
1237+
name=f'initial_roi_{hemi}',
1238+
)
1239+
fill_holes = pe.Node(
1240+
MetricFillHoles(),
1241+
name=f'fill_holes_{hemi}',
1242+
mem_gb=DEFAULT_MEMORY_MIN_GB,
1243+
)
1244+
native_roi = pe.Node(
1245+
MetricRemoveIslands(),
1246+
name=f'native_roi_{hemi}',
1247+
mem_gb=DEFAULT_MEMORY_MIN_GB,
1248+
)
1249+
1250+
workflow.connect([
1251+
(inputnode, abs_thickness, [
1252+
('hemi', 'hemisphere'),
1253+
('thickness', 'metric_file'),
1254+
]),
1255+
(inputnode, initial_roi, [('hemi', 'hemisphere')]),
1256+
(abs_thickness, initial_roi, [('metric_file', 'metric_file')]),
1257+
(inputnode, fill_holes, [('midthickness', 'surface_file')]),
1258+
(inputnode, native_roi, [('midthickness', 'surface_file')]),
1259+
(initial_roi, fill_holes, [('metric_file', 'metric_file')]),
1260+
(fill_holes, native_roi, [('out_file', 'metric_file')]),
1261+
(native_roi, combine_masks, [('out_file', f'in{i_hemi + 1}')]),
1262+
]) # fmt:skip
1263+
12191264
return workflow
12201265

12211266

@@ -1758,3 +1803,8 @@ def _select_seg(in_files, segmentation):
17581803

17591804
def _repeat(seq: list, count: int) -> list:
17601805
return seq * count
1806+
1807+
1808+
def _transpose_lol(inlist):
1809+
"""Transpose a list of lists."""
1810+
return list(map(list, zip(*inlist, strict=False)))

0 commit comments

Comments
 (0)