Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,15 @@ manual) fixed during robust template estimation to improve reproducibility.
Iterations are automatically disabled to reduce runtime when :option:`--hmc-init-frame-fix` is
used.

When motion correction is undesirable, use :option:`--hmc-off` to disable head motion
correction entirely and keep the data unmodified apart from downstream
processing steps.

Examples: ::

$ petprep /data/bids_root /out participant --hmc-fwhm 8 --hmc-start-time 60
$ petprep /data/bids_root /out participant --hmc-init-frame 10 --hmc-init-frame-fix
$ petprep /data/bids_root /out participant --hmc-off

Segmentation
----------------
Expand Down
6 changes: 6 additions & 0 deletions petprep/cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,12 @@ def _bids_filter(value, parser):
action='store_true',
help=('Keep the chosen initial reference frame fixed during head-motion estimation.'),
)
g_hmc.add_argument(
'--hmc-off',
dest='hmc_off',
action='store_true',
help='Disable head-motion correction and use the uncorrected data.',
)

g_seg = parser.add_argument_group('Segmentation options')
g_seg.add_argument(
Expand Down
16 changes: 16 additions & 0 deletions petprep/cli/tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,3 +366,19 @@ def test_hmc_init_frame_parsing(tmp_path):
opts = parser.parse_args(base_args + ['--hmc-init-frame', '3', '--hmc-init-frame-fix'])
assert opts.hmc_init_frame == 3
assert opts.hmc_fix_frame is True


def test_hmc_off_flag(tmp_path):
"""Ensure disabling motion correction is parsed correctly."""
datapath = tmp_path / 'data'
outpath = tmp_path / 'out'
datapath.mkdir()

parser = _build_parser()
base_args = [str(datapath), str(outpath), 'participant']

opts = parser.parse_args(base_args)
assert opts.hmc_off is False

opts = parser.parse_args(base_args + ['--hmc-off'])
assert opts.hmc_off is True
2 changes: 2 additions & 0 deletions petprep/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,8 @@ class workflow(_Config):
"""Index of initial frame for head-motion estimation ('auto' selects highest uptake)."""
hmc_fix_frame: bool = False
"""Whether to fix the reference frame during head-motion estimation."""
hmc_off: bool = False
"""Disable head-motion correction and keep data uncorrected."""
seg = 'gtm'
"""Segmentation approach ('gtm', 'brainstem', 'thalamicNuclei',
'hippocampusAmygdala', 'wm', 'raphe', 'limbic')."""
Expand Down
114 changes: 100 additions & 14 deletions petprep/workflows/pet/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
#
# https://www.nipreps.org/community/licensing/
#
from collections.abc import Sequence
from pathlib import Path

import nibabel as nb
import numpy as np
from nipype.interfaces import utility as niu
from nipype.pipeline import engine as pe
from nitransforms.linear import Affine
from nitransforms.linear import Affine, LinearTransformsMapping
from niworkflows.interfaces.header import ValidateImage
from niworkflows.utils.connections import listify

Expand Down Expand Up @@ -53,6 +55,70 @@
from .registration import init_pet_reg_wf


def _extract_twa_image(
pet_file: str,
output_dir: Path,
frame_start_times: Sequence[float] | None,
frame_durations: Sequence[float] | None,
) -> str:
"""Return a time-weighted average (twa) reference image from a 4D PET series."""

output_dir.mkdir(parents=True, exist_ok=True)
img = nb.load(pet_file)
if img.ndim < 4 or img.shape[-1] == 1:
return pet_file

if frame_start_times is None or frame_durations is None:
raise ValueError(
'Frame timing metadata are required to compute a time-weighted reference image.'
)

frame_start_times = np.asarray(frame_start_times, dtype=float)
frame_durations = np.asarray(frame_durations, dtype=float)

if frame_start_times.ndim != 1 or frame_durations.ndim != 1:
raise ValueError('Frame timing metadata must be one-dimensional sequences.')

if len(frame_start_times) != len(frame_durations):
raise ValueError('FrameTimesStart and FrameDuration must have the same length.')

if len(frame_durations) != img.shape[-1]:
raise ValueError(
'Frame timing metadata must match the number of frames in the PET series.'
)

if np.any(frame_durations <= 0):
raise ValueError('FrameDuration values must all be positive.')

if np.any(np.diff(frame_start_times) < 0):
raise ValueError('FrameTimesStart values must be non-decreasing.')

hdr = img.header.copy()
data = np.asanyarray(img.dataobj)
weighted_average = np.average(data, axis=-1, weights=frame_durations).astype(np.float32)
hdr.set_data_shape(weighted_average.shape)

pet_path = Path(pet_file)
# Drop all suffixes (e.g., `.nii.gz`) before appending the reference label
pet_stem = pet_path
while pet_stem.suffix:
pet_stem = pet_stem.with_suffix('')

out_file = output_dir / f'{pet_stem.name}_timeavgref.nii.gz'
img.__class__(weighted_average, img.affine, hdr).to_filename(out_file)
return str(out_file)


def _write_identity_xforms(num_frames: int, filename: Path) -> Path:
"""Write ``num_frames`` identity transforms to ``filename``."""

filename = Path(filename)
filename.parent.mkdir(parents=True, exist_ok=True)
n_xforms = max(int(num_frames or 0), 1)
LinearTransformsMapping([Affine() for _ in range(n_xforms)]).to_filename(filename, fmt='itk')
return filename


def init_pet_fit_wf(
*,
pet_series: list[str],
Expand Down Expand Up @@ -158,6 +224,13 @@ def init_pet_fit_wf(
if (petref is None) ^ (hmc_xforms is None):
raise ValueError("Both 'petref' and 'hmc' transforms must be provided together.")

if config.workflow.hmc_off and (petref or hmc_xforms):
config.loggers.workflow.warning(
'Ignoring precomputed motion correction derivatives because --hmc-off was set.'
)
petref = None
hmc_xforms = None

workflow = Workflow(name=name)

inputnode = pe.Node(
Expand Down Expand Up @@ -202,19 +275,6 @@ def init_pet_fit_wf(
)
hmc_buffer = pe.Node(niu.IdentityInterface(fields=['hmc_xforms']), name='hmc_buffer')

if pet_tlen <= 1: # 3D PET
petref = pet_file
idmat_fname = config.execution.work_dir / 'idmat.tfm'
Affine().to_filename(idmat_fname, fmt='itk')
hmc_xforms = idmat_fname
config.loggers.workflow.debug('3D PET file - motion correction not needed')
if petref:
petref_buffer.inputs.petref = petref
config.loggers.workflow.debug(f'(Re)using motion correction reference: {petref}')
if hmc_xforms:
hmc_buffer.inputs.hmc_xforms = hmc_xforms
config.loggers.workflow.debug(f'(Re)using motion correction transforms: {hmc_xforms}')

timing_parameters = prepare_timing_parameters(metadata)
frame_durations = timing_parameters.get('FrameDuration')
frame_start_times = timing_parameters.get('FrameTimesStart')
Expand All @@ -226,6 +286,32 @@ def init_pet_fit_wf(
'Please check your BIDS JSON sidecar.'
)

hmc_disabled = bool(config.workflow.hmc_off)
if hmc_disabled:
config.execution.work_dir.mkdir(parents=True, exist_ok=True)
petref = petref or _extract_twa_image(
pet_file,
config.execution.work_dir,
frame_start_times,
frame_durations,
)
idmat_fname = config.execution.work_dir / 'idmat.tfm'
n_frames = len(frame_durations)
hmc_xforms = _write_identity_xforms(n_frames, idmat_fname)
config.loggers.workflow.info('Head motion correction disabled; using identity transforms.')

if pet_tlen <= 1: # 3D PET
petref = pet_file
idmat_fname = config.execution.work_dir / 'idmat.tfm'
hmc_xforms = _write_identity_xforms(pet_tlen, idmat_fname)
config.loggers.workflow.debug('3D PET file - motion correction not needed')
if petref:
petref_buffer.inputs.petref = petref
config.loggers.workflow.debug(f'(Re)using motion correction reference: {petref}')
if hmc_xforms:
hmc_buffer.inputs.hmc_xforms = hmc_xforms
config.loggers.workflow.debug(f'(Re)using motion correction transforms: {hmc_xforms}')

summary = pe.Node(
FunctionalSummary(
registration=('Precomputed' if petref2anat_xform else 'mri_coreg'),
Expand Down
105 changes: 104 additions & 1 deletion petprep/workflows/pet/tests/test_fit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path

import nibabel as nb
import nitransforms as nt
import numpy as np
import pytest
import yaml
Expand All @@ -11,7 +12,7 @@
from ....utils import bids
from ...tests import mock_config
from ...tests.test_base import BASE_LAYOUT
from ..fit import init_pet_fit_wf, init_pet_native_wf
from ..fit import _extract_twa_image, init_pet_fit_wf, init_pet_native_wf
from ..outputs import init_refmask_report_wf


Expand Down Expand Up @@ -350,6 +351,108 @@ def test_pet_fit_stage1_with_cached_baseline(bids_root: Path, tmp_path: Path):
assert not any(name.startswith('pet_hmc_wf') for name in wf.list_node_names())


def test_pet_fit_hmc_off_disables_stage1(bids_root: Path, tmp_path: Path):
"""Disabling HMC should skip Stage 1 and use identity transforms."""
pet_series = [str(bids_root / 'sub-01' / 'pet' / 'sub-01_task-rest_run-1_pet.nii.gz')]
data = np.stack(
(
np.ones((2, 2, 2), dtype=np.float32),
np.full((2, 2, 2), 3.0, dtype=np.float32),
),
axis=-1,
)
img = nb.Nifti1Image(data, np.eye(4))
for path in pet_series:
img.to_filename(path)

sidecar = Path(pet_series[0]).with_suffix('').with_suffix('.json')
sidecar.write_text('{"FrameTimesStart": [0, 2], "FrameDuration": [2, 4]}')

with mock_config(bids_dir=bids_root):
config.workflow.hmc_off = True
wf = init_pet_fit_wf(pet_series=pet_series, precomputed={}, omp_nthreads=1)

assert not any(name.startswith('pet_hmc_wf') for name in wf.list_node_names())
hmc_buffer = wf.get_node('hmc_buffer')
assert str(hmc_buffer.inputs.hmc_xforms).endswith('idmat.tfm')
hmc = nt.linear.load(hmc_buffer.inputs.hmc_xforms)
assert hmc.matrix.shape[0] == data.shape[-1]
assert np.allclose(hmc.matrix, np.tile(np.eye(4), (data.shape[-1], 1, 1)))
petref_buffer = wf.get_node('petref_buffer')
petref_name = Path(petref_buffer.inputs.petref).name
assert petref_name.endswith('_timeavgref.nii.gz')
assert '.nii_timeavgref' not in petref_name
petref_img = nb.load(petref_buffer.inputs.petref)
assert np.allclose(petref_img.get_fdata(), 14.0 / 6.0)


@pytest.mark.parametrize(
('frame_start_times', 'frame_durations', 'message'),
[
(None, [1, 1], 'Frame timing metadata are required'),
([0, 1], None, 'Frame timing metadata are required'),
([[0, 1]], [1, 1], 'must be one-dimensional'),
([0, 1], [1], 'the same length'),
([0, 1, 2], [1, 1, 1], 'match the number of frames'),
([0, 1], [1, -1], 'must all be positive'),
([1, 0], [1, 1], 'must be non-decreasing'),
],
)
def test_extract_twa_image_validation(
tmp_path: Path, frame_start_times, frame_durations, message: str
):
"""Validate error handling for malformed frame timing metadata."""

pet_img = nb.Nifti1Image(np.zeros((2, 2, 2, 2), dtype=np.float32), np.eye(4))
pet_file = tmp_path / 'pet.nii.gz'
pet_img.to_filename(pet_file)

with pytest.raises(ValueError, match=message): # noqa: PT011
_extract_twa_image(
str(pet_file),
tmp_path / 'out',
frame_start_times,
frame_durations,
)


def test_pet_fit_hmc_off_ignores_precomputed(bids_root: Path, tmp_path: Path):
"""Precomputed derivatives are ignored when ``--hmc-off`` is set."""

pet_series = [str(bids_root / 'sub-01' / 'pet' / 'sub-01_task-rest_run-1_pet.nii.gz')]
data = np.stack((np.ones((2, 2, 2)), np.full((2, 2, 2), 2.0)), axis=-1)
img = nb.Nifti1Image(data, np.eye(4))
for path in pet_series:
img.to_filename(path)

sidecar = Path(pet_series[0]).with_suffix('').with_suffix('.json')
sidecar.write_text('{"FrameTimesStart": [0, 1], "FrameDuration": [1, 1]}')

precomputed_petref = tmp_path / 'precomputed_petref.nii.gz'
precomputed_hmc = tmp_path / 'precomputed_hmc.txt'
img.to_filename(precomputed_petref)
np.savetxt(precomputed_hmc, np.eye(4))

with mock_config(bids_dir=bids_root):
config.workflow.hmc_off = True
wf = init_pet_fit_wf(
pet_series=pet_series,
precomputed={
'petref': str(precomputed_petref),
'transforms': {'hmc': str(precomputed_hmc)},
},
omp_nthreads=1,
)

petref_buffer = wf.get_node('petref_buffer')
hmc_buffer = wf.get_node('hmc_buffer')

assert petref_buffer.inputs.petref != str(precomputed_petref)
assert Path(petref_buffer.inputs.petref).name.endswith('_timeavgref.nii.gz')
assert hmc_buffer.inputs.hmc_xforms != str(precomputed_hmc)
assert Path(hmc_buffer.inputs.hmc_xforms).name == 'idmat.tfm'


def test_init_refmask_report_wf(tmp_path: Path):
"""Ensure the refmask report workflow initializes without errors."""
wf = init_refmask_report_wf(output_dir=str(tmp_path), ref_name='test')
Expand Down