Skip to content

Commit 6bda5ce

Browse files
authored
ENH: Restore resampling to T1w target (#3116)
## Changes proposed in this pull request Adds T1w resampling. Unconditionally resamples in T1w space, but conditionally outputs.
2 parents 978ae51 + 0d09cc0 commit 6bda5ce

File tree

6 files changed

+470
-712
lines changed

6 files changed

+470
-712
lines changed

fmriprep/interfaces/resampling.py

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ class ResampleSeriesInputSpec(TraitedSpec):
2828
in_file = File(exists=True, mandatory=True, desc="3D or 4D image file to resample")
2929
ref_file = File(exists=True, mandatory=True, desc="File to resample in_file to")
3030
transforms = InputMultiObject(
31-
File(exists=True), mandatory=True, desc="Transform files, from in_file to ref_file (image mode)"
31+
File(exists=True),
32+
mandatory=True,
33+
desc="Transform files, from in_file to ref_file (image mode)",
3234
)
3335
inverse = InputMultiObject(
3436
traits.Bool,
@@ -48,6 +50,16 @@ class ResampleSeriesInputSpec(TraitedSpec):
4850
desc="the phase-encoding direction corresponding to in_data",
4951
)
5052
num_threads = traits.Int(1, usedefault=True, desc="Number of threads to use for resampling")
53+
output_data_type = traits.Str("float32", usedefault=True, desc="Data type of output image")
54+
order = traits.Int(3, usedefault=True, desc="Order of interpolation (0=nearest, 3=cubic)")
55+
mode = traits.Str(
56+
'constant',
57+
usedefault=True,
58+
desc="How data is extended beyond its boundaries. "
59+
"See scipy.ndimage.map_coordinates for more details.",
60+
)
61+
cval = traits.Float(0.0, usedefault=True, desc="Value to fill past edges of data")
62+
prefilter = traits.Bool(True, usedefault=True, desc="Spline-prefilter data if order > 1")
5163

5264

5365
class ResampleSeriesOutputSpec(TraitedSpec):
@@ -87,13 +99,18 @@ def _run_interface(self, runtime):
8799

88100
pe_info = [(pe_axis, -ro_time if (axis_flip ^ pe_flip) else ro_time)] * nvols
89101

90-
resampled = resample_bold(
102+
resampled = resample_image(
91103
source=source,
92104
target=target,
93105
transforms=transforms,
94106
fieldmap=fieldmap,
95107
pe_info=pe_info,
96108
nthreads=self.inputs.num_threads,
109+
output_dtype=self.inputs.output_data_type,
110+
order=self.inputs.order,
111+
mode=self.inputs.mode,
112+
cval=self.inputs.cval,
113+
prefilter=self.inputs.prefilter,
97114
)
98115
resampled.to_filename(out_path)
99116

@@ -105,10 +122,16 @@ class ReconstructFieldmapInputSpec(TraitedSpec):
105122
in_coeffs = InputMultiObject(
106123
File(exists=True), mandatory=True, desc="SDCflows-style spline coefficient files"
107124
)
108-
target_ref_file = File(exists=True, mandatory=True, desc="Image to reconstruct the field in alignment with")
109-
fmap_ref_file = File(exists=True, mandatory=True, desc="Reference file aligned with coefficients")
125+
target_ref_file = File(
126+
exists=True, mandatory=True, desc="Image to reconstruct the field in alignment with"
127+
)
128+
fmap_ref_file = File(
129+
exists=True, mandatory=True, desc="Reference file aligned with coefficients"
130+
)
110131
transforms = InputMultiObject(
111-
File(exists=True), mandatory=True, desc="Transform files, from in_file to ref_file (image mode)"
132+
File(exists=True),
133+
mandatory=True,
134+
desc="Transform files, from in_file to ref_file (image mode)",
112135
)
113136
inverse = InputMultiObject(
114137
traits.Bool,
@@ -252,6 +275,9 @@ def resample_vol(
252275
coordinates = nb.affines.apply_affine(
253276
hmc_xfm, coordinates.reshape(coords_shape[0], -1).T
254277
).T.reshape(coords_shape)
278+
else:
279+
# Copy coordinates to avoid interfering with other calls
280+
coordinates = coordinates.copy()
255281

256282
vsm = fmap_hz * pe_info[1]
257283
coordinates[pe_info[0], ...] += vsm
@@ -346,15 +372,17 @@ async def resample_series_async(
346372

347373
semaphore = asyncio.Semaphore(max_concurrent)
348374

349-
out_array = np.zeros(coordinates.shape[1:] + data.shape[-1:], dtype=output_dtype)
375+
# Order F ensures individual volumes are contiguous in memory
376+
# Also matches NIfTI, making final save more efficient
377+
out_array = np.zeros(coordinates.shape[1:] + data.shape[-1:], dtype=output_dtype, order='F')
350378

351379
tasks = [
352380
asyncio.create_task(
353381
worker(
354382
partial(
355383
resample_vol,
356384
data=volume,
357-
coordinates=coordinates.copy(),
385+
coordinates=coordinates,
358386
pe_info=pe_info[volid],
359387
hmc_xfm=hmc_xfms[volid] if hmc_xfms else None,
360388
fmap_hz=fmap_hz,
@@ -451,21 +479,26 @@ def resample_series(
451479
)
452480

453481

454-
def resample_bold(
482+
def resample_image(
455483
source: nb.Nifti1Image,
456484
target: nb.Nifti1Image,
457485
transforms: nt.TransformChain,
458486
fieldmap: nb.Nifti1Image | None,
459487
pe_info: list[tuple[int, float]] | None,
460488
nthreads: int = 1,
489+
output_dtype: np.dtype | str | None = 'f4',
490+
order: int = 3,
491+
mode: str = 'constant',
492+
cval: float = 0.0,
493+
prefilter: bool = True,
461494
) -> nb.Nifti1Image:
462-
"""Resample a 4D bold series into a target space, applying head-motion
495+
"""Resample a 3- or 4D image into a target space, applying head-motion
463496
and susceptibility-distortion correction simultaneously.
464497
465498
Parameters
466499
----------
467500
source
468-
The 4D bold series to resample.
501+
The 3D bold image or 4D bold series to resample.
469502
target
470503
An image sampled in the target space.
471504
transforms
@@ -480,6 +513,17 @@ def resample_bold(
480513
of the data array in the second dimension.
481514
nthreads
482515
Number of threads to use for parallel resampling
516+
output_dtype
517+
The dtype of the output array.
518+
order
519+
Order of interpolation (default: 3 = cubic)
520+
mode
521+
How ``data`` is extended beyond its boundaries. See
522+
:func:`scipy.ndimage.map_coordinates` for more details.
523+
cval
524+
Value to fill past edges of ``data`` if ``mode`` is ``'constant'``.
525+
prefilter
526+
Determines if ``data`` is pre-filtered before interpolation.
483527
484528
Returns
485529
-------
@@ -527,8 +571,12 @@ def resample_bold(
527571
pe_info=pe_info,
528572
hmc_xfms=hmc_xfms,
529573
fmap_hz=fieldmap.get_fdata(dtype='f4'),
530-
output_dtype='f4',
574+
output_dtype=output_dtype,
531575
nthreads=nthreads,
576+
order=order,
577+
mode=mode,
578+
cval=cval,
579+
prefilter=prefilter,
532580
)
533581
resampled_img = nb.Nifti1Image(resampled_data, target.affine, target.header)
534582
resampled_img.set_data_dtype('f4')

fmriprep/workflows/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,9 @@ def init_single_subject_wf(subject_id: str):
483483
precomputed=functional_cache,
484484
fieldmap_id=fieldmap_id,
485485
)
486+
if bold_wf is None:
487+
continue
488+
486489
bold_wf.__desc__ = func_pre_desc + (bold_wf.__desc__ or "")
487490

488491
workflow.connect([

fmriprep/workflows/bold/apply.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import nipype.interfaces.utility as niu
88
import nipype.pipeline.engine as pe
99
from niworkflows.interfaces.header import ValidateImage
10+
from niworkflows.interfaces.nibabel import GenerateSamplingReference
1011
from niworkflows.interfaces.utility import KeySelect
1112
from niworkflows.utils.connections import listify
1213

@@ -25,6 +26,110 @@
2526
from niworkflows.utils.spaces import SpatialReferences
2627

2728

29+
def init_bold_volumetric_resample_wf(
30+
*,
31+
metadata: dict,
32+
fieldmap_id: str | None = None,
33+
omp_nthreads: int = 1,
34+
name: str = 'bold_volumetric_resample_wf',
35+
) -> pe.Workflow:
36+
workflow = pe.Workflow(name=name)
37+
38+
inputnode = pe.Node(
39+
niu.IdentityInterface(
40+
fields=[
41+
"bold_file",
42+
"bold_ref_file",
43+
"target_ref_file",
44+
"target_mask",
45+
# HMC
46+
"motion_xfm",
47+
# SDC
48+
"boldref2fmap_xfm",
49+
"fmap_ref",
50+
"fmap_coeff",
51+
"fmap_id",
52+
# Anatomical
53+
"boldref2anat_xfm",
54+
# Template
55+
"anat2std_xfm",
56+
],
57+
),
58+
name='inputnode',
59+
)
60+
61+
outputnode = pe.Node(niu.IdentityInterface(fields=["bold_file"]), name='outputnode')
62+
63+
gen_ref = pe.Node(GenerateSamplingReference(), name='gen_ref', mem_gb=0.3)
64+
65+
boldref2target = pe.Node(niu.Merge(2), name='boldref2target')
66+
bold2target = pe.Node(niu.Merge(2), name='bold2target')
67+
resample = pe.Node(ResampleSeries(), name="resample", n_procs=omp_nthreads)
68+
69+
workflow.connect([
70+
(inputnode, gen_ref, [
71+
('bold_ref_file', 'moving_image'),
72+
('target_ref_file', 'fixed_image'),
73+
('target_mask', 'fov_mask'),
74+
]),
75+
(inputnode, boldref2target, [
76+
('boldref2anat_xfm', 'in1'),
77+
('anat2std_xfm', 'in2'),
78+
]),
79+
(inputnode, bold2target, [('motion_xfm', 'in1')]),
80+
(inputnode, resample, [('bold_file', 'in_file')]),
81+
(gen_ref, resample, [('out_file', 'ref_file')]),
82+
(boldref2target, bold2target, [('out', 'in2')]),
83+
(bold2target, resample, [('out', 'transforms')]),
84+
(resample, outputnode, [('out_file', 'bold_file')]),
85+
]) # fmt:skip
86+
87+
if not fieldmap_id:
88+
return workflow
89+
90+
fmap_select = pe.Node(
91+
KeySelect(fields=["fmap_ref", "fmap_coeff"], key=fieldmap_id),
92+
name="fmap_select",
93+
run_without_submitting=True,
94+
)
95+
distortion_params = pe.Node(
96+
DistortionParameters(metadata=metadata),
97+
name="distortion_params",
98+
run_without_submitting=True,
99+
)
100+
fmap2target = pe.Node(niu.Merge(2), name='fmap2target')
101+
inverses = pe.Node(niu.Function(function=_gen_inverses), name='inverses')
102+
103+
fmap_recon = pe.Node(ReconstructFieldmap(), name="fmap_recon")
104+
105+
workflow.connect([
106+
(inputnode, fmap_select, [
107+
("fmap_ref", "fmap_ref"),
108+
("fmap_coeff", "fmap_coeff"),
109+
("fmap_id", "keys"),
110+
]),
111+
(inputnode, distortion_params, [('bold_file', 'in_file')]),
112+
(inputnode, fmap2target, [('boldref2fmap_xfm', 'in1')]),
113+
(gen_ref, fmap_recon, [('out_file', 'target_ref_file')]),
114+
(boldref2target, fmap2target, [('out', 'in2')]),
115+
(boldref2target, inverses, [('out', 'inlist')]),
116+
(fmap_select, fmap_recon, [
117+
("fmap_coeff", "in_coeffs"),
118+
("fmap_ref", "fmap_ref_file"),
119+
]),
120+
(fmap2target, fmap_recon, [('out', 'transforms')]),
121+
(inverses, fmap_recon, [('out', 'inverse')]),
122+
# Inject fieldmap correction into resample node
123+
(distortion_params, resample, [
124+
("readout_time", "ro_time"),
125+
("pe_direction", "pe_dir"),
126+
]),
127+
(fmap_recon, resample, [('out_file', 'fieldmap')]),
128+
]) # fmt:skip
129+
130+
return workflow
131+
132+
28133
def init_bold_apply_wf(
29134
*,
30135
spaces: SpatialReferences,
@@ -49,3 +154,16 @@ def init_bold_apply_wf(
49154
# )
50155

51156
return workflow
157+
158+
159+
def _gen_inverses(inlist: list) -> list[bool]:
160+
"""Create a list indicating the first transform should be inverted.
161+
162+
The input list is the collection of transforms that follow the
163+
inverted one.
164+
"""
165+
from niworkflows.utils.connections import listify
166+
167+
if not inlist:
168+
return [True]
169+
return [True] + [False] * len(listify(inlist))

0 commit comments

Comments
 (0)