Skip to content

Commit 7c77f4b

Browse files
committed
🚧 WIP 🥅 Iterate guardrail installation
1 parent b33c472 commit 7c77f4b

File tree

7 files changed

+137
-68
lines changed

7 files changed

+137
-68
lines changed

CPAC/pipeline/nipype_pipeline_engine/utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,11 @@
1717
"""Custom nipype utilities"""
1818

1919

20-
def connect_from_spec(wf, spec, original_spec, exclude=None):
20+
def connect_from_spec(spec, original_spec, exclude=None):
2121
"""Function to connect all original inputs to a new spec"""
2222
for _item, _value in original_spec.items():
2323
if isinstance(exclude, (list, tuple)):
2424
if _item not in exclude:
2525
setattr(spec.inputs, _item, _value)
2626
elif _item != exclude:
2727
setattr(spec.inputs, _item, _value)
28-
return wf

CPAC/pipeline/random_state/seed.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def random_seed():
6262
-------
6363
seed : int or None
6464
'''
65-
if _seed['seed'] == 'random':
65+
if _seed['seed'] in ['random', None]:
6666
_seed['seed'] = random_random_seed()
6767
return _seed['seed']
6868

@@ -153,6 +153,24 @@ def _reusable_flags():
153153
}
154154

155155

156+
def seed_plus_1(seed=None):
157+
'''Increment seed, looping back to 1 at MAX_SEED
158+
159+
Parameters
160+
----------
161+
seed : int, optional
162+
Uses configured seed if not specified
163+
164+
Returns
165+
-------
166+
int
167+
'''
168+
seed = random_seed() if seed is None else int(seed)
169+
if seed < MAX_SEED: # increment random seed
170+
return seed + 1
171+
return 1 # loop back to 1
172+
173+
156174
def set_up_random_state(seed):
157175
'''Set global random seed
158176

CPAC/pipeline/schema.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,8 @@
1818
# pylint: disable=too-many-lines
1919
import re
2020
from itertools import chain, permutations
21-
import numpy as np
2221
from pathvalidate import sanitize_filename
23-
from voluptuous import All, ALLOW_EXTRA, Any, Capitalize, Coerce, \
22+
from voluptuous import All, ALLOW_EXTRA, Any, Capitalize, Coerce, Equal, \
2423
ExactSequence, ExclusiveInvalid, In, Length, Lower, \
2524
Match, Maybe, Optional, Range, Required, Schema
2625
from CPAC import docs_prefix
@@ -492,7 +491,6 @@ def sanitize(filename):
492491
'interpolation': In({'trilinear', 'sinc', 'spline'}),
493492
'using': str,
494493
'input': str,
495-
'interpolation': str,
496494
'cost': str,
497495
'dof': int,
498496
'arguments': Maybe(str),
@@ -510,11 +508,14 @@ def sanitize(filename):
510508
},
511509
},
512510
'boundary_based_registration': {
513-
'run': forkable,
511+
'run': All(Coerce(ListFromItem),
512+
[Any(bool, All(Lower, Equal('fallback')))],
513+
Length(max=3)),
514514
'bbr_schedule': str,
515-
'bbr_wm_map': In({'probability_map', 'partial_volume_map'}),
515+
'bbr_wm_map': In(('probability_map',
516+
'partial_volume_map')),
516517
'bbr_wm_mask_args': str,
517-
'reference': In({'whole-head', 'brain'})
518+
'reference': In(('whole-head', 'brain'))
518519
},
519520
},
520521
'EPI_registration': {

CPAC/registration/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
create_fsl_fnirt_nonlinear_reg_nhp, \
44
create_register_func_to_anat, \
55
create_register_func_to_anat_use_T2, \
6-
create_bbregister_func_to_anat, \
76
create_wf_calculate_ants_warp
87

98
from .output_func_to_standard import output_func_to_standard
@@ -13,6 +12,5 @@
1312
'create_fsl_fnirt_nonlinear_reg_nhp',
1413
'create_register_func_to_anat',
1514
'create_register_func_to_anat_use_T2',
16-
'create_bbregister_func_to_anat',
1715
'create_wf_calculate_ants_warp',
1816
'output_func_to_standard']

CPAC/registration/guardrails.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# You should have received a copy of the GNU Lesser General Public
1616
# License along with C-PAC. If not, see <https://www.gnu.org/licenses/>.
1717
"""Guardrails to protect against bad registrations"""
18+
import logging
1819
from copy import deepcopy
1920
from nipype.interfaces.ants import Registration
2021
from nipype.interfaces.fsl import FLIRT
@@ -23,7 +24,7 @@
2324
from CPAC.pipeline.nipype_pipeline_engine.utils import connect_from_spec
2425
from CPAC.qc import qc_masks, REGISTRATION_GUARDRAIL_THRESHOLDS
2526

