diff --git a/docs/usage.rst b/docs/usage.rst index b746a09c..593c5854 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -201,10 +201,15 @@ manual) fixed during robust template estimation to improve reproducibility. Iterations are automatically disabled to reduce runtime when :option:`--hmc-init-frame-fix` is used. +When motion correction is undesirable, use :option:`--hmc-off` to disable head motion +correction entirely and keep the data unmodified apart from downstream +processing steps. + Examples: :: $ petprep /data/bids_root /out participant --hmc-fwhm 8 --hmc-start-time 60 $ petprep /data/bids_root /out participant --hmc-init-frame 10 --hmc-init-frame-fix + $ petprep /data/bids_root /out participant --hmc-off Segmentation ---------------- diff --git a/petprep/cli/parser.py b/petprep/cli/parser.py index 7c470800..898be24b 100644 --- a/petprep/cli/parser.py +++ b/petprep/cli/parser.py @@ -563,6 +563,12 @@ def _bids_filter(value, parser): action='store_true', help=('Keep the chosen initial reference frame fixed during head-motion estimation.'), ) + g_hmc.add_argument( + '--hmc-off', + dest='hmc_off', + action='store_true', + help='Disable head-motion correction and use the uncorrected data.', + ) g_seg = parser.add_argument_group('Segmentation options') g_seg.add_argument( diff --git a/petprep/cli/tests/test_parser.py b/petprep/cli/tests/test_parser.py index d259f6f8..79f4ac27 100644 --- a/petprep/cli/tests/test_parser.py +++ b/petprep/cli/tests/test_parser.py @@ -366,3 +366,19 @@ def test_hmc_init_frame_parsing(tmp_path): opts = parser.parse_args(base_args + ['--hmc-init-frame', '3', '--hmc-init-frame-fix']) assert opts.hmc_init_frame == 3 assert opts.hmc_fix_frame is True + + +def test_hmc_off_flag(tmp_path): + """Ensure disabling motion correction is parsed correctly.""" + datapath = tmp_path / 'data' + outpath = tmp_path / 'out' + datapath.mkdir() + + parser = _build_parser() + base_args = [str(datapath), str(outpath), 'participant'] + + opts = parser.parse_args(base_args) + assert opts.hmc_off is False + + opts = parser.parse_args(base_args + ['--hmc-off']) + assert opts.hmc_off is True diff --git a/petprep/config.py b/petprep/config.py index 2cfc100b..73dfa969 100644 --- a/petprep/config.py +++ b/petprep/config.py @@ -602,6 +602,8 @@ class workflow(_Config): """Index of initial frame for head-motion estimation ('auto' selects highest uptake).""" hmc_fix_frame: bool = False """Whether to fix the reference frame during head-motion estimation.""" + hmc_off: bool = False + """Disable head-motion correction and keep data uncorrected.""" seg = 'gtm' """Segmentation approach ('gtm', 'brainstem', 'thalamicNuclei', 'hippocampusAmygdala', 'wm', 'raphe', 'limbic').""" diff --git a/petprep/workflows/pet/fit.py b/petprep/workflows/pet/fit.py index 5fa28e49..e63ab437 100644 --- a/petprep/workflows/pet/fit.py +++ b/petprep/workflows/pet/fit.py @@ -20,12 +20,14 @@ # # https://www.nipreps.org/community/licensing/ # +from collections.abc import Sequence from pathlib import Path import nibabel as nb +import numpy as np from nipype.interfaces import utility as niu from nipype.pipeline import engine as pe -from nitransforms.linear import Affine +from nitransforms.linear import Affine, LinearTransformsMapping from niworkflows.interfaces.header import ValidateImage from niworkflows.utils.connections import listify @@ -53,6 +55,70 @@ from .registration import init_pet_reg_wf +def _extract_twa_image( + pet_file: str, + output_dir: Path, + frame_start_times: Sequence[float] | None, + frame_durations: Sequence[float] | None, +) -> str: + """Return a time-weighted average (twa) reference image from a 4D PET series.""" + + output_dir.mkdir(parents=True, exist_ok=True) + img = nb.load(pet_file) + if img.ndim < 4 or img.shape[-1] == 1: + return pet_file + + if frame_start_times is None or frame_durations is None: + raise ValueError( + 'Frame timing metadata are required to compute a time-weighted reference image.' + ) + + frame_start_times = np.asarray(frame_start_times, dtype=float) + frame_durations = np.asarray(frame_durations, dtype=float) + + if frame_start_times.ndim != 1 or frame_durations.ndim != 1: + raise ValueError('Frame timing metadata must be one-dimensional sequences.') + + if len(frame_start_times) != len(frame_durations): + raise ValueError('FrameTimesStart and FrameDuration must have the same length.') + + if len(frame_durations) != img.shape[-1]: + raise ValueError( + 'Frame timing metadata must match the number of frames in the PET series.' + ) + + if np.any(frame_durations <= 0): + raise ValueError('FrameDuration values must all be positive.') + + if np.any(np.diff(frame_start_times) < 0): + raise ValueError('FrameTimesStart values must be non-decreasing.') + + hdr = img.header.copy() + data = np.asanyarray(img.dataobj) + weighted_average = np.average(data, axis=-1, weights=frame_durations).astype(np.float32) + hdr.set_data_shape(weighted_average.shape) + + pet_path = Path(pet_file) + # Drop all suffixes (e.g., `.nii.gz`) before appending the reference label + pet_stem = pet_path + while pet_stem.suffix: + pet_stem = pet_stem.with_suffix('') + + out_file = output_dir / f'{pet_stem.name}_timeavgref.nii.gz' + img.__class__(weighted_average, img.affine, hdr).to_filename(out_file) + return str(out_file) + + +def _write_identity_xforms(num_frames: int, filename: Path) -> Path: + """Write ``num_frames`` identity transforms to ``filename``.""" + + filename = Path(filename) + filename.parent.mkdir(parents=True, exist_ok=True) + n_xforms = max(int(num_frames or 0), 1) + LinearTransformsMapping([Affine() for _ in range(n_xforms)]).to_filename(filename, fmt='itk') + return filename + + def init_pet_fit_wf( *, pet_series: list[str], @@ -158,6 +224,13 @@ def init_pet_fit_wf( if (petref is None) ^ (hmc_xforms is None): raise ValueError("Both 'petref' and 'hmc' transforms must be provided together.") + if config.workflow.hmc_off and (petref or hmc_xforms): + config.loggers.workflow.warning( + 'Ignoring precomputed motion correction derivatives because --hmc-off was set.' + ) + petref = None + hmc_xforms = None + workflow = Workflow(name=name) inputnode = pe.Node( @@ -202,19 +275,6 @@ def init_pet_fit_wf( ) hmc_buffer = pe.Node(niu.IdentityInterface(fields=['hmc_xforms']), name='hmc_buffer') - if pet_tlen <= 1: # 3D PET - petref = pet_file - idmat_fname = config.execution.work_dir / 'idmat.tfm' - Affine().to_filename(idmat_fname, fmt='itk') - hmc_xforms = idmat_fname - config.loggers.workflow.debug('3D PET file - motion correction not needed') - if petref: - petref_buffer.inputs.petref = petref - config.loggers.workflow.debug(f'(Re)using motion correction reference: {petref}') - if hmc_xforms: - hmc_buffer.inputs.hmc_xforms = hmc_xforms - config.loggers.workflow.debug(f'(Re)using motion correction transforms: {hmc_xforms}') - timing_parameters = prepare_timing_parameters(metadata) frame_durations = timing_parameters.get('FrameDuration') frame_start_times = timing_parameters.get('FrameTimesStart') @@ -226,6 +286,32 @@ def init_pet_fit_wf( 'Please check your BIDS JSON sidecar.' ) + hmc_disabled = bool(config.workflow.hmc_off) + if hmc_disabled: + config.execution.work_dir.mkdir(parents=True, exist_ok=True) + petref = petref or _extract_twa_image( + pet_file, + config.execution.work_dir, + frame_start_times, + frame_durations, + ) + idmat_fname = config.execution.work_dir / 'idmat.tfm' + n_frames = len(frame_durations) + hmc_xforms = _write_identity_xforms(n_frames, idmat_fname) + config.loggers.workflow.info('Head motion correction disabled; using identity transforms.') + + if pet_tlen <= 1: # 3D PET + petref = pet_file + idmat_fname = config.execution.work_dir / 'idmat.tfm' + hmc_xforms = _write_identity_xforms(pet_tlen, idmat_fname) + config.loggers.workflow.debug('3D PET file - motion correction not needed') + if petref: + petref_buffer.inputs.petref = petref + config.loggers.workflow.debug(f'(Re)using motion correction reference: {petref}') + if hmc_xforms: + hmc_buffer.inputs.hmc_xforms = hmc_xforms + config.loggers.workflow.debug(f'(Re)using motion correction transforms: {hmc_xforms}') + summary = pe.Node( FunctionalSummary( registration=('Precomputed' if petref2anat_xform else 'mri_coreg'), diff --git a/petprep/workflows/pet/tests/test_fit.py b/petprep/workflows/pet/tests/test_fit.py index e6fe6fe1..1e951917 100644 --- a/petprep/workflows/pet/tests/test_fit.py +++ b/petprep/workflows/pet/tests/test_fit.py @@ -1,6 +1,7 @@ from pathlib import Path import nibabel as nb +import nitransforms as nt import numpy as np import pytest import yaml @@ -11,7 +12,7 @@ from ....utils import bids from ...tests import mock_config from ...tests.test_base import BASE_LAYOUT -from ..fit import init_pet_fit_wf, init_pet_native_wf +from ..fit import _extract_twa_image, init_pet_fit_wf, init_pet_native_wf from ..outputs import init_refmask_report_wf @@ -350,6 +351,108 @@ def test_pet_fit_stage1_with_cached_baseline(bids_root: Path, tmp_path: Path): assert not any(name.startswith('pet_hmc_wf') for name in wf.list_node_names()) +def test_pet_fit_hmc_off_disables_stage1(bids_root: Path, tmp_path: Path): + """Disabling HMC should skip Stage 1 and use identity transforms.""" + pet_series = [str(bids_root / 'sub-01' / 'pet' / 'sub-01_task-rest_run-1_pet.nii.gz')] + data = np.stack( + ( + np.ones((2, 2, 2), dtype=np.float32), + np.full((2, 2, 2), 3.0, dtype=np.float32), + ), + axis=-1, + ) + img = nb.Nifti1Image(data, np.eye(4)) + for path in pet_series: + img.to_filename(path) + + sidecar = Path(pet_series[0]).with_suffix('').with_suffix('.json') + sidecar.write_text('{"FrameTimesStart": [0, 2], "FrameDuration": [2, 4]}') + + with mock_config(bids_dir=bids_root): + config.workflow.hmc_off = True + wf = init_pet_fit_wf(pet_series=pet_series, precomputed={}, omp_nthreads=1) + + assert not any(name.startswith('pet_hmc_wf') for name in wf.list_node_names()) + hmc_buffer = wf.get_node('hmc_buffer') + assert str(hmc_buffer.inputs.hmc_xforms).endswith('idmat.tfm') + hmc = nt.linear.load(hmc_buffer.inputs.hmc_xforms) + assert hmc.matrix.shape[0] == data.shape[-1] + assert np.allclose(hmc.matrix, np.tile(np.eye(4), (data.shape[-1], 1, 1))) + petref_buffer = wf.get_node('petref_buffer') + petref_name = Path(petref_buffer.inputs.petref).name + assert petref_name.endswith('_timeavgref.nii.gz') + assert '.nii_timeavgref' not in petref_name + petref_img = nb.load(petref_buffer.inputs.petref) + assert np.allclose(petref_img.get_fdata(), 14.0 / 6.0) + + +@pytest.mark.parametrize( + ('frame_start_times', 'frame_durations', 'message'), + [ + (None, [1, 1], 'Frame timing metadata are required'), + ([0, 1], None, 'Frame timing metadata are required'), + ([[0, 1]], [1, 1], 'must be one-dimensional'), + ([0, 1], [1], 'the same length'), + ([0, 1, 2], [1, 1, 1], 'match the number of frames'), + ([0, 1], [1, -1], 'must all be positive'), + ([1, 0], [1, 1], 'must be non-decreasing'), + ], +) +def test_extract_twa_image_validation( + tmp_path: Path, frame_start_times, frame_durations, message: str +): + """Validate error handling for malformed frame timing metadata.""" + + pet_img = nb.Nifti1Image(np.zeros((2, 2, 2, 2), dtype=np.float32), np.eye(4)) + pet_file = tmp_path / 'pet.nii.gz' + pet_img.to_filename(pet_file) + + with pytest.raises(ValueError, match=message): # noqa: PT011 + _extract_twa_image( + str(pet_file), + tmp_path / 'out', + frame_start_times, + frame_durations, + ) + + +def test_pet_fit_hmc_off_ignores_precomputed(bids_root: Path, tmp_path: Path): + """Precomputed derivatives are ignored when ``--hmc-off`` is set.""" + + pet_series = [str(bids_root / 'sub-01' / 'pet' / 'sub-01_task-rest_run-1_pet.nii.gz')] + data = np.stack((np.ones((2, 2, 2)), np.full((2, 2, 2), 2.0)), axis=-1) + img = nb.Nifti1Image(data, np.eye(4)) + for path in pet_series: + img.to_filename(path) + + sidecar = Path(pet_series[0]).with_suffix('').with_suffix('.json') + sidecar.write_text('{"FrameTimesStart": [0, 1], "FrameDuration": [1, 1]}') + + precomputed_petref = tmp_path / 'precomputed_petref.nii.gz' + precomputed_hmc = tmp_path / 'precomputed_hmc.txt' + img.to_filename(precomputed_petref) + np.savetxt(precomputed_hmc, np.eye(4)) + + with mock_config(bids_dir=bids_root): + config.workflow.hmc_off = True + wf = init_pet_fit_wf( + pet_series=pet_series, + precomputed={ + 'petref': str(precomputed_petref), + 'transforms': {'hmc': str(precomputed_hmc)}, + }, + omp_nthreads=1, + ) + + petref_buffer = wf.get_node('petref_buffer') + hmc_buffer = wf.get_node('hmc_buffer') + + assert petref_buffer.inputs.petref != str(precomputed_petref) + assert Path(petref_buffer.inputs.petref).name.endswith('_timeavgref.nii.gz') + assert hmc_buffer.inputs.hmc_xforms != str(precomputed_hmc) + assert Path(hmc_buffer.inputs.hmc_xforms).name == 'idmat.tfm' + + def test_init_refmask_report_wf(tmp_path: Path): """Ensure the refmask report workflow initializes without errors.""" wf = init_refmask_report_wf(output_dir=str(tmp_path), ref_name='test')