Skip to content

Commit 837533e

Browse files
committed
♻️ Better handling of global thresholds
1 parent 328b569 commit 837533e

File tree

4 files changed

+37
-11
lines changed

4 files changed

+37
-11
lines changed

CPAC/qc/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# You should have received a copy of the GNU Lesser General Public
1616
# License along with C-PAC. If not, see <https://www.gnu.org/licenses/>.
1717
"""Quality control utilities for C-PAC"""
18-
from CPAC.qc.globals import REGISTRATION_GUARDRAIL_THRESHOLDS
18+
from CPAC.qc.globals import registration_guardrail_thresholds, \
19+
update_thresholds
1920
from CPAC.qc.qcmetrics import qc_masks
20-
__all__ = ['qc_masks', 'REGISTRATION_GUARDRAIL_THRESHOLDS']
21+
__all__ = ['qc_masks', 'registration_guardrail_thresholds',
22+
'update_thresholds']

CPAC/qc/globals.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,28 @@
1515
# You should have received a copy of the GNU Lesser General Public
1616
# License along with C-PAC. If not, see <https://www.gnu.org/licenses/>.
1717
"""Global QC values"""
18-
REGISTRATION_GUARDRAIL_THRESHOLDS = {}
18+
_REGISTRATION_GUARDRAIL_THRESHOLDS = {'thresholds': {}}
19+
20+
21+
def registration_guardrail_thresholds() -> dict:
22+
"""Get registration guardrail thresholds
23+
24+
Returns
25+
-------
26+
dict
27+
"""
28+
return _REGISTRATION_GUARDRAIL_THRESHOLDS['thresholds']
29+
30+
31+
def update_thresholds(thresholds) -> None:
32+
"""Set a registration guardrail threshold
33+
34+
Parameters
35+
----------
36+
thresholds : dict of {str: float or int}
37+
38+
Returns
39+
-------
40+
None
41+
"""
42+
_REGISTRATION_GUARDRAIL_THRESHOLDS['thresholds'].update(thresholds)

CPAC/registration/guardrails.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
# You should have received a copy of the GNU Lesser General Public
1616
# License along with C-PAC. If not, see <https://www.gnu.org/licenses/>.
1717
"""Guardrails to protect against bad registrations"""
18+
import logging
1819
from copy import deepcopy
1920
from nipype.interfaces.ants import Registration
2021
from nipype.interfaces.fsl import FLIRT
2122
from nipype.interfaces.utility import Function
2223
from CPAC.pipeline.nipype_pipeline_engine import Node, Workflow
2324
from CPAC.pipeline.nipype_pipeline_engine.utils import connect_from_spec
24-
from CPAC.qc import qc_masks, REGISTRATION_GUARDRAIL_THRESHOLDS
25+
from CPAC.qc import qc_masks, registration_guardrail_thresholds
2526

2627
_SPEC_KEYS = {
2728
FLIRT: {'reference': 'reference', 'registered': 'out_file'},
@@ -82,11 +83,10 @@ def registration_guardrail(registered: str, reference: str, retry: bool = False
8283
metrics met specified thresholds?, used as index for selecting
8384
outputs
8485
"""
85-
import logging
8686
logger = logging.getLogger('nipype.workflow')
8787
qc_metrics = qc_masks(registered, reference)
8888
failed_qc = 0
89-
for metric, threshold in REGISTRATION_GUARDRAIL_THRESHOLDS.items():
89+
for metric, threshold in registration_guardrail_thresholds().items():
9090
if threshold is not None:
9191
value = qc_metrics.get(metric)
9292
if isinstance(value, list):
@@ -121,8 +121,9 @@ def registration_guardrail_node(name=None):
121121
'reference'],
122122
output_names=['registered',
123123
'failed_qc'],
124-
imports=['from CPAC.qc import qc_masks, '
125-
'REGISTRATION_GUARDRAIL_THRESHOLDS',
124+
imports=['import logging',
125+
'from CPAC.qc import qc_masks, '
126+
'registration_guardrail_thresholds',
126127
'from CPAC.registration.guardrails '
127128
'import BadRegistrationError'],
128129
function=registration_guardrail), name=name)

CPAC/utils/configuration/configuration.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from typing import Optional, Tuple
2424
from warnings import warn
2525
import yaml
26-
from CPAC.qc import REGISTRATION_GUARDRAIL_THRESHOLDS
26+
from CPAC.qc import update_thresholds
2727
from CPAC.utils.utils import load_preconfig
2828
from .diff import dct_diff
2929

@@ -152,8 +152,7 @@ def __init__(self, config_map=None):
152152
setattr(self, key, set_from_ENV(config_map[key]))
153153

154154
# set global QC thresholds
155-
REGISTRATION_GUARDRAIL_THRESHOLDS.update(self[
156-
'registration_workflows', 'quality_thresholds'])
155+
update_thresholds(self['registration_workflows', 'quality_thresholds'])
157156

158157
self.__update_attr()
159158

0 commit comments

Comments
 (0)