26-
27+
logger = logging.getLogger('nipype.workflow')
2728
_SPEC_KEYS = {
2829
FLIRT: {'reference': 'reference', 'registered': 'out_file'},
2930
Registration: {'reference': 'reference', 'registered': 'out_file'}}
@@ -54,7 +55,7 @@ def __init__(self, *args, metric=None, value=None, threshold=None,
5455

5556

5657
def registration_guardrail(registered: str, reference: str, retry: bool = False
57-
) -> str:
58+
):
5859
"""Check QC metrics post-registration and throw an exception if
5960
metrics are below given thresholds.
6061
@@ -78,23 +79,29 @@ def registration_guardrail(registered: str, reference: str, retry: bool = False
7879
-------
7980
registered_mask : str
8081
path to mask
82+
83+
failed_qc : int
84+
metrics met specified thresholds?, used as index for selecting
85+
outputs
8186
"""
8287
qc_metrics = qc_masks(registered, reference)
88+
failed_qc = 0
8389
for metric, threshold in REGISTRATION_GUARDRAIL_THRESHOLDS.items():
8490
if threshold is not None:
8591
value = qc_metrics.get(metric)
8692
if isinstance(value, list):
8793
value = value[0]
8894
if value < threshold:
95+
failed_qc = 1
8996
with open(f'{registered}.failed_qc', 'w',
9097
encoding='utf-8') as _f:
9198
_f.write(f'{metric}: {value} < {threshold}')
9299
if retry:
93100
registered = f'{registered}-failed'
94101
else:
95-
raise BadRegistrationError(metric=metric, value=value,
96-
threshold=threshold)
97-
return registered
102+
logger.error(str(BadRegistrationError(
103+
metric=metric, value=value, threshold=threshold)))
104+
return registered, failed_qc
98105

99106

100107
def registration_guardrail_node(name=None):
@@ -112,7 +119,8 @@ def registration_guardrail_node(name=None):
112119
name = 'registration_guardrail'
113120
return Node(Function(input_names=['registered',
114121
'reference'],
115-
output_names=['registered'],
122+
output_names=['registered',
123+
'failed_qc'],
116124
imports=['from CPAC.qc import qc_masks, '
117125
'REGISTRATION_GUARDRAIL_THRESHOLDS',
118126
'from CPAC.registration.guardrails '
@@ -146,10 +154,10 @@ def registration_guardrail_workflow(registration_node, retry=True):
146154
(registration_node, guardrail, [(outkey, 'registered')])])
147155
if retry:
148156
wf = retry_registration(wf, registration_node,
149-
guardrail.outputs.registered)
157+
guardrail.outputs.registered)[0]
150158
else:
151159
wf.connect(guardrail, 'registered', outputspec, outkey)
152-
wf = connect_from_spec(wf, outputspec, registration_node, outkey)
160+
connect_from_spec(outputspec, registration_node, outkey)
153161
return wf
154162

155163

@@ -167,6 +175,8 @@ def retry_registration(wf, registration_node, registered):
167175
Returns
168176
-------
169177
Workflow
178+
179+
Node
170180
"""
171181
name = f'retry_{registration_node.name}'
172182
retry_node = Node(Function(function=retry_registration_node,
@@ -177,14 +187,14 @@ def retry_registration(wf, registration_node, registered):
177187
outputspec = registration_node.outputs
178188
outkey = spec_key(registration_node, 'registered')
179189
guardrail = registration_guardrail_node(f'{name}_guardrail')
180-
wf = connect_from_spec(wf, inputspec, retry_node)
190+
connect_from_spec(inputspec, retry_node)
181191
wf.connect([
182192
(inputspec, guardrail, [
183193
(spec_key(retry_node, 'reference'), 'reference')]),
184194
(retry_node, guardrail, [(outkey, 'registered')]),
185195
(guardrail, outputspec, [('registered', outkey)])])
186-
wf = connect_from_spec(wf, retry_node, outputspec, registered)
187-
return wf
196+
connect_from_spec(retry_node, outputspec, registered)
197+
return wf, retry_node
188198

189199

190200
def retry_registration_node(registered, registration_node):
@@ -200,16 +210,12 @@ def retry_registration_node(registered, registration_node):
200210
-------
201211
Node
202212
"""
203-
from CPAC.pipeline.random_state.seed import MAX_SEED, random_seed
204-
seed = random_seed()
213+
from CPAC.pipeline.random_state.seed import seed_plus_1
205214
if registered.endswith('-failed'):
206215
retry_node = registration_node.clone(
207216
name=f'{registration_node.name}-retry')
208-
if isinstance(seed, int):
209-
if seed < MAX_SEED: # increment random seed
210-
retry_node.seed = seed + 1
211-
else: # loop back to minumum seed
212-
retry_node.seed = 1
217+
if isinstance(retry_node.seed, int):
218+
retry_node.seed = seed_plus_1()
213219
return retry_node
214220
return registration_node
215221

0 commit comments

Comments
 (0)