Skip to content

Commit 02a0737

Browse files
committed
feat: Parameterize epi_reference_wf to better support rodents
1 parent b746cba commit 02a0737

File tree

1 file changed

+26
-3
lines changed

1 file changed

+26
-3
lines changed

niworkflows/workflows/epi/refmap.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ def init_epi_reference_wf(
3434
omp_nthreads,
3535
auto_bold_nss=False,
3636
name='epi_reference_wf',
37+
n4_iterations=(50,)*5,
38+
n4_shrink_factor=4,
39+
calculate_bspline_grid=False,
3740
):
3841
"""
3942
Build a workflow that generates a reference map from a set of EPI images.
@@ -121,7 +124,7 @@ def init_epi_reference_wf(
121124
from ...interfaces.header import ValidateImage
122125
from ...interfaces.images import RobustAverage
123126
from ...interfaces.nibabel import IntensityClip
124-
from ...utils.connections import listify
127+
from ...utils.connections import listify, pop_file
125128

126129
wf = Workflow(name=name)
127130

@@ -153,9 +156,9 @@ def init_epi_reference_wf(
153156
N4BiasFieldCorrection(
154157
dimension=3,
155158
copy_header=True,
156-
n_iterations=[50] * 5,
159+
n_iterations=list(n4_iterations),
157160
convergence_threshold=1e-7,
158-
shrink_factor=4,
161+
shrink_factor=n4_shrink_factor,
159162
),
160163
n_procs=omp_nthreads,
161164
name='n4_avgs',
@@ -220,6 +223,13 @@ def _set_threads(in_list, maximum):
220223
else:
221224
wf.connect(inputnode, 't_masks', per_run_avgs, 't_mask')
222225

226+
if calculate_bspline_grid:
227+
bspline_grid = pe.Node(niu.Function(function=_bspline_grid), name='bspline_grid')
228+
wf.connect([
229+
(inputnode, bspline_grid, [(('in_files', pop_file), 'in_file')]),
230+
(bspline_grid, n4_avgs, [('out', 'args')]),
231+
]) # fmt:skip
232+
223233
return wf
224234

225235

@@ -256,3 +266,16 @@ def _post_merge(in_file, in_xfms):
256266
img = nb.load(in_file)
257267
nb.Nifti1Image(img.dataobj, img.affine, None).to_filename(out_file)
258268
return _advanced_clip(out_file, p_min=0.0, p_max=100.0)
269+
270+
271+
def _bspline_grid(in_file):
272+
import nibabel as nb
273+
import numpy as np
274+
import math
275+
276+
img = nb.load(in_file)
277+
zooms = img.header.get_zooms()[:3]
278+
extent = (np.array(img.shape[:3]) - 1) * zooms
279+
# get mesh resolution ratio
280+
retval = [f'{math.ceil(i / extent[np.argmin(extent)])}' for i in extent]
281+
return f"-b [{'x'.join(retval)}]"

0 commit comments

Comments
 (0)