Skip to content

Commit 64d875e

Browse files
committed
♻️ Attach guardrails directly to Nodes
1 parent 81c2471 commit 64d875e

File tree

2 files changed

+36
-54
lines changed

2 files changed

+36
-54
lines changed

CPAC/pipeline/nipype_pipeline_engine/engine.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -469,8 +469,7 @@ class GuardrailedNode:
469469
``node = wf.guardrailed_node(node, reference, registered, pipe_num)``
470470
to automatically build guardrails
471471
"""
472-
def __init__(self, wf, node, reference, registered, pipe_num,
473-
retry=True):
472+
def __init__(self, wf, node, reference, registered, pipe_num, retry):
474473
'''A Node with guardrails
475474
476475
Parameters
@@ -493,39 +492,34 @@ def __init__(self, wf, node, reference, registered, pipe_num,
493492
retry : bool
494493
retry if run is so configured
495494
'''
496-
self.guardrails = [registration_guardrail_node(
497-
f'{node.name}_guardrail_{pipe_num}')]
498495
self.node = node
496+
self.node.guardrail = registration_guardrail_node(
497+
f'{node.name}_guardrail_{pipe_num}')
499498
self.reference = reference
500499
self.registered = registered
501-
self.retries = []
500+
self.tries = [node]
502501
self.wf = wf
503-
self.wf.connect(self.node, registered,
504-
self.guardrails[0], 'registered')
505502
if retry and self.wf.num_tries > 1:
506503
if registration_guardrails.retry_on_first_failure:
507-
self.guardrails.append(registration_guardrail_node(
508-
f'{node.name}_guardrail'))
509-
self.retries.append(retry_clone(self.node))
510-
self.retries[0].interface.inputs.add_trait(
504+
self.tries.append(retry_clone(self.node))
505+
self.tries[1].interface.inputs.add_trait(
511506
'previous_failure', traits.Bool())
512-
self.guardrails.append(registration_guardrail_node(
513-
f'{self.retries[0].name}_guardrail',
514-
raise_on_failure=True))
515-
self.wf.connect(self.guardrails[0], 'failed_qc',
516-
self.retries[0], 'previous_failure')
507+
self.tries[1].guardrail = registration_guardrail_node(
508+
f'{self.tries[1].name}_guardrail',
509+
raise_on_failure=True)
510+
self.wf.connect(self.tries[0].guardrail, 'failed_qc',
511+
self.tries[1], 'previous_failure')
517512
else:
518513
num_retries = self.wf.num_tries - 1
519514
for i in range(num_retries):
520-
self.retries.append(retry_clone(self.node, i + 2))
521-
self.guardrails.append(registration_guardrail_node(
522-
f'{self.retries[-1].name}_guardrail',
523-
raise_on_failure=(i + 1 == num_retries)))
524-
for i, _retry in enumerate(self.retries):
525-
self.wf.connect(_retry, registered,
526-
self.guardrails[i + 1], 'registered')
527-
for guardrail in self.guardrails:
528-
guardrail.inputs.reference = self.reference
515+
self.tries.append(retry_clone(self.node, i + 2))
516+
self.tries[i + 1].guardrail = (
517+
registration_guardrail_node(
518+
f'{self.tries[i + 1].name}_guardrail',
519+
raise_on_failure=(i + 1 == num_retries)))
520+
for i, _try in enumerate(self.tries):
521+
self.wf.connect(_try, registered, _try.guardrail, 'registered')
522+
_try.guardrail.inputs.reference = self.reference
529523

530524
def __getattr__(self, __name):
531525
"""Get attributes from the node that is guardrailed if that
@@ -684,7 +678,8 @@ def guardrail(self):
684678
"""
685679
return any(registration_guardrails.thresholds.values())
686680

687-
def guardrailed_node(self, node, reference, registered, pipe_num):
681+
def guardrailed_node(self, node, reference, registered, pipe_num,
682+
retry=True):
688683
"""Method to return a GuardrailedNode in the given Workflow.
689684
690685
.. seealso:: Workflow.GuardrailedNode
@@ -702,9 +697,12 @@ def guardrailed_node(self, node, reference, registered, pipe_num):
702697
703698
pipe_num : int
704699
int
700+
701+
retry : bool
702+
retry if run is so configured
705703
"""
706704
return self.GuardrailedNode(self, node, reference, registered,
707-
pipe_num)
705+
pipe_num, retry)
708706

709707
def guardrail_selection(self, node: 'Workflow.GuardrailedNode',
710708
output_key: str) -> Node:

CPAC/registration/registration.py

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from nipype.interfaces.afni import utils as afni_utils
2323
from CPAC.anat_preproc.lesion_preproc import create_lesion_preproc
2424
from CPAC.func_preproc.utils import chunk_ts, split_ts_chunks
25-
from CPAC.qc.globals import registration_guardrails
2625
from CPAC.registration.utils import seperate_warps_list, \
2726
check_transforms, \
2827
generate_inverse_transform_flags, \
@@ -922,7 +921,6 @@ def create_bbregister_func_to_anat(phase_diff_distcor, name, bbreg_status,
922921
Functional data in anatomical space
923922
"""
924923
suffix = f'{bbreg_status.title()}_{pipe_num}'
925-
retry = bbreg_status == 'On'
926924
register_bbregister_func_to_anat = pe.Workflow(name=f'{name}_{suffix}')
927925
inputspec = pe.Node(util.IdentityInterface(fields=['func',
928926
'anat',
@@ -961,7 +959,8 @@ def bbreg_args(bbreg_target):
961959
name=f'bbreg_func_to_anat_{suffix}')
962960
bbreg_func_to_anat.inputs.dof = 6
963961
bbreg_func_to_anat = register_bbregister_func_to_anat.guardrailed_node(
964-
bbreg_func_to_anat, 'reference', 'out_file', pipe_num)
962+
bbreg_func_to_anat, 'reference', 'out_file', pipe_num,
963+
retry=bbreg_status == 'On')
965964
register_bbregister_func_to_anat.connect([
966965
(inputspec, bbreg_func_to_anat, [
967966
('bbr_schedule', 'schedule'),
@@ -979,21 +978,14 @@ def bbreg_args(bbreg_target):
979978
('fieldmapmask', 'fieldmapmask')]),
980979
(inputNode_echospacing, bbreg_func_to_anat, [
981980
('echospacing', 'echospacing')])])
982-
if retry and registration_guardrails.retry_on_first_failure:
983-
outfile = register_bbregister_func_to_anat.guardrail_selection(
984-
bbreg_func_to_anat, 'out_file')
985-
matrix = register_bbregister_func_to_anat.guardrail_selection(
986-
bbreg_func_to_anat, 'out_matrix_file')
987-
register_bbregister_func_to_anat.connect(
988-
matrix, 'out', outputspec, 'func_to_anat_linear_xfm')
989-
register_bbregister_func_to_anat.connect(outfile, 'out',
990-
outputspec, 'anat_func')
991-
else:
992-
register_bbregister_func_to_anat.connect(
993-
bbreg_func_to_anat, 'out_matrix_file',
994-
outputspec, 'func_to_anat_linear_xfm')
995-
register_bbregister_func_to_anat.connect(
996-
bbreg_func_to_anat, 'out_file', outputspec, 'anat_func')
981+
outfile = register_bbregister_func_to_anat.guardrail_selection(
982+
bbreg_func_to_anat, 'out_file')
983+
matrix = register_bbregister_func_to_anat.guardrail_selection(
984+
bbreg_func_to_anat, 'out_matrix_file')
985+
register_bbregister_func_to_anat.connect(
986+
matrix, 'out', outputspec, 'func_to_anat_linear_xfm')
987+
register_bbregister_func_to_anat.connect(outfile, 'out',
988+
outputspec, 'anat_func')
997989
return register_bbregister_func_to_anat
998990

999991

@@ -1827,7 +1819,7 @@ def bold_to_T1template_xfm_connector(wf_name, cfg, reg_tool, symmetric=False,
18271819
name='change_transform_type')
18281820

18291821
wf.connect(fsl_reg_2_itk, 'itk_transform',
1830-
change_transform, 'input_affine_file')
1822+
change_transform, 'input_affine_file')
18311823

