Skip to content

Commit be69c25

Browse files
committed
♻️ Refactor bbreg guardrails
1 parent e28b1a7 commit be69c25

File tree

2 files changed

+116
-127
lines changed

2 files changed

+116
-127
lines changed

CPAC/registration/guardrails.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,36 @@
3333
Registration: {'reference': 'reference', 'registered': 'out_file'}}
3434

3535

36+
def connect_retries(wf, nodes, connections):
37+
"""Function to generalize making the same connections to try and
38+
retry nodes.
39+
40+
For each 3-tuple (``conn``) in ``connections``, will do
41+
42+
.. code-block:: Python
43+
44+
wf.connect(conn[0], node, conn[1], conn[2])
45+
46+
for each node in nodes
47+
48+
Parameters
49+
----------
50+
wf : Workflow
51+
52+
nodes : iterable of Nodes
53+
54+
connections : iterable of 3-tuples of (Node, str or tuple, str)
55+
56+
Returns
57+
-------
58+
Workflow
59+
"""
60+
for node in nodes:
61+
for conn in connections:
62+
wf.connect(conn[0], node, conn[1], conn[2])
63+
return wf
64+
65+
3666
def guardrail_selection(wf: 'Workflow', node1: 'Node', node2: 'Node',
3767
output_key: str = 'registered',
3868
guardrail_node: 'Node' = None) -> Node:

CPAC/registration/registration.py

