Skip to content

Commit c83ed62

Browse files
committed
♻️ Rewire guardrail for anat_mni_ants_register
1 parent 88b5589 commit c83ed62

File tree

2 files changed

+90
-57
lines changed

2 files changed

+90
-57
lines changed

CPAC/registration/guardrails.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from copy import deepcopy
2121
from nipype.interfaces.ants import Registration
2222
from nipype.interfaces.fsl import FLIRT
23-
from nipype.interfaces.utility import Function
23+
from nipype.interfaces.utility import Function, Merge, Select
2424
from CPAC.pipeline.nipype_pipeline_engine import Node, Workflow
2525
# from CPAC.pipeline.nipype_pipeline_engine.utils import connect_from_spec
2626
from CPAC.qc import qc_masks, registration_guardrail_thresholds
@@ -56,6 +56,36 @@ def __init__(self, *args, metric=None, value=None, threshold=None,
5656
super().__init__(msg, *args, **kwargs)
5757

5858

59+
def guardrail_selection(wf: 'Workflow', node1: 'Node', node2: 'Node',
60+
) -> Node:
61+
"""Generate requisite Nodes for choosing a path through the graph
62+
with retries
63+
64+
Parameters
65+
----------
66+
wf : Workflow
67+
68+
node1, node2 : Node
69+
try guardrail, retry guardrail
70+
71+
Returns
72+
-------
73+
select : Node
74+
"""
75+
# pylint: disable=redefined-outer-name,reimported,unused-import
76+
from CPAC.pipeline.nipype_pipeline_engine import Node, Workflow
77+
name = node1.name
78+
choices = Node(Merge(2), run_without_submitting=True,
79+
name=f'{name}_choices')
80+
select = Node(Select(), run_without_submitting=True,
81+
name=f'choose_{name}')
82+
wf.connect([(node1, choices, [('registered', 'in1')]),
83+
(node2, choices, [('registered', 'in2')]),
84+
(choices, select, [('out', 'inlist')]),
85+
(node1, select, [('failed_qc', 'index')])])
86+
return select
87+
88+
5989
def registration_guardrail(registered: str, reference: str,
6090
retry: bool = False, retry_num: int = 0
6191
) -> Tuple[str, int]:

CPAC/registration/registration.py

Lines changed: 59 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
from CPAC.pipeline import nipype_pipeline_engine as pe
2222
from nipype.interfaces import afni, ants, c3, fsl, utility as util
2323
from nipype.interfaces.afni import utils as afni_utils
24-
from nipype.interfaces.utility import Merge, Select
2524
from CPAC.anat_preproc.lesion_preproc import create_lesion_preproc
2625
from CPAC.func_preproc.utils import chunk_ts, split_ts_chunks
27-
from CPAC.registration.guardrails import registration_guardrail_node
26+
from CPAC.registration.guardrails import guardrail_selection, \
27+
registration_guardrail_node
2828
from CPAC.registration.utils import seperate_warps_list, \
2929
check_transforms, \
3030
generate_inverse_transform_flags, \
@@ -1174,20 +1174,14 @@ def create_wf_calculate_ants_warp(name='create_wf_calculate_ants_warp',
11741174
.. image::
11751175
:width: 500
11761176
'''
1177-
1177+
from CPAC.registration.guardrails import retry_hardcoded_reg
11781178
calc_ants_warp_wf = pe.Workflow(name=name)
11791179

1180-
inputspec = pe.Node(util.IdentityInterface(
1181-
fields=['moving_brain',
1182-
'reference_brain',
1183-
'moving_skull',
1184-
'reference_skull',
1185-
'reference_mask',
1186-
'moving_mask',
1187-
'fixed_image_mask',
1188-
'ants_para',
1189-
'interp']),
1190-
name='inputspec')
1180+
warp_inputs = ['moving_brain', 'reference_brain', 'moving_skull',
1181+
'reference_skull', 'ants_para', 'moving_mask',
1182+
'reference_mask', 'fixed_image_mask', 'interp']
1183+
inputspec = pe.Node(util.IdentityInterface(fields=warp_inputs),
1184+
name='inputspec')
11911185

11921186
outputspec = pe.Node(util.IdentityInterface(
11931187
fields=['ants_initial_xfm',
@@ -1208,27 +1202,30 @@ def create_wf_calculate_ants_warp(name='create_wf_calculate_ants_warp',
12081202
calculate_ants_warp.inputs.initial_moving_transform_com = 0
12091203
'''
12101204
reg_imports = ['import os', 'import subprocess']
1211-
calculate_ants_warp = \
1212-
pe.Node(interface=util.Function(input_names=['moving_brain',
1213-
'reference_brain',
1214-
'moving_skull',
1215-
'reference_skull',
1216-
'ants_para',
1217-
'moving_mask',
1218-
'reference_mask',
1219-
'fixed_image_mask',
1220-
'interp',
1221-
'reg_with_skull'],
1222-
output_names=['warp_list',
1223-
'warped_image'],
1224-
function=hardcoded_reg,
1225-
imports=reg_imports),
1226-
name='calc_ants_warp',
1227-
mem_gb=2.8,
1228-
mem_x=(2e-7, 'moving_brain', 'xyz'))
1205+
warp_inputs += ['reg_with_skull']
1206+
warp_outputs = ['warp_list', 'warped_image']
1207+
calculate_ants_warp = pe.Node(
1208+
interface=util.Function(input_names=warp_inputs,
1209+
output_names=warp_outputs,
1210+
function=hardcoded_reg,
1211+
imports=reg_imports),
1212+
name='calc_ants_warp', mem_gb=2.8,
1213+
mem_x=(2e-7, 'moving_brain', 'xyz'))
1214+
retry_calculate_ants_warp = pe.Node(
1215+
interface=util.Function(input_names=[*warp_inputs, 'previous_failure'],
1216+
output_names=warp_outputs,
1217+
function=retry_hardcoded_reg,
1218+
imports=['from CPAC.registration.utils '
1219+
'import hardcoded_reg',
1220+
'from CPAC.utils.docs import '
1221+
'retry_docstring']),
1222+
name='retry_calc_ants_warp', mem_gb=2.8,
1223+
mem_x=(2e-7, 'moving_brain', 'xyz'))
1224+
guardrails = tuple(registration_guardrail_node(
1225+
f'{_try}{name}_guardrail', i) for i, _try in enumerate(('', 'retry_')))
12291226

12301227
calculate_ants_warp.interface.num_threads = num_threads
1231-
1228+
retry_calculate_ants_warp.interface.num_threads = num_threads
12321229
select_forward_initial = pe.Node(util.Function(
12331230
input_names=['warp_list', 'selection'],
12341231
output_names=['selected_warp'],
@@ -1264,13 +1261,10 @@ def create_wf_calculate_ants_warp(name='create_wf_calculate_ants_warp',
12641261

12651262
select_inverse_warp.inputs.selection = "Inverse"
12661263

1267-
guardrail = registration_guardrail_node(f'{name}_guardrail')
12681264
calc_ants_warp_wf.connect(inputspec, 'moving_brain',
12691265
calculate_ants_warp, 'moving_brain')
12701266
calc_ants_warp_wf.connect(inputspec, 'reference_brain',
12711267
calculate_ants_warp, 'reference_brain')
1272-
calc_ants_warp_wf.connect(inputspec, 'reference_brain',
1273-
guardrail, 'reference')
12741268

12751269
if reg_ants_skull == 1:
12761270
calculate_ants_warp.inputs.reg_with_skull = 1
@@ -1279,11 +1273,17 @@ def create_wf_calculate_ants_warp(name='create_wf_calculate_ants_warp',
12791273
calculate_ants_warp, 'moving_skull')
12801274
calc_ants_warp_wf.connect(inputspec, 'reference_skull',
12811275
calculate_ants_warp, 'reference_skull')
1276+
for guardrail in guardrails:
1277+
calc_ants_warp_wf.connect(inputspec, 'reference_skull',
1278+
guardrail, 'reference')
12821279
else:
12831280
calc_ants_warp_wf.connect(inputspec, 'moving_brain',
12841281
calculate_ants_warp, 'moving_skull')
12851282
calc_ants_warp_wf.connect(inputspec, 'reference_brain',
12861283
calculate_ants_warp, 'reference_skull')
1284+
for guardrail in guardrails:
1285+
calc_ants_warp_wf.connect(inputspec, 'reference_brain',
1286+
guardrail, 'reference')
12871287

12881288
calc_ants_warp_wf.connect(inputspec, 'fixed_image_mask',
12891289
calculate_ants_warp, 'fixed_image_mask')
@@ -1317,9 +1317,11 @@ def create_wf_calculate_ants_warp(name='create_wf_calculate_ants_warp',
13171317
outputspec, 'warp_field')
13181318
calc_ants_warp_wf.connect(select_inverse_warp, 'selected_warp',
13191319
outputspec, 'inverse_warp_field')
1320-
calc_ants_warp_wf.connect(calculate_ants_warp, 'warped_image',
1321-
guardrail, 'registered')
1322-
calc_ants_warp_wf.connect(guardrail, 'registered',
1320+
for guardrail in guardrails:
1321+
calc_ants_warp_wf.connect(calculate_ants_warp, 'warped_image',
1322+
guardrail, 'registered')
1323+
select = guardrail_selection(calc_ants_warp_wf, *guardrails)
1324+
calc_ants_warp_wf.connect(select, 'out',
13231325
outputspec, 'normalized_output_brain')
13241326

13251327
return calc_ants_warp_wf
@@ -2928,38 +2930,39 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None):
29282930
wf.connect(func_to_anat_bbreg, 'outputspec.anat_func',
29292931
retry_guardrail, 'registered')
29302932

2931-
mean_bolds = pe.Node(Merge(2), run_without_submitting=True,
2933+
mean_bolds = pe.Node(util.Merge(2), run_without_submitting=True,
29322934
name=f'bbreg_mean_bold_choices_{pipe_num}')
2933-
xfms = pe.Node(Merge(2), run_without_submitting=True,
2935+
xfms = pe.Node(util.Merge(2), run_without_submitting=True,
29342936
name=f'bbreg_xfm_choices_{pipe_num}')
2935-
fallback_mean_bolds = pe.Node(Select, run_without_submitting=True,
2937+
fallback_mean_bolds = pe.Node(util.Select(),
2938+
run_without_submitting=True,
29362939
name=f'bbreg_choose_mean_bold_{pipe_num}'
29372940
)
2938-
fallback_xfms = pe.Node(Select, run_without_submitting=True,
2941+
fallback_xfms = pe.Node(util.Select(), run_without_submitting=True,
29392942
name=f'bbreg_choose_xfm_{pipe_num}')
29402943
if opt is True:
29412944
wf.connect([
2942-
(bbreg_guardrail, mean_bolds, ['registered', 'in1']),
2943-
(retry_guardrail, mean_bolds, ['registered', 'in1']),
2945+
(bbreg_guardrail, mean_bolds, [('registered', 'in1')]),
2946+
(retry_guardrail, mean_bolds, [('registered', 'in2')]),
29442947
(func_to_anat_bbreg, xfms, [
2945-
'outputspec.func_to_anat_linear_xfm', 'in2']),
2948+
('outputspec.func_to_anat_linear_xfm', 'in1')]),
29462949
(retry_node, xfms, [
2947-
'outputspec.func_to_anat_linear_xfm_nobbreg', 'in2'])])
2950+
('outputspec.func_to_anat_linear_xfm', 'in2')])])
29482951
else:
29492952
# Fall back to no-BBReg
29502953
wf.connect([
2951-
(bbreg_guardrail, mean_bolds, ['registered', 'in1']),
2952-
(func_to_anat, mean_bolds, ['outputspec.anat_func_nobbreg',
2953-
'in1']),
2954+
(bbreg_guardrail, mean_bolds, [('registered', 'in1')]),
2955+
(func_to_anat, mean_bolds, [('outputspec.anat_func_nobbreg',
2956+
'in2')]),
29542957
(func_to_anat_bbreg, xfms, [
2955-
'outputspec.func_to_anat_linear_xfm', 'in2']),
2958+
('outputspec.func_to_anat_linear_xfm', 'in1')]),
29562959
(func_to_anat, xfms, [
2957-
'outputspec.func_to_anat_linear_xfm_nobbreg', 'in2'])])
2960+
('outputspec.func_to_anat_linear_xfm_nobbreg', 'in2')])])
29582961
wf.connect([
2959-
(mean_bolds, fallback_mean_bolds, ['out', 'inlist']),
2960-
(xfms, fallback_xfms, ['out', 'inlist']),
2961-
(bbreg_guardrail, fallback_mean_bolds, ['failed_qc', 'index']),
2962-
(bbreg_guardrail, fallback_xfms, ['failed_qc', 'index'])])
2962+
(mean_bolds, fallback_mean_bolds, [('out', 'inlist')]),
2963+
(xfms, fallback_xfms, [('out', 'inlist')]),
2964+
(bbreg_guardrail, fallback_mean_bolds, [('failed_qc', 'index')]),
2965+
(bbreg_guardrail, fallback_xfms, [('failed_qc', 'index')])])
29632966
outputs = {
29642967
'space-T1w_desc-mean_bold': (fallback_mean_bolds, 'out'),
29652968
'from-bold_to-T1w_mode-image_desc-linear_xfm': (fallback_xfms,

0 commit comments

Comments
 (0)