Skip to content

Commit ce7c65f

Browse files
authored
ENH: Restore resampling BOLD to volumetric templates (#3121)
Builds on #3116.
2 parents 6bda5ce + 2fd3012 commit ce7c65f

File tree

3 files changed

+151
-133
lines changed

3 files changed

+151
-133
lines changed

fmriprep/utils/transforms.py

Lines changed: 66 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Utilities for loading transforms for resampling"""
2+
import warnings
23
from pathlib import Path
34

45
import h5py
@@ -36,57 +37,76 @@ def load_transforms(xfm_paths: list[Path], inverse: list[bool]) -> nt.base.Trans
3637
return chain
3738

3839

39-
def load_ants_h5(filename: Path) -> nt.TransformChain:
40-
"""Load ANTs H5 files as a nitransforms TransformChain"""
41-
affine, warp, warp_affine = parse_combined_hdf5(filename)
42-
warp_transform = nt.DenseFieldTransform(nb.Nifti1Image(warp, warp_affine))
43-
return nt.TransformChain([warp_transform, nt.Affine(affine)])
40+
FIXED_PARAMS = np.array([
41+
193.0, 229.0, 193.0, # Size
42+
96.0, 132.0, -78.0, # Origin
43+
1.0, 1.0, 1.0, # Spacing
44+
-1.0, 0.0, 0.0, # Directions
45+
0.0, -1.0, 0.0,
46+
0.0, 0.0, 1.0,
47+
]) # fmt:skip
4448

4549

46-
def parse_combined_hdf5(h5_fn, to_ras=True):
50+
def load_ants_h5(filename: Path) -> nt.base.TransformBase:
51+
"""Load ANTs H5 files as a nitransforms TransformChain"""
4752
# Borrowed from https://github.com/feilong/process
4853
# process.resample.parse_combined_hdf5()
49-
h = h5py.File(h5_fn)
54+
#
55+
# Changes:
56+
# * Tolerate a missing displacement field
57+
# * Return the original affine without a round-trip
58+
# * Always return a nitransforms TransformChain
59+
#
60+
# This should be upstreamed into nitransforms
61+
h = h5py.File(filename)
5062
xform = ITKCompositeH5.from_h5obj(h)
51-
affine = xform[0].to_ras()
63+
64+
# nt.Affine
65+
transforms = [nt.Affine(xform[0].to_ras())]
66+
67+
if '2' not in h['TransformGroup']:
68+
return transforms[0]
69+
70+
transform2 = h['TransformGroup']['2']
71+
5272
# Confirm these transformations are applicable
53-
assert (
54-
h['TransformGroup']['2']['TransformType'][:][0] == b'DisplacementFieldTransform_float_3_3'
55-
)
56-
assert np.array_equal(
57-
h['TransformGroup']['2']['TransformFixedParameters'][:],
58-
np.array(
59-
[
60-
193.0,
61-
229.0,
62-
193.0,
63-
96.0,
64-
132.0,
65-
-78.0,
66-
1.0,
67-
1.0,
68-
1.0,
69-
-1.0,
70-
0.0,
71-
0.0,
72-
0.0,
73-
-1.0,
74-
0.0,
75-
0.0,
76-
0.0,
77-
1.0,
78-
]
79-
),
80-
)
73+
if transform2['TransformType'][:][0] != b'DisplacementFieldTransform_float_3_3':
74+
msg = 'Unknown transform type [2]\n'
75+
for i in h['TransformGroup'].keys():
76+
msg += f'[{i}]: {h["TransformGroup"][i]["TransformType"][:][0]}\n'
77+
raise ValueError(msg)
78+
79+
fixed_params = transform2['TransformFixedParameters'][:]
80+
if not np.array_equal(fixed_params, FIXED_PARAMS):
81+
msg = 'Unexpected fixed parameters\n'
82+
msg += f'Expected: {FIXED_PARAMS}\n'
83+
msg += f'Found: {fixed_params}'
84+
if not np.array_equal(fixed_params[6:], FIXED_PARAMS[6:]):
85+
raise ValueError(msg)
86+
warnings.warn(msg)
87+
88+
shape = tuple(fixed_params[:3].astype(int))
8189
warp = h['TransformGroup']['2']['TransformParameters'][:]
82-
warp = warp.reshape((193, 229, 193, 3)).transpose(2, 1, 0, 3)
90+
warp = warp.reshape((*shape, 3)).transpose(2, 1, 0, 3)
8391
warp *= np.array([-1, -1, 1])
84-
warp_affine = np.array(
85-
[
86-
[1.0, 0.0, 0.0, -96.0],
87-
[0.0, 1.0, 0.0, -132.0],
88-
[0.0, 0.0, 1.0, -78.0],
89-
[0.0, 0.0, 0.0, 1.0],
90-
]
91-
)
92-
return affine, warp, warp_affine
92+
93+
warp_affine = np.eye(4)
94+
warp_affine[:3, :3] = fixed_params[9:].reshape((3, 3))
95+
warp_affine[:3, 3] = fixed_params[3:6]
96+
lps_to_ras = np.eye(4) * np.array([-1, -1, 1, 1])
97+
warp_affine = lps_to_ras @ warp_affine
98+
if np.array_equal(fixed_params, FIXED_PARAMS):
99+
# Confirm that we construct the right affine when fixed parameters are known
100+
assert np.array_equal(
101+
warp_affine,
102+
np.array(
103+
[
104+
[1.0, 0.0, 0.0, -96.0],
105+
[0.0, 1.0, 0.0, -132.0],
106+
[0.0, 0.0, 1.0, -78.0],
107+
[0.0, 0.0, 0.0, 1.0],
108+
]
109+
),
110+
)
111+
transforms.insert(0, nt.DenseFieldTransform(nb.Nifti1Image(warp, warp_affine)))
112+
return nt.TransformChain(transforms)

fmriprep/workflows/base.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def init_single_subject_wf(subject_id: str):
151151
from niworkflows.utils.misc import fix_multi_T1w_source_name
152152
from niworkflows.utils.spaces import Reference
153153
from smriprep.workflows.anatomical import init_anat_fit_wf
154+
from smriprep.workflows.outputs import init_template_iterator_wf
154155

155156
from fmriprep.workflows.bold.base import init_bold_wf
156157

@@ -310,7 +311,6 @@ def init_single_subject_wf(subject_id: str):
310311
skull_strip_fixed_seed=config.workflow.skull_strip_fixed_seed,
311312
)
312313

313-
# fmt:off
314314
workflow.connect([
315315
(inputnode, anat_fit_wf, [('subjects_dir', 'inputnode.subjects_dir')]),
316316
(bidssrc, bids_info, [(('t1w', fix_multi_T1w_source_name), 'in_file')]),
@@ -329,8 +329,18 @@ def init_single_subject_wf(subject_id: str):
329329
(bidssrc, ds_report_about, [(('t1w', fix_multi_T1w_source_name), 'source_file')]),
330330
(summary, ds_report_summary, [('out_report', 'in_file')]),
331331
(about, ds_report_about, [('out_report', 'in_file')]),
332-
])
333-
# fmt:on
332+
]) # fmt:skip
333+
334+
# Set up the template iterator once, if used
335+
if config.workflow.level == "full":
336+
if spaces.get_spaces(nonstandard=False, dim=(3,)):
337+
template_iterator_wf = init_template_iterator_wf(spaces=spaces)
338+
workflow.connect([
339+
(anat_fit_wf, template_iterator_wf, [
340+
('outputnode.template', 'inputnode.template'),
341+
('outputnode.anat2std_xfm', 'inputnode.anat2std_xfm'),
342+
]),
343+
]) # fmt:skip
334344

335345
if config.workflow.anat_only:
336346
return clean_datasinks(workflow)
@@ -510,6 +520,18 @@ def init_single_subject_wf(subject_id: str):
510520
]),
511521
]) # fmt:skip
512522

523+
if config.workflow.level == "full":
524+
workflow.connect([
525+
(template_iterator_wf, bold_wf, [
526+
("outputnode.anat2std_xfm", "inputnode.anat2std_xfm"),
527+
("outputnode.space", "inputnode.std_space"),
528+
("outputnode.resolution", "inputnode.std_resolution"),
529+
("outputnode.cohort", "inputnode.std_cohort"),
530+
("outputnode.std_t1w", "inputnode.std_t1w"),
531+
("outputnode.std_mask", "inputnode.std_mask"),
532+
]),
533+
]) # fmt:skip
534+
513535
return clean_datasinks(workflow)
514536

515537

fmriprep/workflows/bold/base.py

Lines changed: 60 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,13 @@ def init_bold_wf(
198198
"fmap_mask",
199199
"fmap_id",
200200
"sdc_method",
201+
# Volumetric templates
202+
"anat2std_xfm",
203+
"std_space",
204+
"std_resolution",
205+
"std_cohort",
206+
"std_t1w",
207+
"std_mask",
201208
],
202209
),
203210
name="inputnode",
@@ -381,6 +388,59 @@ def init_bold_wf(
381388
(bold_anat_wf, ds_bold_t1_wf, [('outputnode.bold_file', 'inputnode.bold')]),
382389
]) # fmt:skip
383390

391+
if spaces.get_spaces(nonstandard=False, dim=(3,)):
392+
# Missing:
393+
# * Clipping BOLD after resampling
394+
# * Resampling parcellations
395+
bold_std_wf = init_bold_volumetric_resample_wf(
396+
metadata=all_metadata[0],
397+
fieldmap_id=fieldmap_id if not multiecho else None,
398+
omp_nthreads=omp_nthreads,
399+
name='bold_std_wf',
400+
)
401+
ds_bold_std_wf = init_ds_volumes_wf(
402+
bids_root=str(config.execution.bids_dir),
403+
output_dir=fmriprep_dir,
404+
multiecho=multiecho,
405+
metadata=all_metadata[0],
406+
name='ds_bold_std_wf',
407+
)
408+
ds_bold_std_wf.inputs.inputnode.source_files = bold_series
409+
410+
workflow.connect([
411+
(inputnode, bold_std_wf, [
412+
("std_t1w", "inputnode.target_ref_file"),
413+
("std_mask", "inputnode.target_mask"),
414+
("anat2std_xfm", "inputnode.anat2std_xfm"),
415+
("fmap_ref", "inputnode.fmap_ref"),
416+
("fmap_coeff", "inputnode.fmap_coeff"),
417+
("fmap_id", "inputnode.fmap_id"),
418+
]),
419+
(bold_fit_wf, bold_std_wf, [
420+
("outputnode.coreg_boldref", "inputnode.bold_ref_file"),
421+
("outputnode.boldref2fmap_xfm", "inputnode.boldref2fmap_xfm"),
422+
("outputnode.boldref2anat_xfm", "inputnode.boldref2anat_xfm"),
423+
]),
424+
(bold_native_wf, bold_std_wf, [
425+
("outputnode.bold_minimal", "inputnode.bold_file"),
426+
("outputnode.motion_xfm", "inputnode.motion_xfm"),
427+
]),
428+
(inputnode, ds_bold_std_wf, [
429+
('std_t1w', 'inputnode.ref_file'),
430+
('anat2std_xfm', 'inputnode.anat2std_xfm'),
431+
('std_space', 'inputnode.space'),
432+
('std_resolution', 'inputnode.resolution'),
433+
('std_cohort', 'inputnode.cohort'),
434+
]),
435+
(bold_fit_wf, ds_bold_std_wf, [
436+
('outputnode.bold_mask', 'inputnode.bold_mask'),
437+
('outputnode.coreg_boldref', 'inputnode.bold_ref'),
438+
('outputnode.boldref2anat_xfm', 'inputnode.boldref2anat_xfm'),
439+
]),
440+
(bold_native_wf, ds_bold_std_wf, [('outputnode.t2star_map', 'inputnode.t2star')]),
441+
(bold_std_wf, ds_bold_std_wf, [('outputnode.bold_file', 'inputnode.bold')]),
442+
]) # fmt:skip
443+
384444
# Fill-in datasinks of reportlets seen so far
385445
for node in workflow.list_node_names():
386446
if node.split(".")[-1].startswith("ds_report"):
@@ -629,90 +689,6 @@ def init_func_preproc_wf(bold_file, has_fieldmap=False):
629689
)
630690
bold_confounds_wf.get_node("inputnode").inputs.t1_transform_flags = [False]
631691

632-
if spaces.get_spaces(nonstandard=False, dim=(3,)):
633-
# Apply transforms in 1 shot
634-
bold_std_trans_wf = init_bold_std_trans_wf(
635-
freesurfer=freesurfer,
636-
mem_gb=mem_gb["resampled"],
637-
omp_nthreads=omp_nthreads,
638-
spaces=spaces,
639-
multiecho=multiecho,
640-
name="bold_std_trans_wf",
641-
use_compression=not config.execution.low_mem,
642-
)
643-
bold_std_trans_wf.inputs.inputnode.fieldwarp = "identity"
644-
645-
# fmt:off
646-
workflow.connect([
647-
(inputnode, bold_std_trans_wf, [
648-
("template", "inputnode.templates"),
649-
("anat2std_xfm", "inputnode.anat2std_xfm"),
650-
("bold_file", "inputnode.name_source"),
651-
("t1w_aseg", "inputnode.bold_aseg"),
652-
("t1w_aparc", "inputnode.bold_aparc"),
653-
]),
654-
(bold_final, bold_std_trans_wf, [
655-
("mask", "inputnode.bold_mask"),
656-
("t2star", "inputnode.t2star"),
657-
]),
658-
(bold_reg_wf, bold_std_trans_wf, [
659-
("outputnode.itk_bold_to_t1", "inputnode.itk_bold_to_t1"),
660-
]),
661-
(bold_std_trans_wf, outputnode, [
662-
("outputnode.bold_std", "bold_std"),
663-
("outputnode.bold_std_ref", "bold_std_ref"),
664-
("outputnode.bold_mask_std", "bold_mask_std"),
665-
]),
666-
])
667-
# fmt:on
668-
669-
if freesurfer:
670-
# fmt:off
671-
workflow.connect([
672-
(bold_std_trans_wf, func_derivatives_wf, [
673-
("outputnode.bold_aseg_std", "inputnode.bold_aseg_std"),
674-
("outputnode.bold_aparc_std", "inputnode.bold_aparc_std"),
675-
]),
676-
(bold_std_trans_wf, outputnode, [
677-
("outputnode.bold_aseg_std", "bold_aseg_std"),
678-
("outputnode.bold_aparc_std", "bold_aparc_std"),
679-
]),
680-
])
681-
# fmt:on
682-
683-
if not multiecho:
684-
# fmt:off
685-
workflow.connect([
686-
(bold_split, bold_std_trans_wf, [("out_files", "inputnode.bold_split")]),
687-
(bold_hmc_wf, bold_std_trans_wf, [
688-
("outputnode.xforms", "inputnode.hmc_xforms"),
689-
]),
690-
])
691-
# fmt:on
692-
else:
693-
# fmt:off
694-
workflow.connect([
695-
(split_opt_comb, bold_std_trans_wf, [("out_files", "inputnode.bold_split")]),
696-
(bold_std_trans_wf, outputnode, [("outputnode.t2star_std", "t2star_std")]),
697-
])
698-
# fmt:on
699-
700-
# Already applied in bold_bold_trans_wf, which inputs to bold_t2s_wf
701-
bold_std_trans_wf.inputs.inputnode.hmc_xforms = "identity"
702-
703-
# fmt:off
704-
# func_derivatives_wf internally parametrizes over snapshotted spaces.
705-
workflow.connect([
706-
(bold_std_trans_wf, func_derivatives_wf, [
707-
("outputnode.template", "inputnode.template"),
708-
("outputnode.spatial_reference", "inputnode.spatial_reference"),
709-
("outputnode.bold_std_ref", "inputnode.bold_std_ref"),
710-
("outputnode.bold_std", "inputnode.bold_std"),
711-
("outputnode.bold_mask_std", "inputnode.bold_mask_std"),
712-
]),
713-
])
714-
# fmt:on
715-
716692
# SURFACES ##################################################################################
717693
# Freesurfer
718694
if freesurfer and freesurfer_spaces:

0 commit comments

Comments
 (0)