1515# You should have received a copy of the GNU Lesser General Public
1616# License along with C-PAC. If not, see <https://www.gnu.org/licenses/>.
1717"""Guardrails to protect against bad registrations"""
18+ import logging
1819from copy import deepcopy
1920from nipype .interfaces .ants import Registration
2021from nipype .interfaces .fsl import FLIRT
2324from CPAC .pipeline .nipype_pipeline_engine .utils import connect_from_spec
2425from CPAC .qc import qc_masks , REGISTRATION_GUARDRAIL_THRESHOLDS
2526
26-
27+ logger = logging . getLogger ( 'nipype.workflow' )
2728_SPEC_KEYS = {
2829 FLIRT : {'reference' : 'reference' , 'registered' : 'out_file' },
2930 Registration : {'reference' : 'reference' , 'registered' : 'out_file' }}
@@ -54,7 +55,7 @@ def __init__(self, *args, metric=None, value=None, threshold=None,
5455
5556
5657def registration_guardrail (registered : str , reference : str , retry : bool = False
57- ) -> str :
58+ ):
5859 """Check QC metrics post-registration and throw an exception if
5960 metrics are below given thresholds.
6061
@@ -78,23 +79,29 @@ def registration_guardrail(registered: str, reference: str, retry: bool = False
7879 -------
7980 registered_mask : str
8081 path to mask
82+
83+ failed_qc : int
84+ metrics met specified thresholds?, used as index for selecting
85+ outputs
8186 """
8287 qc_metrics = qc_masks (registered , reference )
88+ failed_qc = 0
8389 for metric , threshold in REGISTRATION_GUARDRAIL_THRESHOLDS .items ():
8490 if threshold is not None :
8591 value = qc_metrics .get (metric )
8692 if isinstance (value , list ):
8793 value = value [0 ]
8894 if value < threshold :
95+ failed_qc = 1
8996 with open (f'{ registered } .failed_qc' , 'w' ,
9097 encoding = 'utf-8' ) as _f :
9198 _f .write (f'{ metric } : { value } < { threshold } ' )
9299 if retry :
93100 registered = f'{ registered } -failed'
94101 else :
95- raise BadRegistrationError (metric = metric , value = value ,
96- threshold = threshold )
97- return registered
102+ logger . error ( str ( BadRegistrationError (
103+ metric = metric , value = value , threshold = threshold )) )
104+ return registered , failed_qc
98105
99106
100107def registration_guardrail_node (name = None ):
@@ -112,7 +119,8 @@ def registration_guardrail_node(name=None):
112119 name = 'registration_guardrail'
113120 return Node (Function (input_names = ['registered' ,
114121 'reference' ],
115- output_names = ['registered' ],
122+ output_names = ['registered' ,
123+ 'failed_qc' ],
116124 imports = ['from CPAC.qc import qc_masks, '
117125 'REGISTRATION_GUARDRAIL_THRESHOLDS' ,
118126 'from CPAC.registration.guardrails '
@@ -146,10 +154,10 @@ def registration_guardrail_workflow(registration_node, retry=True):
146154 (registration_node , guardrail , [(outkey , 'registered' )])])
147155 if retry :
148156 wf = retry_registration (wf , registration_node ,
149- guardrail .outputs .registered )
157+ guardrail .outputs .registered )[ 0 ]
150158 else :
151159 wf .connect (guardrail , 'registered' , outputspec , outkey )
152- wf = connect_from_spec (wf , outputspec , registration_node , outkey )
160+ connect_from_spec (outputspec , registration_node , outkey )
153161 return wf
154162
155163
@@ -167,6 +175,8 @@ def retry_registration(wf, registration_node, registered):
167175 Returns
168176 -------
169177 Workflow
178+
179+ Node
170180 """
171181 name = f'retry_{ registration_node .name } '
172182 retry_node = Node (Function (function = retry_registration_node ,
@@ -177,14 +187,14 @@ def retry_registration(wf, registration_node, registered):
177187 outputspec = registration_node .outputs
178188 outkey = spec_key (registration_node , 'registered' )
179189 guardrail = registration_guardrail_node (f'{ name } _guardrail' )
180- wf = connect_from_spec (wf , inputspec , retry_node )
190+ connect_from_spec (inputspec , retry_node )
181191 wf .connect ([
182192 (inputspec , guardrail , [
183193 (spec_key (retry_node , 'reference' ), 'reference' )]),
184194 (retry_node , guardrail , [(outkey , 'registered' )]),
185195 (guardrail , outputspec , [('registered' , outkey )])])
186- wf = connect_from_spec (wf , retry_node , outputspec , registered )
187- return wf
196+ connect_from_spec (retry_node , outputspec , registered )
197+ return wf , retry_node
188198
189199
190200def retry_registration_node (registered , registration_node ):
@@ -200,16 +210,12 @@ def retry_registration_node(registered, registration_node):
200210 -------
201211 Node
202212 """
203- from CPAC .pipeline .random_state .seed import MAX_SEED , random_seed
204- seed = random_seed ()
213+ from CPAC .pipeline .random_state .seed import seed_plus_1
205214 if registered .endswith ('-failed' ):
206215 retry_node = registration_node .clone (
207216 name = f'{ registration_node .name } -retry' )
208- if isinstance (seed , int ):
209- if seed < MAX_SEED : # increment random seed
210- retry_node .seed = seed + 1
211- else : # loop back to minumum seed
212- retry_node .seed = 1
217+ if isinstance (retry_node .seed , int ):
218+ retry_node .seed = seed_plus_1 ()
213219 return retry_node
214220 return registration_node
215221
0 commit comments