Skip to content

Commit 3f530ff

Browse files
committed
♻️ Retry BBR on failure
1 parent e9359c2 commit 3f530ff

File tree

3 files changed

+87
-54
lines changed

3 files changed

+87
-54
lines changed

CPAC/pipeline/nipype_pipeline_engine/utils.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,22 @@
1818

1919

2020
def connect_from_spec(wf, spec, original_spec, exclude=None):
21-
"""Function to connect all original inputs to a new spec"""
21+
"""Function to connect all original inputs to a new spec
22+
23+
Parameters
24+
----------
25+
wf : Workflow
26+
27+
spec : dict
28+
29+
original_spec : dict
30+
31+
exclude : list, tuple, or dict, optional
32+
33+
Returns
34+
-------
35+
Workflow
36+
"""
2237
for _item, _value in original_spec.items():
2338
if isinstance(exclude, (list, tuple)):
2439
if _item not in exclude:

CPAC/registration/guardrails.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,36 +120,46 @@ def registration_guardrail_node(name=None):
120120
function=registration_guardrail), name=name)
121121

122122

123-
def registration_guardrail_workflow(registration_node, retry=True):
123+
def registration_guardrail_workflow(registration_node, retry=True, spec=None):
124124
"""A workflow to handle hitting a registration guardrail
125125
126126
Parameters
127127
----------
128128
name : str
129129
130-
registration_node : Node
130+
registration_node : Node or Workflow
131131
132132
retry : bool, optional
133133
134+
spec : dict, required for guardrailing function nodes
135+
Resource pool keys for reference and registered resources, in
136+
the format ``{'reference': str, 'registered': str}``
137+
134138
Returns
135139
-------
136140
Workflow
141+
142+
See Also
143+
--------
144+
spec_key
137145
"""
138146
name = f'{registration_node.name}_guardrail'
139147
wf = Workflow(name=f'{name}_wf')
140148
outputspec = deepcopy(registration_node.outputs)
141149
guardrail = registration_guardrail_node(name)
142-
outkey = spec_key(registration_node, 'registered')
150+
if spec is None:
151+
spec = {key: spec_key(registration_node, key) for
152+
key in ['reference', 'registered']}
143153
wf.connect([
144-
(registration_node, guardrail, [
145-
(spec_key(registration_node, 'reference'), 'reference')]),
146-
(registration_node, guardrail, [(outkey, 'registered')])])
154+
(registration_node, guardrail, [(spec['reference'], 'reference')]),
155+
(registration_node, guardrail, [(spec['registered'], 'registered')])])
147156
if retry:
148157
wf = retry_registration(wf, registration_node,
149158
guardrail.outputs.registered)
150159
else:
151-
wf.connect(guardrail, 'registered', outputspec, outkey)
152-
wf = connect_from_spec(wf, outputspec, registration_node, outkey)
160+
wf.connect(guardrail, 'registered', outputspec, spec['registered'])
161+
wf = connect_from_spec(wf, outputspec,
162+
registration_node, spec['registered'])
153163
return wf
154164

155165

CPAC/registration/registration.py

Lines changed: 53 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424

