1515# You should have received a copy of the GNU Lesser General Public 
1616# License along with C-PAC. If not, see <https://www.gnu.org/licenses/>. 
1717"""Guardrails to protect against bad registrations""" 
18+ import  logging 
1819from  copy  import  deepcopy 
1920from  nipype .interfaces .ants  import  Registration 
2021from  nipype .interfaces .fsl  import  FLIRT 
2324from  CPAC .pipeline .nipype_pipeline_engine .utils  import  connect_from_spec 
2425from  CPAC .qc  import  qc_masks , REGISTRATION_GUARDRAIL_THRESHOLDS 
2526
26- 
27+ logger   =   logging . getLogger ( 'nipype.workflow' ) 
2728_SPEC_KEYS  =  {
2829    FLIRT : {'reference' : 'reference' , 'registered' : 'out_file' },
2930    Registration : {'reference' : 'reference' , 'registered' : 'out_file' }}
@@ -54,7 +55,7 @@ def __init__(self, *args, metric=None, value=None, threshold=None,
5455
5556
5657def  registration_guardrail (registered : str , reference : str , retry : bool  =  False 
57-                            )  ->   str :
58+                            ):
5859    """Check QC metrics post-registration and throw an exception if 
5960    metrics are below given thresholds. 
6061
@@ -78,23 +79,29 @@ def registration_guardrail(registered: str, reference: str, retry: bool = False
7879    ------- 
7980    registered_mask : str 
8081        path to mask 
82+ 
83+     failed_qc : int 
84+         metrics met specified thresholds?, used as index for selecting 
85+         outputs 
8186    """ 
8287    qc_metrics  =  qc_masks (registered , reference )
88+     failed_qc  =  0 
8389    for  metric , threshold  in  REGISTRATION_GUARDRAIL_THRESHOLDS .items ():
8490        if  threshold  is  not   None :
8591            value  =  qc_metrics .get (metric )
8692            if  isinstance (value , list ):
8793                value  =  value [0 ]
8894            if  value  <  threshold :
95+                 failed_qc  =  1 
8996                with  open (f'{ registered }  .failed_qc' , 'w' ,
9097                          encoding = 'utf-8' ) as  _f :
9198                    _f .write (f'{ metric }  : { value }   < { threshold }  ' )
9299                if  retry :
93100                    registered  =  f'{ registered }  -failed' 
94101                else :
95-                     raise   BadRegistrationError (metric = metric ,  value = value , 
96-                                                 threshold = threshold )
97-     return  registered 
102+                     logger . error ( str ( BadRegistrationError (
103+                         metric = metric ,  value = value ,  threshold = threshold )) )
104+     return  registered ,  failed_qc 
98105
99106
100107def  registration_guardrail_node (name = None ):
@@ -112,7 +119,8 @@ def registration_guardrail_node(name=None):
112119        name  =  'registration_guardrail' 
113120    return  Node (Function (input_names = ['registered' ,
114121                                      'reference' ],
115-                          output_names = ['registered' ],
122+                          output_names = ['registered' ,
123+                                        'failed_qc' ],
116124                         imports = ['from CPAC.qc import qc_masks, ' 
117125                                  'REGISTRATION_GUARDRAIL_THRESHOLDS' ,
118126                                  'from CPAC.registration.guardrails ' 
@@ -146,10 +154,10 @@ def registration_guardrail_workflow(registration_node, retry=True):
146154        (registration_node , guardrail , [(outkey , 'registered' )])])
147155    if  retry :
148156        wf  =  retry_registration (wf , registration_node ,
149-                                 guardrail .outputs .registered )
157+                                 guardrail .outputs .registered )[ 0 ] 
150158    else :
151159        wf .connect (guardrail , 'registered' , outputspec , outkey )
152-         wf   =   connect_from_spec (wf ,  outputspec , registration_node , outkey )
160+         connect_from_spec (outputspec , registration_node , outkey )
153161    return  wf 
154162
155163
@@ -167,6 +175,8 @@ def retry_registration(wf, registration_node, registered):
167175    Returns 
168176    ------- 
169177    Workflow 
178+ 
179+     Node 
170180    """ 
171181    name  =  f'retry_{ registration_node .name }  ' 
172182    retry_node  =  Node (Function (function = retry_registration_node ,
@@ -177,14 +187,14 @@ def retry_registration(wf, registration_node, registered):
177187    outputspec  =  registration_node .outputs 
178188    outkey  =  spec_key (registration_node , 'registered' )
179189    guardrail  =  registration_guardrail_node (f'{ name }  _guardrail' )
180-     wf   =   connect_from_spec (wf ,  inputspec , retry_node )
190+     connect_from_spec (inputspec , retry_node )
181191    wf .connect ([
182192        (inputspec , guardrail , [
183193            (spec_key (retry_node , 'reference' ), 'reference' )]),
184194        (retry_node , guardrail , [(outkey , 'registered' )]),
185195        (guardrail , outputspec , [('registered' , outkey )])])
186-     wf   =   connect_from_spec (wf ,  retry_node , outputspec , registered )
187-     return  wf 
196+     connect_from_spec (retry_node , outputspec , registered )
197+     return  wf ,  retry_node 
188198
189199
190200def  retry_registration_node (registered , registration_node ):
@@ -200,16 +210,12 @@ def retry_registration_node(registered, registration_node):
200210    ------- 
201211    Node 
202212    """ 
203-     from  CPAC .pipeline .random_state .seed  import  MAX_SEED , random_seed 
204-     seed  =  random_seed ()
213+     from  CPAC .pipeline .random_state .seed  import  seed_plus_1 
205214    if  registered .endswith ('-failed' ):
206215        retry_node  =  registration_node .clone (
207216            name = f'{ registration_node .name }  -retry' )
208-         if  isinstance (seed , int ):
209-             if  seed  <  MAX_SEED :  # increment random seed 
210-                 retry_node .seed  =  seed  +  1 
211-             else :  # loop back to minumum seed 
212-                 retry_node .seed  =  1 
217+         if  isinstance (retry_node .seed , int ):
218+             retry_node .seed  =  seed_plus_1 ()
213219        return  retry_node 
214220    return  registration_node 
215221
0 commit comments