Lines changed: 86 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,16 @@
1717
"""Registration functions"""
1818
# pylint: disable=too-many-lines,ungrouped-imports,wrong-import-order
1919
# TODO: replace Tuple with tuple, Union with |, once Python >= 3.9, 3.10
20+
from sqlite3 import connect
2021
from typing import Optional, Tuple, Union
2122
from CPAC.pipeline import nipype_pipeline_engine as pe
2223
from nipype.interfaces import afni, ants, c3, fsl, utility as util
2324
from nipype.interfaces.afni import utils as afni_utils
2425
from CPAC.anat_preproc.lesion_preproc import create_lesion_preproc
2526
from CPAC.func_preproc.utils import chunk_ts, split_ts_chunks
2627
from CPAC.pipeline.random_state.seed import increment_seed
27-
from CPAC.registration.guardrails import guardrail_selection, \
28+
from CPAC.registration.guardrails import connect_retries, \
29+
guardrail_selection, \
2830
registration_guardrail_node
2931
from CPAC.registration.utils import seperate_warps_list, \
3032
check_transforms, \
@@ -974,7 +976,6 @@ def create_bbregister_func_to_anat(phase_diff_distcor=False,
974976
outputspec.anat_func : string (nifti file)
975977
Functional data in anatomical space
976978
"""
977-
from CPAC.pipeline.random_state.seed import seed_plus_1
978979
register_bbregister_func_to_anat = pe.Workflow(name=name)
979980

980981
inputspec = pe.Node(util.IdentityInterface(fields=['func',
@@ -999,14 +1000,10 @@ def create_bbregister_func_to_anat(phase_diff_distcor=False,
9991000

10001001
wm_bb_mask = pe.Node(interface=fsl.ImageMaths(),
10011002
name='wm_bb_mask')
1002-
if retry:
1003-
seed = seed_plus_1()
1004-
wm_bb_mask.seed = seed
10051003

10061004
register_bbregister_func_to_anat.connect(
10071005
inputspec, 'bbr_wm_mask_args',
10081006
wm_bb_mask, 'op_string')
1009-
10101007
register_bbregister_func_to_anat.connect(inputspec,
10111008
'anat_wm_segmentation',
10121009
wm_bb_mask, 'in_file')
@@ -1017,50 +1014,52 @@ def bbreg_args(bbreg_target):
10171014
bbreg_func_to_anat = pe.Node(interface=fsl.FLIRT(),
10181015
name='bbreg_func_to_anat')
10191016
bbreg_func_to_anat.inputs.dof = 6
1017+
guardrail_bbreg_func_to_anat = registration_guardrail_node(
1018+
f'{bbreg_func_to_anat.name}_guardrail')
1019+
nodes = [bbreg_func_to_anat]
1020+
guardrails = [guardrail_bbreg_func_to_anat]
10201021
if retry:
1021-
bbreg_func_to_anat.seed = seed
1022-
1023-
register_bbregister_func_to_anat.connect(
1024-
inputspec, 'bbr_schedule',
1025-
bbreg_func_to_anat, 'schedule')
1026-
register_bbregister_func_to_anat.connect(
1027-
wm_bb_mask, ('out_file', bbreg_args),
1028-
bbreg_func_to_anat, 'args')
1029-
register_bbregister_func_to_anat.connect(
1030-
inputspec, 'func',
1031-
bbreg_func_to_anat, 'in_file')
1032-
register_bbregister_func_to_anat.connect(
1033-
inputspec, 'anat',
1034-
bbreg_func_to_anat, 'reference')
1035-
register_bbregister_func_to_anat.connect(
1036-
inputspec, 'linear_reg_matrix',
1037-
bbreg_func_to_anat, 'in_matrix_file')
1038-
1022+
retry_bbreg_func_to_anat = increment_seed(bbreg_func_to_anat.clone(
1023+
f'retry_{bbreg_func_to_anat.name}'))
1024+
guardrail_retry_bbreg_func_to_anat = registration_guardrail_node(
1025+
f'{retry_bbreg_func_to_anat.name}_guardrail')
1026+
nodes += [retry_bbreg_func_to_anat]
1027+
guardrails += [guardrail_retry_bbreg_func_to_anat]
1028+
register_bbregister_func_to_anat = connect_retries(
1029+
register_bbregister_func_to_anat, nodes, [
1030+
(inputspec, 'bbr_schedule', 'schedule'),
1031+
(wm_bb_mask, ('out_file', bbreg_args), 'args'),
1032+
(inputspec, 'func', 'in_file'),
1033+
(inputspec, 'anat', 'reference'),
1034+
(inputspec, 'linear_reg_matrix', 'in_matrix_file')])
10391035
if phase_diff_distcor:
1036+
register_bbregister_func_to_anat = connect_retries(
1037+
register_bbregister_func_to_anat, nodes, [
1038+
(inputNode_pedir, ('pedir', convert_pedir), 'pedir'),
1039+
(inputspec, 'fieldmap', 'fieldmap'),
1040+
(inputspec, 'fieldmapmask', 'fieldmapmask'),
1041+
(inputNode_echospacing, 'echospacing', 'echospacing')])
1042+
for i, node in enumerate(nodes):
1043+
register_bbregister_func_to_anat.connect(inputspec, 'anat',
1044+
guardrails[i], 'reference')
1045+
register_bbregister_func_to_anat.connect(node, 'out_file',
1046+
guardrails[i], 'registered')
1047+
if retry:
1048+
# pylint: disable=no-value-for-parameter
1049+
outfile = guardrail_selection(register_bbregister_func_to_anat,
1050+
*guardrails)
1051+
matrix = guardrail_selection(register_bbregister_func_to_anat, *nodes,
1052+
'out_matrix_file', guardrails[0])
10401053
register_bbregister_func_to_anat.connect(
1041-
inputNode_pedir, ('pedir', convert_pedir),
1042-
bbreg_func_to_anat, 'pedir')
1043-
register_bbregister_func_to_anat.connect(
1044-
inputspec, 'fieldmap',
1045-
bbreg_func_to_anat, 'fieldmap')
1046-
register_bbregister_func_to_anat.connect(
1047-
inputspec, 'fieldmapmask',
1048-
bbreg_func_to_anat, 'fieldmapmask')
1054+
matrix, 'out', outputspec, 'func_to_anat_linear_xfm')
1055+
register_bbregister_func_to_anat.connect(outfile, 'out',
1056+
outputspec, 'anat_func')
1057+
else:
10491058
register_bbregister_func_to_anat.connect(
1050-
inputNode_echospacing, 'echospacing',
1051-
bbreg_func_to_anat, 'echospacing')
1052-
1053-
guardrail = registration_guardrail_node(name=f'{name}_guardrail')
1054-
register_bbregister_func_to_anat.connect(inputspec, 'anat',
1055-
guardrail, 'reference')
1056-
register_bbregister_func_to_anat.connect(
1057-
bbreg_func_to_anat, 'out_matrix_file',
1058-
outputspec, 'func_to_anat_linear_xfm')
1059-
register_bbregister_func_to_anat.connect(bbreg_func_to_anat, 'out_file',
1060-
guardrail, 'registered')
1061-
register_bbregister_func_to_anat.connect(guardrail, 'registered',
1062-
outputspec, 'anat_func')
1063-
1059+
bbreg_func_to_anat, 'out_matrix_file',
1060+
outputspec, 'func_to_anat_linear_xfm')
1061+
register_bbregister_func_to_anat.connect(guardrails[0], 'registered',
1062+
outputspec, 'anat_func')
10641063
return register_bbregister_func_to_anat
10651064

10661065

@@ -2871,9 +2870,11 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None):
28712870
(func_to_anat, 'outputspec.func_to_anat_linear_xfm_nobbreg')
28722871
}
28732872

2874-
if opt in [True, "fallback"]:
2873+
if opt in [True, 'fallback']:
2874+
fallback = opt == 'fallback'
28752875
func_to_anat_bbreg = create_bbregister_func_to_anat(
2876-
diff_complete, f'func_to_anat_bbreg{bbreg_status}_{pipe_num}')
2876+
diff_complete, f'func_to_anat_bbreg{bbreg_status}_{pipe_num}',
2877+
opt is True)
28772878
func_to_anat_bbreg.inputs.inputspec.bbr_schedule = \
28782879
cfg.registration_workflows['functional_registration'][
28792880
'coregistration']['boundary_based_registration'][
@@ -2882,56 +2883,31 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None):
28822883
cfg.registration_workflows['functional_registration'][
28832884
'coregistration']['boundary_based_registration'][
28842885
'bbr_wm_mask_args']
2885-
bbreg_guardrail = registration_guardrail_node(
2886-
f'bbreg{bbreg_status}_guardrail_{pipe_num}')
2887-
if opt is True:
2888-
# Retry once on failure
2889-
retry_node = create_bbregister_func_to_anat(diff_complete,
2890-
f'retry_func_to_anat_'
2891-
f'bbreg_{pipe_num}',
2892-
retry=True)
2893-
retry_node.inputs.inputspec.bbr_schedule = cfg[
2894-
'registration_workflows', 'functional_registration',
2895-
'coregistration', 'boundary_based_registration',
2896-
'bbr_schedule']
2897-
retry_node.inputs.inputspec.bbr_wm_mask_args = cfg[
2898-
'registration_workflows', 'functional_registration',
2899-
'coregistration', 'boundary_based_registration',
2900-
'bbr_wm_mask_args']
2901-
retry_guardrail = registration_guardrail_node(
2902-
f'retry_bbreg_guardrail_{pipe_num}')
2886+
if fallback:
2887+
bbreg_guardrail = registration_guardrail_node(
2888+
f'bbreg{bbreg_status}_guardrail_{pipe_num}')
29032889

29042890
node, out = strat_pool.get_data('desc-reginput_bold')
29052891
wf.connect(node, out, func_to_anat_bbreg, 'inputspec.func')
2906-
if opt is True:
2907-
wf.connect(node, out, retry_node, 'inputspec.func')
29082892

29092893
if cfg.registration_workflows['functional_registration'][
29102894
'coregistration']['boundary_based_registration'][
29112895
'reference'] == 'whole-head':
29122896
node, out = strat_pool.get_data('T1w')
29132897
wf.connect(node, out, func_to_anat_bbreg, 'inputspec.anat')
2914-
wf.connect(node, out, bbreg_guardrail, 'reference')
2915-
if opt is True:
2916-
wf.connect(node, out, retry_node, 'inputspec.anat')
2917-
wf.connect(node, out, retry_guardrail, 'reference')
2898+
if fallback:
2899+
wf.connect(node, out, bbreg_guardrail, 'reference')
29182900

29192901
elif cfg.registration_workflows['functional_registration'][
29202902
'coregistration']['boundary_based_registration'][
29212903
'reference'] == 'brain':
29222904
node, out = strat_pool.get_data('desc-brain_T1w')
29232905
wf.connect(node, out, func_to_anat_bbreg, 'inputspec.anat')
2924-
wf.connect(node, out, bbreg_guardrail, 'reference')
2925-
if opt is True:
2926-
wf.connect(node, out, retry_node, 'inputspec.anat')
2927-
wf.connect(node, out, retry_guardrail, 'reference')
2906+
if fallback:
2907+
wf.connect(node, out, bbreg_guardrail, 'reference')
29282908

29292909
wf.connect(func_to_anat, 'outputspec.func_to_anat_linear_xfm_nobbreg',
29302910
func_to_anat_bbreg, 'inputspec.linear_reg_matrix')
2931-
if opt is True:
2932-
wf.connect(func_to_anat,
2933-
'outputspec.func_to_anat_linear_xfm_nobbreg',
2934-
retry_node, 'inputspec.linear_reg_matrix')
29352911

29362912
if strat_pool.check_rpool('space-bold_label-WM_mask'):
29372913
node, out = strat_pool.get_data(["space-bold_label-WM_mask"])
@@ -2948,76 +2924,59 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None):
29482924
"label-WM_mask"])
29492925
wf.connect(node, out,
29502926
func_to_anat_bbreg, 'inputspec.anat_wm_segmentation')
2951-
if opt is True:
2952-
wf.connect(node, out, retry_node, 'inputspec.anat_wm_segmentation')
29532927

29542928
if diff_complete:
29552929
node, out = strat_pool.get_data('effectiveEchoSpacing')
29562930
wf.connect(node, out,
29572931
func_to_anat_bbreg, 'echospacing_input.echospacing')
2958-
if opt is True:
2959-
wf.connect(node, out,
2960-
retry_node, 'echospacing_input.echospacing')
29612932

29622933
node, out = strat_pool.get_data('diffphase-pedir')
29632934
wf.connect(node, out, func_to_anat_bbreg, 'pedir_input.pedir')
2964-
if opt is True:
2965-
wf.connect(node, out, retry_node, 'pedir_input.pedir')
29662935

29672936
node, out = strat_pool.get_data("despiked-fieldmap")
29682937
wf.connect(node, out, func_to_anat_bbreg, 'inputspec.fieldmap')
2969-
if opt is True:
2970-
wf.connect(node, out, retry_node, 'inputspec.fieldmap')
29712938

29722939
node, out = strat_pool.get_data("fieldmap-mask")
29732940
wf.connect(node, out,
29742941
func_to_anat_bbreg, 'inputspec.fieldmapmask')
2975-
if opt is True:
2976-
wf.connect(node, out, retry_node, 'inputspec.fieldmapmask')
2977-
2978-
wf.connect(func_to_anat_bbreg, 'outputspec.anat_func',
2979-
bbreg_guardrail, 'registered')
2980-
if opt is True:
2981-
wf.connect(func_to_anat_bbreg, 'outputspec.anat_func',
2982-
retry_guardrail, 'registered')
2983-
2984-
mean_bolds = pe.Node(util.Merge(2), run_without_submitting=True,
2985-
name=f'bbreg_mean_bold_choices_{pipe_num}')
2986-
xfms = pe.Node(util.Merge(2), run_without_submitting=True,
2987-
name=f'bbreg_xfm_choices_{pipe_num}')
2988-
fallback_mean_bolds = pe.Node(util.Select(),
2989-
run_without_submitting=True,
2990-
name=f'bbreg_choose_mean_bold_{pipe_num}'
2991-
)
2992-
fallback_xfms = pe.Node(util.Select(), run_without_submitting=True,
2993-
name=f'bbreg_choose_xfm_{pipe_num}')
2994-
if opt is True:
2995-
wf.connect([
2996-
(bbreg_guardrail, mean_bolds, [('registered', 'in1')]),
2997-
(retry_guardrail, mean_bolds, [('registered', 'in2')]),
2998-
(func_to_anat_bbreg, xfms, [
2999-
('outputspec.func_to_anat_linear_xfm', 'in1')]),
3000-
(retry_node, xfms, [
3001-
('outputspec.func_to_anat_linear_xfm', 'in2')])])
3002-
else:
2942+
if fallback:
30032943
# Fall back to no-BBReg
2944+
mean_bolds = pe.Node(util.Merge(2), run_without_submitting=True,
2945+
name=f'bbreg_mean_bold_choices_{pipe_num}')
2946+
xfms = pe.Node(util.Merge(2), run_without_submitting=True,
2947+
name=f'bbreg_xfm_choices_{pipe_num}')
2948+
fallback_mean_bolds = pe.Node(util.Select(),
2949+
run_without_submitting=True,
2950+
name='bbreg_choose_mean_bold_'
2951+
f'{pipe_num}')
2952+
fallback_xfms = pe.Node(util.Select(), run_without_submitting=True,
2953+
name=f'bbreg_choose_xfm_{pipe_num}')
30042954
wf.connect([
2955+
(func_to_anat_bbreg, bbreg_guardrail, [
2956+
('outputspec.anat_func', 'registered')]),
30052957
(bbreg_guardrail, mean_bolds, [('registered', 'in1')]),
30062958
(func_to_anat, mean_bolds, [('outputspec.anat_func_nobbreg',
30072959
'in2')]),
30082960
(func_to_anat_bbreg, xfms, [
30092961
('outputspec.func_to_anat_linear_xfm', 'in1')]),
30102962
(func_to_anat, xfms, [
3011-
('outputspec.func_to_anat_linear_xfm_nobbreg', 'in2')])])
3012-
wf.connect([
3013-
(mean_bolds, fallback_mean_bolds, [('out', 'inlist')]),
3014-
(xfms, fallback_xfms, [('out', 'inlist')]),
3015-
(bbreg_guardrail, fallback_mean_bolds, [('failed_qc', 'index')]),
3016-
(bbreg_guardrail, fallback_xfms, [('failed_qc', 'index')])])
3017-
outputs = {
3018-
'space-T1w_desc-mean_bold': (fallback_mean_bolds, 'out'),
3019-
'from-bold_to-T1w_mode-image_desc-linear_xfm': (fallback_xfms,
3020-
'out')}
2963+
('outputspec.func_to_anat_linear_xfm_nobbreg', 'in2')]),
2964+
(mean_bolds, fallback_mean_bolds, [('out', 'inlist')]),
2965+
(xfms, fallback_xfms, [('out', 'inlist')]),
2966+
(bbreg_guardrail, fallback_mean_bolds, [
2967+
('failed_qc', 'index')]),
2968+
(bbreg_guardrail, fallback_xfms, [('failed_qc', 'index')])])
2969+
outputs = {
2970+
'space-T1w_desc-mean_bold': (fallback_mean_bolds, 'out'),
2971+
'from-bold_to-T1w_mode-image_desc-linear_xfm': (fallback_xfms,
2972+
'out')}
2973+
else:
2974+
outputs = {
2975+
'space-T1w_desc-mean_bold': (func_to_anat_bbreg,
2976+
'outputspec.anat_func'),
2977+
'from-bold_to-T1w_mode-image_desc-linear_xfm': (
2978+
func_to_anat_bbreg,
2979+
'outputspec.func_to_anat_linear_xfm')}
30212980
return wf, outputs
30222981

30232982

0 commit comments

Comments
 (0)