2525
from CPAC.anat_preproc.lesion_preproc import create_lesion_preproc
2626
from CPAC.func_preproc.utils import chunk_ts, split_ts_chunks
27-
from CPAC.registration.guardrails import registration_guardrail_node
27+
from CPAC.registration.guardrails import registration_guardrail_node, \
28+
registration_guardrail_workflow
2829
from CPAC.registration.utils import seperate_warps_list, \
2930
check_transforms, \
3031
generate_inverse_transform_flags, \
@@ -3686,14 +3687,23 @@ def single_step_resample_timeseries_to_T1template(wf, cfg, strat_pool,
36863687
"space-template_desc-brain_bold",
36873688
"space-template_desc-bold_mask"]}
36883689
""" # noqa: 501
3690+
subwf = pe.Workflow('single_step_resample_timeseries_to_T1template_'
3691+
f'{pipe_num}')
3692+
guardrail_preproc = registration_guardrail_workflow(
3693+
subwf,
3694+
spec={'reference': f'convert_bbr2itk_{pipe_num}.reference_file',
3695+
'registered': f'merge_func_to_standard_{pipe_num}.merged_file'})
3696+
guardrail_brain = registration_guardrail_workflow(
3697+
subwf,
3698+
spec={'reference': f'applyxfm_func_to_standard_{pipe_num}.'
3699+
'reference_image',
3700+
'registered': f'get_func_brain_to_standard_{pipe_num}.out_file'})
36893701
bbr2itk = pe.Node(util.Function(input_names=['reference_file',
36903702
'source_file',
36913703
'transform_file'],
36923704
output_names=['itk_transform'],
36933705
function=run_c3d),
36943706
name=f'convert_bbr2itk_{pipe_num}')
3695-
guardrail_preproc = registration_guardrail_node(
3696-
'single-step-resampling-preproc_guardrail')
36973707
if cfg.registration_workflows['functional_registration'][
36983708
'coregistration']['boundary_based_registration'][
36993709
'reference'] == 'whole-head':
@@ -3702,23 +3712,22 @@ def single_step_resample_timeseries_to_T1template(wf, cfg, strat_pool,
37023712
'coregistration']['boundary_based_registration'][
37033713
'reference'] == 'brain':
37043714
node, out = strat_pool.get_data('desc-brain_T1w')
3705-
wf.connect(node, out, bbr2itk, 'reference_file')
3706-
wf.connect(node, out, guardrail_preproc, 'reference')
3715+
subwf.connect(node, out, bbr2itk, 'reference_file')
37073716

37083717
node, out = strat_pool.get_data(['desc-reginput_bold', 'desc-mean_bold'])
3709-
wf.connect(node, out, bbr2itk, 'source_file')
3718+
subwf.connect(node, out, bbr2itk, 'source_file')
37103719

37113720
node, out = strat_pool.get_data('from-bold_to-T1w_mode-image_desc-linear_'
37123721
'xfm')
3713-
wf.connect(node, out, bbr2itk, 'transform_file')
3722+
subwf.connect(node, out, bbr2itk, 'transform_file')
37143723

37153724
split_func = pe.Node(interface=fsl.Split(),
37163725
name=f'split_func_{pipe_num}')
37173726

37183727
split_func.inputs.dimension = 't'
37193728

37203729
node, out = strat_pool.get_data('desc-stc_bold')
3721-
wf.connect(node, out, split_func, 'in_file')
3730+
subwf.connect(node, out, split_func, 'in_file')
37223731

37233732
### Loop starts! ###
37243733
motionxfm2itk = pe.MapNode(util.Function(
@@ -3731,54 +3740,53 @@ def single_step_resample_timeseries_to_T1template(wf, cfg, strat_pool,
37313740
iterfield=['transform_file'])
37323741

37333742
node, out = strat_pool.get_data('motion-basefile')
3734-
wf.connect(node, out, motionxfm2itk, 'reference_file')
3735-
wf.connect(node, out, motionxfm2itk, 'source_file')
3743+
subwf.connect(node, out, motionxfm2itk, 'reference_file')
3744+
subwf.connect(node, out, motionxfm2itk, 'source_file')
37363745

37373746
node, out = strat_pool.get_data('coordinate-transformation')
37383747
motion_correct_tool = check_prov_for_motion_tool(
37393748
strat_pool.get_cpac_provenance('coordinate-transformation'))
37403749
if motion_correct_tool == 'mcflirt':
3741-
wf.connect(node, out, motionxfm2itk, 'transform_file')
3750+
subwf.connect(node, out, motionxfm2itk, 'transform_file')
37423751
elif motion_correct_tool == '3dvolreg':
37433752
convert_transform = pe.Node(util.Function(
37443753
input_names=['one_d_filename'],
37453754
output_names=['transform_directory'],
37463755
function=one_d_to_mat,
37473756
imports=['import os', 'import numpy as np']),
37483757
name=f'convert_transform_{pipe_num}')
3749-
wf.connect(node, out, convert_transform, 'one_d_filename')
3750-
wf.connect(convert_transform, 'transform_directory',
3751-
motionxfm2itk, 'transform_file')
3758+
subwf.connect(node, out, convert_transform, 'one_d_filename')
3759+
subwf.connect(convert_transform, 'transform_directory',
3760+
motionxfm2itk, 'transform_file')
37523761

37533762
collectxfm = pe.MapNode(util.Merge(4),
37543763
name=f'collectxfm_func_to_standard_{pipe_num}',
37553764
iterfield=['in4'])
37563765

37573766
node, out = strat_pool.get_data('from-T1w_to-template_mode-image_xfm')
3758-
wf.connect(node, out, collectxfm, 'in1')
3759-
wf.connect(bbr2itk, 'itk_transform', collectxfm, 'in2')
3767+
subwf.connect(node, out, collectxfm, 'in1')
3768+
subwf.connect(bbr2itk, 'itk_transform', collectxfm, 'in2')
37603769

37613770
collectxfm.inputs.in3 = 'identity'
37623771

3763-
wf.connect(motionxfm2itk, 'itk_transform',
3764-
collectxfm, 'in4')
3772+
subwf.connect(motionxfm2itk, 'itk_transform',
3773+
collectxfm, 'in4')
37653774

37663775
applyxfm_func_to_standard = pe.MapNode(interface=ants.ApplyTransforms(),
3767-
name=f'applyxfm_func_to_standard_{pipe_num}',
3768-
iterfield=['input_image', 'transforms'])
3776+
name='applyxfm_func_to_standard_'
3777+
f'{pipe_num}',
3778+
iterfield=['input_image',
3779+
'transforms'])
37693780

37703781
applyxfm_func_to_standard.inputs.float = True
37713782
applyxfm_func_to_standard.inputs.interpolation = 'LanczosWindowedSinc'
3772-
guardrail_brain = registration_guardrail_node(
3773-
'single-step-resampling-brain_guardrail')
37743783

3775-
wf.connect(split_func, 'out_files',
3776-
applyxfm_func_to_standard, 'input_image')
3784+
subwf.connect(split_func, 'out_files',
3785+
applyxfm_func_to_standard, 'input_image')
37773786

37783787
node, out = strat_pool.get_data('T1w-brain-template-funcreg')
3779-
wf.connect(node, out, applyxfm_func_to_standard, 'reference_image')
3780-
wf.connect(node, out, guardrail_brain, 'reference')
3781-
wf.connect(collectxfm, 'out', applyxfm_func_to_standard, 'transforms')
3788+
subwf.connect(node, out, applyxfm_func_to_standard, 'reference_image')
3789+
subwf.connect(collectxfm, 'out', applyxfm_func_to_standard, 'transforms')
37823790

37833791
### Loop ends! ###
37843792

@@ -3787,8 +3795,8 @@ def single_step_resample_timeseries_to_T1template(wf, cfg, strat_pool,
37873795

37883796
merge_func_to_standard.inputs.dimension = 't'
37893797

3790-
wf.connect(applyxfm_func_to_standard, 'output_image',
3791-
merge_func_to_standard, 'in_files')
3798+
subwf.connect(applyxfm_func_to_standard, 'output_image',
3799+
merge_func_to_standard, 'in_files')
37923800

37933801
applyxfm_func_mask_to_standard = pe.Node(interface=ants.ApplyTransforms(),
37943802
name='applyxfm_func_mask_to_'
@@ -3797,34 +3805,34 @@ def single_step_resample_timeseries_to_T1template(wf, cfg, strat_pool,
37973805
applyxfm_func_mask_to_standard.inputs.interpolation = 'MultiLabel'
37983806

37993807
node, out = strat_pool.get_data('space-bold_desc-brain_mask')
3800-
wf.connect(node, out, applyxfm_func_mask_to_standard, 'input_image')
3808+
subwf.connect(node, out, applyxfm_func_mask_to_standard, 'input_image')
38013809

38023810
node, out = strat_pool.get_data('T1w-brain-template-funcreg')
3803-
wf.connect(node, out, applyxfm_func_mask_to_standard, 'reference_image')
3811+
subwf.connect(node, out, applyxfm_func_mask_to_standard, 'reference_image')
38043812

38053813
collectxfm_mask = pe.Node(util.Merge(2),
3806-
name=f'collectxfm_func_mask_to_standard_{pipe_num}')
3814+
name='collectxfm_func_mask_to_standard_'
3815+
f'{pipe_num}')
38073816

38083817
node, out = strat_pool.get_data('from-T1w_to-template_mode-image_xfm')
3809-
wf.connect(node, out, collectxfm_mask, 'in1')
3810-
wf.connect(bbr2itk, 'itk_transform', collectxfm_mask, 'in2')
3811-
wf.connect(collectxfm_mask, 'out',
3812-
applyxfm_func_mask_to_standard, 'transforms')
3818+
subwf.connect(node, out, collectxfm_mask, 'in1')
3819+
subwf.connect(bbr2itk, 'itk_transform', collectxfm_mask, 'in2')
3820+
subwf.connect(collectxfm_mask, 'out',
3821+
applyxfm_func_mask_to_standard, 'transforms')
38133822

38143823
apply_mask = pe.Node(interface=fsl.maths.ApplyMask(),
38153824
name=f'get_func_brain_to_standard_{pipe_num}')
38163825

3817-
wf.connect(merge_func_to_standard, 'merged_file',
3818-
apply_mask, 'in_file')
3819-
wf.connect(applyxfm_func_mask_to_standard, 'output_image',
3820-
apply_mask, 'mask_file')
3821-
wf.connect(merge_func_to_standard, 'merged_file',
3822-
guardrail_preproc, 'registered')
3823-
wf.connect(apply_mask, 'out_file', guardrail_brain, 'registered')
3826+
subwf.connect(merge_func_to_standard, 'merged_file',
3827+
apply_mask, 'in_file')
3828+
subwf.connect(applyxfm_func_mask_to_standard, 'output_image',
3829+
apply_mask, 'mask_file')
38243830

38253831
outputs = {
3826-
'space-template_desc-preproc_bold': (guardrail_preproc, 'registered'),
3827-
'space-template_desc-brain_bold': (guardrail_brain, 'registered'),
3832+
'space-template_desc-preproc_bold': (guardrail_preproc,
3833+
'outputspec.merged_file'),
3834+
'space-template_desc-brain_bold': (guardrail_brain,
3835+
'outputspec.out_file'),
38283836
'space-template_desc-bold_mask': (applyxfm_func_mask_to_standard,
38293837
'output_image'),
38303838
}

0 commit comments

Comments
 (0)