Skip to content

Commit 04643a4

Browse files
committed
rf: Replace explicit loop with iterables in ds_surface_masks_wf
1 parent ccc727d commit 04643a4

File tree

2 files changed

+45
-46
lines changed

2 files changed

+45
-46
lines changed

src/smriprep/workflows/anatomical.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1356,7 +1356,6 @@ def init_anat_fit_wf(
13561356
output_dir=output_dir,
13571357
mask_type='cortex',
13581358
name='ds_cortex_masks_wf',
1359-
entities={'extension': '.label.gii'},
13601359
)
13611360

13621361
workflow.connect([

src/smriprep/workflows/outputs.py

Lines changed: 45 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,57 +1275,53 @@ def init_ds_surface_masks_wf(
12751275
niu.IdentityInterface(fields=['mask_files', 'source_files']),
12761276
name='inputnode',
12771277
)
1278-
outputnode = pe.Node(niu.IdentityInterface(fields=['mask_files']), name='outputnode')
1278+
outputnode = pe.JoinNode(
1279+
niu.IdentityInterface(fields=['mask_files']), name='outputnode', joinsource='ds_itersource'
1280+
)
12791281

1280-
combine_masks = pe.Node(
1281-
niu.Merge(2),
1282-
name='combine_masks',
1282+
ds_itersource = pe.Node(
1283+
niu.IdentityInterface(fields=['hemi']),
1284+
name='ds_itersource',
1285+
iterables=[('hemi', ['L', 'R'])],
12831286
)
1284-
workflow.connect([(combine_masks, outputnode, [('out', 'mask_files')])])
12851287

1286-
for i_hemi, hemi in enumerate(['L', 'R']):
1287-
select_mask = pe.Node(
1288-
niu.Select(index=i_hemi),
1289-
name=f'select_mask_{hemi}',
1290-
run_without_submitting=True,
1291-
)
1292-
workflow.connect([(inputnode, select_mask, [('mask_files', 'inlist')])])
1288+
sources = pe.Node(niu.Function(function=_bids_relative), name='sources')
1289+
sources.inputs.bids_root = output_dir
12931290

1294-
select_source = pe.Node(
1295-
niu.Select(index=i_hemi),
1296-
name=f'select_source_{hemi}',
1297-
run_without_submitting=True,
1298-
)
1299-
workflow.connect([(inputnode, select_source, [('source_files', 'inlist')])])
1300-
1301-
sources = pe.Node(
1302-
niu.Function(function=_bids_relative),
1303-
name=f'sources_{hemi}',
1304-
)
1305-
sources.inputs.bids_root = output_dir
1291+
select_files = pe.Node(
1292+
KeySelect(fields=['mask_file', 'sources'], keys=['L', 'R']),
1293+
name='select_files',
1294+
run_without_submitting=True,
1295+
)
13061296

1307-
ds_mask = pe.Node(
1308-
DerivativesDataSink(
1309-
base_directory=output_dir,
1310-
hemi=hemi,
1311-
desc=mask_type,
1312-
**entities,
1313-
),
1314-
name=f'ds_mask_{hemi}',
1315-
run_without_submitting=True,
1316-
)
1317-
if mask_type == 'brain':
1318-
ds_mask.inputs.Type = 'Brain'
1319-
else:
1320-
ds_mask.inputs.Type = 'ROI'
1297+
ds_surf_mask = pe.Node(
1298+
DerivativesDataSink(
1299+
base_directory=output_dir,
1300+
suffix='mask',
1301+
desc=mask_type,
1302+
extension='.label.gii',
1303+
Type='Brain' if mask_type == 'brain' else 'ROI',
1304+
**entities,
1305+
),
1306+
name='ds_surf_mask',
1307+
run_without_submitting=True,
1308+
)
13211309

1322-
workflow.connect([
1323-
(select_mask, ds_mask, [('out', 'in_file')]),
1324-
(select_source, sources, [('out', 'in_files')]),
1325-
(select_source, ds_mask, [('out', 'source_file')]),
1326-
(sources, ds_mask, [('out', 'Sources')]),
1327-
(ds_mask, combine_masks, [('out_file', f'in{i_hemi + 1}')]),
1328-
]) # fmt:skip
1310+
workflow.connect([
1311+
(inputnode, select_files, [
1312+
('mask_files', 'mask_file'),
1313+
('source_files', 'sources'),
1314+
]),
1315+
(select_files, sources, [('sources', 'in_files')]),
1316+
(ds_itersource, select_files, [('hemi', 'key')]),
1317+
(ds_itersource, ds_surf_mask, [('hemi', 'hemi')]),
1318+
(select_files, ds_surf_mask, [
1319+
('mask_file', 'in_file'),
1320+
(('sources', _pop), 'source_file'),
1321+
]),
1322+
(sources, ds_surf_mask, [('out', 'Sources')]),
1323+
(ds_surf_mask, outputnode, [('out_file', 'mask_files')]),
1324+
]) # fmt: skip
13291325

13301326
return workflow
13311327

@@ -1434,3 +1430,7 @@ def _read_json(in_file):
14341430
from pathlib import Path
14351431

14361432
return loads(Path(in_file).read_text())
1433+
1434+
1435+
def _pop(in_list):
1436+
return in_list[0]

0 commit comments

Comments
 (0)