diff --git a/docs/notebooks/pet_motion_estimation.ipynb b/docs/notebooks/pet_motion_estimation.ipynb index f695be6d1..6cd3c9082 100644 --- a/docs/notebooks/pet_motion_estimation.ipynb +++ b/docs/notebooks/pet_motion_estimation.ipynb @@ -53,343 +53,6 @@ "pet_dataset" ] }, - { - "cell_type": "code", - "execution_count": 4, - "id": "09e5053e-4565-4892-a0a9-a8410fbe6748", - "metadata": {}, - "outputs": [], - "source": [ - "data_train, data_test = pet_dataset.lofo_split(15)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "a51fcb74-0e57-4ee2-8c24-1f89ff6f879c", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([[[[-0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., -0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[-0., 0., 0., ..., 0., -0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., -0., -0., 0.],\n", - " ...,\n", - " [ 0., 0., -0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., -0., 0., 0.],\n", - " [-0., 0., -0., ..., 0., 0., 0.]],\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., -0., -0., ..., 0., -0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [ 0., -0., 0., ..., -0., 0., 0.],\n", - " [-0., 0., -0., ..., 0., 0., 0.],\n", - " [-0., -0., 0., ..., -0., 0., 0.]],\n", - "\n", - " ...,\n", - "\n", - " [[-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., -0., 0.],\n", - " [-0., 0., 0., ..., -0., 0., 0.],\n", - " ...,\n", - " [-0., -0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., -0., ..., -0., 0., 0.],\n", - " [ 0., -0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., -0., 0., 0.],\n", - " [-0., 0., -0., ..., 0., 0., 0.],\n", - " ...,\n", - " [-0., 0., -0., ..., -0., 0., 0.],\n", - " [-0., -0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., -0., ..., -0., 0., 0.]],\n", - "\n", - " [[-0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., -0., 0., 0.],\n", - " ...,\n", - " [-0., -0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., -0., ..., -0., 0., 0.],\n", - " [ 0., -0., 0., ..., 0., 0., 0.]]],\n", - "\n", - "\n", - " [[[-0., 0., 0., ..., 0., -0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., -0., 0.],\n", - " ...,\n", - " [ 0., 0., 0., ..., 0., -0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., -0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [ 0., 0., 0., ..., -0., 0., 0.],\n", - " [-0., 0., -0., ..., 0., -0., 0.],\n", - " [-0., 0., 0., ..., -0., 0., 0.]],\n", - "\n", - " [[ 0., 0., 0., ..., 0., -0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., -0., -0., ..., -0., -0., 0.],\n", - " ...,\n", - " [-0., 0., -0., ..., 0., -0., 0.],\n", - " [-0., -0., 0., ..., -0., 0., 0.],\n", - " [-0., 0., -0., ..., 0., 0., 0.]],\n", - "\n", - " ...,\n", - "\n", - " [[-0., 0., 0., ..., 0., -0., 0.],\n", - " [-0., 0., 0., ..., -0., 0., -0.],\n", - " [-0., 0., -0., ..., 0., -0., 0.],\n", - " ...,\n", - " [-0., 0., -0., ..., -0., 0., 0.],\n", - " [-0., -0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., -0., ..., -0., 0., 0.]],\n", - "\n", - " [[-0., 0., 0., ..., 0., 0., -0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., -0., 0., ..., -0., 0., -0.],\n", - " ...,\n", - " [-0., -0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., -0., ..., -0., 0., 0.],\n", - " [-0., -0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., -0., 0., -0.],\n", - " [ 0., 0., -0., ..., 0., 0., 0.],\n", - " ...,\n", - " [-0., 0., -0., ..., -0., 0., 0.],\n", - " [-0., -0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., -0., ..., -0., 0., 0.]]],\n", - "\n", - "\n", - " [[[-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., -0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., -0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[-0., 0., 0., ..., 0., -0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., -0., -0., 0.],\n", - " ...,\n", - " [ 0., 0., -0., ..., 0., -0., 0.],\n", - " [-0., 0., 0., ..., -0., 0., 0.],\n", - " [-0., 0., -0., ..., 0., 0., 0.]],\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., -0., -0., ..., 0., -0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., -0.],\n", - " ...,\n", - " [-0., -0., 0., ..., -0., 0., 0.],\n", - " [ 0., 0., -0., ..., 0., -0., 0.],\n", - " [ 0., -0., 0., ..., -0., 0., 0.]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., -0.],\n", - " [-0., 0., 0., ..., 0., -0., 0.],\n", - " [-0., -0., 0., ..., -0., 0., -0.],\n", - " ...,\n", - " [-0., -0., 0., ..., 0., 0., -0.],\n", - " [-0., 0., -0., ..., -0., 0., 0.],\n", - " [-0., -0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., -0., 0., -0.],\n", - " [-0., 0., -0., ..., 0., 0., 0.],\n", - " ...,\n", - " [-0., 0., -0., ..., -0., 0., 0.],\n", - " [-0., -0., 0., ..., 0., 0., -0.],\n", - " [-0., 0., -0., ..., -0., 0., 0.]],\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., -0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., -0., 0., ..., -0., 0., -0.],\n", - " ...,\n", - " [-0., -0., 0., ..., 0., 0., -0.],\n", - " [-0., 0., -0., ..., -0., 0., 0.],\n", - " [ 0., -0., 0., ..., 0., 0., 0.]]],\n", - "\n", - "\n", - " ...,\n", - "\n", - "\n", - " [[[ 0., 0., 0., ..., 0., -0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., -0., 0.],\n", - " ...,\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[-0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., -0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [-0., -0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., -0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[-0., 0., 0., ..., 0., -0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., -0.],\n", - " [ 0., 0., 0., ..., 0., -0., 0.],\n", - " ...,\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., -0., 0., ..., -0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., -0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., -0., 0., -0.],\n", - " ...,\n", - " [ 0., 0., 0., ..., 0., 0., -0.],\n", - " [-0., -0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., -0., 0., -0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [ 0., -0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., -0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., -0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., -0., 0., -0.],\n", - " ...,\n", - " [ 0., 0., 0., ..., 0., 0., -0.],\n", - " [-0., -0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.]]],\n", - "\n", - "\n", - " [[[ 0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[-0., 0., 0., ..., 0., 0., -0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., -0.],\n", - " ...,\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., -0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [ 0., -0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., -0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., -0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., -0.],\n", - " ...,\n", - " [-0., 0., 0., ..., 0., 0., -0.],\n", - " [-0., -0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., -0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [ 0., -0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., -0., 0., ..., 0., 0., 0.]]],\n", - "\n", - "\n", - " [[[-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[-0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., -0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " ...,\n", - "\n", - " [[ 0., 0., 0., ..., 0., 0., -0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., -0.],\n", - " ...,\n", - " [-0., 0., 0., ..., 0., 0., -0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[-0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., -0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " ...,\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.]],\n", - "\n", - " [[-0., 0., 0., ..., 0., 0., -0.],\n", - " [ 0., 0., 0., ..., 0., 0., 0.],\n", - " [ 0., 0., 0., ..., 0., 0., -0.],\n", - " ...,\n", - " [ 0., 0., 0., ..., 0., 0., -0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.],\n", - " [-0., 0., 0., ..., 0., 0., 0.]]]], dtype=float32)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "data_train[0]" - ] - }, { "cell_type": "code", "execution_count": 5, diff --git a/src/nifreeze/data/pet.py b/src/nifreeze/data/pet.py index 5777753a2..a7a9d3db6 100644 --- a/src/nifreeze/data/pet.py +++ b/src/nifreeze/data/pet.py @@ -81,44 +81,6 @@ def __getitem__( """ return super().__getitem__(idx) - def lofo_split(self, index): - """ - Leave-one-frame-out (LOFO) for PET data. - - Parameters - ---------- - index : int - Index of the PET frame to be left out in this fold. - - Returns - ------- - (train_data, train_timings) : tuple - Training data and corresponding timings, excluding the left-out frame. - (test_data, test_timing) : tuple - Test data (one PET frame) and corresponding timing. - """ - - if not Path(self._filepath).exists(): - self.to_filename(self._filepath) - - # Read original PET data - with h5py.File(self._filepath, "r") as in_file: - root = in_file["/0"] - pet_frame = np.asanyarray(root["dataobj"][..., index]) - timing_frame = np.asanyarray(root["midframe"][..., index]) - - # Mask to exclude the selected frame - mask = np.ones(self.dataobj.shape[-1], dtype=bool) - mask[index] = False - - train_data = self.dataobj[..., mask] - train_timings = self.midframe[mask] - - test_data = pet_frame - test_timing = timing_frame - - return (train_data, train_timings), (test_data, test_timing) - def set_transform(self, index: int, affine: np.ndarray, order: int = 3) -> None: """Set an affine, and update data object and gradients.""" ImageGrid = namedtuple("ImageGrid", ("shape", "affine")) diff --git a/src/nifreeze/estimator.py b/src/nifreeze/estimator.py index 4c38faf57..35b43f59b 100644 --- a/src/nifreeze/estimator.py +++ b/src/nifreeze/estimator.py @@ -40,7 +40,6 @@ from nifreeze.data.base import BaseDataset from nifreeze.data.pet import PET from nifreeze.model.base import BaseModel, ModelFactory -from nifreeze.model.pet import PETModel from nifreeze.registration.ants import ( Registration, _prepare_registration_data, @@ -222,91 +221,94 @@ def run(self, dataset: DatasetT, **kwargs) -> Self: class PETMotionEstimator: """Estimates motion within PET imaging data aligned with generic Estimator workflow.""" - def __init__(self, align_kwargs: dict | None = None, strategy: str = "lofo"): - self.align_kwargs = align_kwargs or {} - self.strategy = strategy + def __init__( + self, + model: BaseModel | str, + strategy="linear", + model_kwargs: dict | None = None, + align_kwargs: dict | None = None, + ): + self._model = model + self._strategy = strategy + self._model_kwargs = model_kwargs or {} + self._align_kwargs = align_kwargs or {} - def run(self, pet_dataset: PET, omp_nthreads: int | None = None) -> list: - n_frames = len(pet_dataset) - frame_indices = np.arange(n_frames) + def run(self, dataset: PET, omp_nthreads: int | None = None) -> list: + # Prepare iterator + iterfunc = getattr(iterators, f"{self._strategy}_iterator") + index_iter = iterfunc(len(dataset), seed=self._align_kwargs.get("seed", None)) + + # Initialize model + if isinstance(self._model, str): + model = ModelFactory.init( + model=self._model, + dataset=dataset, + **self._model_kwargs, + ) + else: + model = self._model + + dataset_length = len(dataset) if omp_nthreads: - self.align_kwargs["num_threads"] = omp_nthreads + self._align_kwargs["num_threads"] = omp_nthreads affine_matrices = [] with TemporaryDirectory() as tmp_dir: tmp_path = Path(tmp_dir) - for idx in tqdm(frame_indices, desc="Estimating PET motion"): - (train_data, train_times), (test_data, test_time) = pet_dataset.lofo_split(idx) + with tqdm(total=dataset_length, unit="vols.") as pbar: + for i in index_iter: + pbar.set_description_str(f"{FIT_MSG: <16} vol. <{i}>") + + # Fit the model once on the training dataset + model.fit_predict(None) + + # Predict the reference volume at the test frame's timepoint + predicted = model.fit_predict(dataset.midframe[i]) - if train_times is None: - raise ValueError( - f"train_times is None at index {idx}, check midframe initialization." + fixed_image_path = tmp_path / f"fixed_frame_{i:03d}.nii.gz" + moving_image_path = tmp_path / f"moving_frame_{i:03d}.nii.gz" + + fixed_img = nb.Nifti1Image(predicted, dataset.affine) + moving_img = nb.Nifti1Image(dataset[i][0], dataset.affine) + + moving_img = nb.as_closest_canonical(moving_img, enforce_diag=True) + + fixed_img.to_filename(fixed_image_path) + moving_img.to_filename(moving_image_path) + + registration_config = files("nifreeze.registration.config").joinpath( + "pet-to-pet_level1.json" ) - # Build a temporary dataset excluding the test frame - train_dataset = PET( - dataobj=train_data, - affine=pet_dataset.affine, - brainmask=pet_dataset.brainmask, - midframe=train_times, - total_duration=pet_dataset.total_duration, - ) - - # Instantiate PETModel explicitly - model = PETModel( - dataset=train_dataset, - timepoints=train_times, - xlim=pet_dataset.total_duration, - ) - - # Fit the model once on the training dataset - model.fit_predict(None) - - # Predict the reference volume at the test frame's timepoint - predicted = model.fit_predict(test_time) - - fixed_image_path = tmp_path / f"fixed_frame_{idx:03d}.nii.gz" - moving_image_path = tmp_path / f"moving_frame_{idx:03d}.nii.gz" - - fixed_img = nb.Nifti1Image(predicted, pet_dataset.affine) - moving_img = nb.Nifti1Image(test_data, pet_dataset.affine) - - moving_img = nb.as_closest_canonical(moving_img, enforce_diag=True) - - fixed_img.to_filename(fixed_image_path) - moving_img.to_filename(moving_image_path) - - registration_config = files("nifreeze.registration.config").joinpath( - "pet-to-pet_level1.json" - ) - - registration = Registration( - from_file=registration_config, - fixed_image=str(fixed_image_path), - moving_image=str(moving_image_path), - output_warped_image=True, - output_transform_prefix=f"ants_{idx:03d}", - **self.align_kwargs, - ) - - try: - result = registration.run(cwd=str(tmp_path)) - if result.outputs.forward_transforms: - transform = nt.io.itk.ITKLinearTransform.from_filename( - result.outputs.forward_transforms[0] - ) - matrix = transform.to_ras( - reference=str(fixed_image_path), moving=str(moving_image_path) - ) - affine_matrices.append(matrix) - else: + registration = Registration( + from_file=registration_config, + fixed_image=str(fixed_image_path), + moving_image=str(moving_image_path), + output_warped_image=True, + output_transform_prefix=f"ants_{i:03d}", + **self._align_kwargs, + ) + + try: + result = registration.run(cwd=str(tmp_path)) + if result.outputs.forward_transforms: + transform = nt.io.itk.ITKLinearTransform.from_filename( + result.outputs.forward_transforms[0] + ) + matrix = transform.to_ras( + reference=str(fixed_image_path), moving=str(moving_image_path) + ) + affine_matrices.append(matrix) + else: + affine_matrices.append(np.eye(4)) + print(f"No transforms produced for index {i}") + except Exception as e: affine_matrices.append(np.eye(4)) - print(f"No transforms produced for index {idx}") - except Exception as e: - affine_matrices.append(np.eye(4)) - print(f"Failed to process frame {idx} due to {e}") + print(f"Failed to process frame {i} due to {e}") + + pbar.update() - return affine_matrices + return affine_matrices diff --git a/test/test_integration_pet.py b/test/test_integration_pet.py index 68b2a318e..52d3d0993 100644 --- a/test/test_integration_pet.py +++ b/test/test_integration_pet.py @@ -28,6 +28,7 @@ from nifreeze.data.pet import PET from nifreeze.estimator import PETMotionEstimator +from nifreeze.model.base import BaseModel @pytest.fixture @@ -51,16 +52,6 @@ def random_dataset(setup_random_pet_data) -> PET: ) -@pytest.mark.random_pet_data(4, (2, 2, 2), np.asarray([1.0, 2.0, 3.0, 4.0]), 5.0) -def test_lofo_split_shapes(random_dataset, tmp_path): - idx = 2 - (train_data, train_times), (test_data, test_time) = random_dataset.lofo_split(idx) - assert train_data.shape[-1] == random_dataset.dataobj.shape[-1] - 1 - np.testing.assert_array_equal(test_data, random_dataset.dataobj[..., idx]) - np.testing.assert_array_equal(train_times, np.delete(random_dataset.midframe, idx)) - assert test_time == random_dataset.midframe[idx] - - @pytest.mark.random_pet_data(3, (2, 2, 2), np.asarray([1.0, 2.0, 3.0]), 4.0) def test_to_from_filename_roundtrip(random_dataset, tmp_path): out_file = tmp_path / "petdata" @@ -75,16 +66,16 @@ def test_to_from_filename_roundtrip(random_dataset, tmp_path): @pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 60.0) def test_pet_motion_estimator_run(random_dataset, monkeypatch): - class DummyModel: + class DummyModel(BaseModel): def __init__(self, dataset, timepoints, xlim): self.dataset = dataset - def fit_predict(self, index): + def fit_predict(self, index = None, **kawargs): if index is None: return None return np.zeros(self.dataset.shape3d, dtype=np.float32) - monkeypatch.setattr("nifreeze.estimator.PETModel", DummyModel) + model = DummyModel(random_dataset, None, None) class DummyRegistration: def __init__(self, *args, **kwargs): @@ -95,7 +86,7 @@ def run(self, cwd=None): monkeypatch.setattr("nifreeze.estimator.Registration", DummyRegistration) - estimator = PETMotionEstimator(None) + estimator = PETMotionEstimator(model) affines = estimator.run(random_dataset) assert len(affines) == len(random_dataset) for mat in affines: