@@ -34,6 +34,9 @@ def init_epi_reference_wf(
3434    omp_nthreads ,
3535    auto_bold_nss = False ,
3636    name = 'epi_reference_wf' ,
37+     n4_iterations = (50 ,)* 5 ,
38+     n4_shrink_factor = 4 ,
39+     calculate_bspline_grid = False ,
3740):
3841    """ 
3942    Build a workflow that generates a reference map from a set of EPI images. 
@@ -121,7 +124,7 @@ def init_epi_reference_wf(
121124    from  ...interfaces .header  import  ValidateImage 
122125    from  ...interfaces .images  import  RobustAverage 
123126    from  ...interfaces .nibabel  import  IntensityClip 
124-     from  ...utils .connections  import  listify 
127+     from  ...utils .connections  import  listify ,  pop_file 
125128
126129    wf  =  Workflow (name = name )
127130
@@ -153,9 +156,9 @@ def init_epi_reference_wf(
153156        N4BiasFieldCorrection (
154157            dimension = 3 ,
155158            copy_header = True ,
156-             n_iterations = [ 50 ]  *   5 ,
159+             n_iterations = list ( n4_iterations ) ,
157160            convergence_threshold = 1e-7 ,
158-             shrink_factor = 4 ,
161+             shrink_factor = n4_shrink_factor ,
159162        ),
160163        n_procs = omp_nthreads ,
161164        name = 'n4_avgs' ,
@@ -220,6 +223,13 @@ def _set_threads(in_list, maximum):
220223    else :
221224        wf .connect (inputnode , 't_masks' , per_run_avgs , 't_mask' )
222225
226+     if  calculate_bspline_grid :
227+         bspline_grid  =  pe .Node (niu .Function (function = _bspline_grid ), name = 'bspline_grid' )
228+         wf .connect ([
229+             (inputnode , bspline_grid , [(('in_files' , pop_file ), 'in_file' )]),
230+             (bspline_grid , n4_avgs , [('out' , 'args' )]),
231+         ])  # fmt:skip 
232+ 
223233    return  wf 
224234
225235
@@ -256,3 +266,16 @@ def _post_merge(in_file, in_xfms):
256266    img  =  nb .load (in_file )
257267    nb .Nifti1Image (img .dataobj , img .affine , None ).to_filename (out_file )
258268    return  _advanced_clip (out_file , p_min = 0.0 , p_max = 100.0 )
269+ 
270+ 
271+ def  _bspline_grid (in_file ):
272+     import  nibabel  as  nb 
273+     import  numpy  as  np 
274+     import  math 
275+ 
276+     img  =  nb .load (in_file )
277+     zooms  =  img .header .get_zooms ()[:3 ]
278+     extent  =  (np .array (img .shape [:3 ]) -  1 ) *  zooms 
279+     # get mesh resolution ratio 
280+     retval  =  [f'{ math .ceil (i  /  extent [np .argmin (extent )])}  '  for  i  in  extent ]
281+     return  f"-b [{ 'x' .join (retval )}  ]" 
0 commit comments