2121from  CPAC .pipeline  import  nipype_pipeline_engine  as  pe 
2222from  nipype .interfaces  import  afni , ants , c3 , fsl , utility  as  util 
2323from  nipype .interfaces .afni  import  utils  as  afni_utils 
24- from  nipype .interfaces .utility  import  Merge , Select 
2524from  CPAC .anat_preproc .lesion_preproc  import  create_lesion_preproc 
2625from  CPAC .func_preproc .utils  import  chunk_ts , split_ts_chunks 
27- from  CPAC .registration .guardrails  import  registration_guardrail_node 
26+ from  CPAC .registration .guardrails  import  guardrail_selection , \
27+                                          registration_guardrail_node 
2828from  CPAC .registration .utils  import  seperate_warps_list , \
2929                                    check_transforms , \
3030                                    generate_inverse_transform_flags , \
@@ -1174,20 +1174,14 @@ def create_wf_calculate_ants_warp(name='create_wf_calculate_ants_warp',
11741174    .. image:: 
11751175        :width: 500 
11761176    ''' 
1177- 
1177+      from   CPAC . registration . guardrails   import   retry_hardcoded_reg 
11781178    calc_ants_warp_wf  =  pe .Workflow (name = name )
11791179
1180-     inputspec  =  pe .Node (util .IdentityInterface (
1181-         fields = ['moving_brain' ,
1182-                 'reference_brain' ,
1183-                 'moving_skull' ,
1184-                 'reference_skull' ,
1185-                 'reference_mask' ,
1186-                 'moving_mask' ,
1187-                 'fixed_image_mask' ,
1188-                 'ants_para' ,
1189-                 'interp' ]),
1190-                 name = 'inputspec' )
1180+     warp_inputs  =  ['moving_brain' , 'reference_brain' , 'moving_skull' ,
1181+                    'reference_skull' , 'ants_para' , 'moving_mask' ,
1182+                    'reference_mask' , 'fixed_image_mask' , 'interp' ]
1183+     inputspec  =  pe .Node (util .IdentityInterface (fields = warp_inputs ),
1184+                         name = 'inputspec' )
11911185
11921186    outputspec  =  pe .Node (util .IdentityInterface (
11931187        fields = ['ants_initial_xfm' ,
@@ -1208,27 +1202,30 @@ def create_wf_calculate_ants_warp(name='create_wf_calculate_ants_warp',
12081202    calculate_ants_warp.inputs.initial_moving_transform_com = 0 
12091203    ''' 
12101204    reg_imports  =  ['import os' , 'import subprocess' ]
1211-     calculate_ants_warp  =  \
1212-         pe .Node (interface = util .Function (input_names = ['moving_brain' ,
1213-                                                      'reference_brain' ,
1214-                                                      'moving_skull' ,
1215-                                                      'reference_skull' ,
1216-                                                      'ants_para' ,
1217-                                                      'moving_mask' ,
1218-                                                      'reference_mask' ,
1219-                                                      'fixed_image_mask' ,
1220-                                                      'interp' ,
1221-                                                      'reg_with_skull' ],
1222-                                         output_names = ['warp_list' ,
1223-                                                       'warped_image' ],
1224-                                         function = hardcoded_reg ,
1225-                                         imports = reg_imports ),
1226-                 name = 'calc_ants_warp' ,
1227-                 mem_gb = 2.8 ,
1228-                 mem_x = (2e-7 , 'moving_brain' , 'xyz' ))
1205+     warp_inputs  +=  ['reg_with_skull' ]
1206+     warp_outputs  =  ['warp_list' , 'warped_image' ]
1207+     calculate_ants_warp  =  pe .Node (
1208+         interface = util .Function (input_names = warp_inputs ,
1209+                                 output_names = warp_outputs ,
1210+                                 function = hardcoded_reg ,
1211+                                 imports = reg_imports ),
1212+         name = 'calc_ants_warp' , mem_gb = 2.8 ,
1213+         mem_x = (2e-7 , 'moving_brain' , 'xyz' ))
1214+     retry_calculate_ants_warp  =  pe .Node (
1215+         interface = util .Function (input_names = [* warp_inputs , 'previous_failure' ],
1216+                                 output_names = warp_outputs ,
1217+                                 function = retry_hardcoded_reg ,
1218+                                 imports = ['from CPAC.registration.utils ' 
1219+                                          'import hardcoded_reg' ,
1220+                                          'from CPAC.utils.docs import ' 
1221+                                          'retry_docstring' ]),
1222+         name = 'retry_calc_ants_warp' , mem_gb = 2.8 ,
1223+         mem_x = (2e-7 , 'moving_brain' , 'xyz' ))
1224+     guardrails  =  tuple (registration_guardrail_node (
1225+         f'{ _try } { name }  _guardrail' , i ) for  i , _try  in  enumerate (('' , 'retry_' )))
12291226
12301227    calculate_ants_warp .interface .num_threads  =  num_threads 
1231- 
1228+      retry_calculate_ants_warp . interface . num_threads   =   num_threads 
12321229    select_forward_initial  =  pe .Node (util .Function (
12331230        input_names = ['warp_list' , 'selection' ],
12341231        output_names = ['selected_warp' ],
@@ -1264,13 +1261,10 @@ def create_wf_calculate_ants_warp(name='create_wf_calculate_ants_warp',
12641261
12651262    select_inverse_warp .inputs .selection  =  "Inverse" 
12661263
1267-     guardrail  =  registration_guardrail_node (f'{ name }  _guardrail' )
12681264    calc_ants_warp_wf .connect (inputspec , 'moving_brain' ,
12691265                              calculate_ants_warp , 'moving_brain' )
12701266    calc_ants_warp_wf .connect (inputspec , 'reference_brain' ,
12711267                              calculate_ants_warp , 'reference_brain' )
1272-     calc_ants_warp_wf .connect (inputspec , 'reference_brain' ,
1273-                               guardrail , 'reference' )
12741268
12751269    if  reg_ants_skull  ==  1 :
12761270        calculate_ants_warp .inputs .reg_with_skull  =  1 
@@ -1279,11 +1273,17 @@ def create_wf_calculate_ants_warp(name='create_wf_calculate_ants_warp',
12791273                                  calculate_ants_warp , 'moving_skull' )
12801274        calc_ants_warp_wf .connect (inputspec , 'reference_skull' ,
12811275                                  calculate_ants_warp , 'reference_skull' )
1276+         for  guardrail  in  guardrails :
1277+             calc_ants_warp_wf .connect (inputspec , 'reference_skull' ,
1278+                                       guardrail , 'reference' )
12821279    else :
12831280        calc_ants_warp_wf .connect (inputspec , 'moving_brain' ,
12841281                                  calculate_ants_warp , 'moving_skull' )
12851282        calc_ants_warp_wf .connect (inputspec , 'reference_brain' ,
12861283                                  calculate_ants_warp , 'reference_skull' )
1284+         for  guardrail  in  guardrails :
1285+             calc_ants_warp_wf .connect (inputspec , 'reference_brain' ,
1286+                                       guardrail , 'reference' )
12871287
12881288    calc_ants_warp_wf .connect (inputspec , 'fixed_image_mask' ,
12891289                              calculate_ants_warp , 'fixed_image_mask' )
@@ -1317,9 +1317,11 @@ def create_wf_calculate_ants_warp(name='create_wf_calculate_ants_warp',
13171317                              outputspec , 'warp_field' )
13181318    calc_ants_warp_wf .connect (select_inverse_warp , 'selected_warp' ,
13191319                              outputspec , 'inverse_warp_field' )
1320-     calc_ants_warp_wf .connect (calculate_ants_warp , 'warped_image' ,
1321-                               guardrail , 'registered' )
1322-     calc_ants_warp_wf .connect (guardrail , 'registered' ,
1320+     for  guardrail  in  guardrails :
1321+         calc_ants_warp_wf .connect (calculate_ants_warp , 'warped_image' ,
1322+                                   guardrail , 'registered' )
1323+     select  =  guardrail_selection (calc_ants_warp_wf , * guardrails )
1324+     calc_ants_warp_wf .connect (select , 'out' ,
13231325                              outputspec , 'normalized_output_brain' )
13241326
13251327    return  calc_ants_warp_wf 
@@ -2928,38 +2930,39 @@ def coregistration(wf, cfg, strat_pool, pipe_num, opt=None):
29282930            wf .connect (func_to_anat_bbreg , 'outputspec.anat_func' ,
29292931                       retry_guardrail , 'registered' )
29302932
2931-         mean_bolds  =  pe .Node (Merge (2 ), run_without_submitting = True ,
2933+         mean_bolds  =  pe .Node (util . Merge (2 ), run_without_submitting = True ,
29322934                             name = f'bbreg_mean_bold_choices_{ pipe_num }  ' )
2933-         xfms  =  pe .Node (Merge (2 ), run_without_submitting = True ,
2935+         xfms  =  pe .Node (util . Merge (2 ), run_without_submitting = True ,
29342936                       name = f'bbreg_xfm_choices_{ pipe_num }  ' )
2935-         fallback_mean_bolds  =  pe .Node (Select , run_without_submitting = True ,
2937+         fallback_mean_bolds  =  pe .Node (util .Select (),
2938+                                       run_without_submitting = True ,
29362939                                      name = f'bbreg_choose_mean_bold_{ pipe_num }  ' 
29372940                                      )
2938-         fallback_xfms  =  pe .Node (Select , run_without_submitting = True ,
2941+         fallback_xfms  =  pe .Node (util . Select () , run_without_submitting = True ,
29392942                                name = f'bbreg_choose_xfm_{ pipe_num }  ' )
29402943        if  opt  is  True :
29412944            wf .connect ([
2942-                 (bbreg_guardrail , mean_bolds , ['registered' , 'in1' ]),
2943-                 (retry_guardrail , mean_bolds , ['registered' , 'in1'  ]),
2945+                 (bbreg_guardrail , mean_bolds , [( 'registered' , 'in1' ) ]),
2946+                 (retry_guardrail , mean_bolds , [( 'registered' , 'in2'  ) ]),
29442947                (func_to_anat_bbreg , xfms , [
2945-                     'outputspec.func_to_anat_linear_xfm' , 'in2'  ]),
2948+                     ( 'outputspec.func_to_anat_linear_xfm' , 'in1'  ) ]),
29462949                (retry_node , xfms , [
2947-                     'outputspec.func_to_anat_linear_xfm_nobbreg ' , 'in2' ])])
2950+                     ( 'outputspec.func_to_anat_linear_xfm ' , 'in2' ) ])])
29482951        else :
29492952            # Fall back to no-BBReg 
29502953            wf .connect ([
2951-                 (bbreg_guardrail , mean_bolds , ['registered' , 'in1' ]),
2952-                 (func_to_anat , mean_bolds , ['outputspec.anat_func_nobbreg' ,
2953-                                             'in1' ]),
2954+                 (bbreg_guardrail , mean_bolds , [( 'registered' , 'in1' ) ]),
2955+                 (func_to_anat , mean_bolds , [( 'outputspec.anat_func_nobbreg' ,
2956+                                               'in2' ) ]),
29542957                (func_to_anat_bbreg , xfms , [
2955-                     'outputspec.func_to_anat_linear_xfm' , 'in2'  ]),
2958+                     ( 'outputspec.func_to_anat_linear_xfm' , 'in1'  ) ]),
29562959                (func_to_anat , xfms , [
2957-                     'outputspec.func_to_anat_linear_xfm_nobbreg' , 'in2' ])])
2960+                     ( 'outputspec.func_to_anat_linear_xfm_nobbreg' , 'in2' ) ])])
29582961        wf .connect ([
2959-             (mean_bolds , fallback_mean_bolds , ['out' , 'inlist' ]),
2960-             (xfms , fallback_xfms , ['out' , 'inlist' ]),
2961-             (bbreg_guardrail , fallback_mean_bolds , ['failed_qc' , 'index' ]),
2962-             (bbreg_guardrail , fallback_xfms , ['failed_qc' , 'index' ])])
2962+             (mean_bolds , fallback_mean_bolds , [( 'out' , 'inlist' ) ]),
2963+             (xfms , fallback_xfms , [( 'out' , 'inlist' ) ]),
2964+             (bbreg_guardrail , fallback_mean_bolds , [( 'failed_qc' , 'index' ) ]),
2965+             (bbreg_guardrail , fallback_xfms , [( 'failed_qc' , 'index' ) ])])
29632966        outputs  =  {
29642967            'space-T1w_desc-mean_bold' : (fallback_mean_bolds , 'out' ),
29652968            'from-bold_to-T1w_mode-image_desc-linear_xfm' : (fallback_xfms ,
0 commit comments