Skip to content

Commit e32644b

Browse files
committed
🥅 Guardrail: Log error on first try, raise on second
1 parent 837533e commit e32644b

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

CPAC/registration/guardrails.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
# License along with C-PAC. If not, see <https://www.gnu.org/licenses/>.
1717
"""Guardrails to protect against bad registrations"""
1818
import logging
19+
from typing import Tuple
1920
from copy import deepcopy
2021
from nipype.interfaces.ants import Registration
2122
from nipype.interfaces.fsl import FLIRT
@@ -53,8 +54,9 @@ def __init__(self, *args, metric=None, value=None, threshold=None,
5354
super().__init__(msg, *args, **kwargs)
5455

5556

56-
def registration_guardrail(registered: str, reference: str, retry: bool = False
57-
):
57+
def registration_guardrail(registered: str, reference: str,
58+
retry: bool = False, retry_num: int = 0
59+
) -> Tuple[str, int]:
5860
"""Check QC metrics post-registration and throw an exception if
5961
metrics are below given thresholds.
6062
@@ -71,9 +73,12 @@ def registration_guardrail(registered: str, reference: str, retry: bool = False
7173
registered, reference : str
7274
path to mask
7375
74-
retry : bool
76+
retry : bool, optional
7577
can retry?
7678
79+
retry_num : int, optional
80+
how many previous tries?
81+
7782
Returns
7883
-------
7984
registered_mask : str
@@ -99,8 +104,11 @@ def registration_guardrail(registered: str, reference: str, retry: bool = False
99104
if retry:
100105
registered = f'{registered}-failed'
101106
else:
102-
logger.error(str(BadRegistrationError(
103-
metric=metric, value=value, threshold=threshold)))
107+
bad_registration = BadRegistrationError(
108+
metric=metric, value=value, threshold=threshold)
109+
logger.error(str(bad_registration))
110+
if retry_num: # if we've already retried, raise the error
111+
raise bad_registration
104112
return registered, failed_qc
105113

106114

@@ -122,6 +130,7 @@ def registration_guardrail_node(name=None):
122130
output_names=['registered',
123131
'failed_qc'],
124132
imports=['import logging',
133+
'from typing import Tuple',
125134
'from CPAC.qc import qc_masks, '
126135
'registration_guardrail_thresholds',
127136
'from CPAC.registration.guardrails '

0 commit comments

Comments
 (0)