18321824
# combine ALL xfm's into one - makes it easier downstream
18331825
write_composite_xfm = pe.Node(
@@ -2862,10 +2854,6 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None):
28622854
cfg.registration_workflows['functional_registration'][
28632855
'coregistration']['boundary_based_registration'][
28642856
'bbr_wm_mask_args']
2865-
if fallback:
2866-
bbreg_guardrail = pe.registration_guardrail_node(
2867-
f'bbreg{bbreg_status}_guardrail_{pipe_num}',
2868-
raise_on_failure=False)
28692857

28702858
node, out = strat_pool.get_data('sbref')
28712859
wf.connect(node, out, func_to_anat_bbreg, 'inputspec.func')
@@ -2875,16 +2863,12 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None):
28752863
'reference'] == 'whole-head':
28762864
node, out = strat_pool.get_data('desc-head_T1w')
28772865
wf.connect(node, out, func_to_anat_bbreg, 'inputspec.anat')
2878-
if fallback:
2879-
wf.connect(node, out, bbreg_guardrail, 'reference')
28802866

28812867
elif cfg.registration_workflows['functional_registration'][
28822868
'coregistration']['boundary_based_registration'][
28832869
'reference'] == 'brain':
28842870
node, out = strat_pool.get_data('desc-preproc_T1w')
28852871
wf.connect(node, out, func_to_anat_bbreg, 'inputspec.anat')
2886-
if fallback:
2887-
wf.connect(node, out, bbreg_guardrail, 'reference')
28882872

28892873
wf.connect(func_to_anat, 'outputspec.func_to_anat_linear_xfm_nobbreg',
28902874
func_to_anat_bbreg, 'inputspec.linear_reg_matrix')

0 commit comments

Comments
 (0)