Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/nifreeze/cli/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ def build_parser() -> ArgumentParser:
)
parser.add_argument(
"--nthreads",
"--omp-nthreads",
"--ncpus",
action="store",
type=int,
default=None,
Expand Down
10 changes: 10 additions & 0 deletions src/nifreeze/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,16 @@ def __getitem__(
affine = self.motion_affines[idx] if self.motion_affines is not None else None
return self.dataobj[..., idx], affine, *self._getextra(idx)

@property
def shape3d(self):
"""Get the shape of the 3D volume."""
return self.dataobj.shape[:3]

@property
def size3d(self):
"""Get the number of voxels in the 3D volume."""
return np.prod(self.dataobj.shape[:3])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like better volume_shape and volume_size for the properties, but I prefer to have this merged and change them at a latter stage if we agree on better naming.

@classmethod
def from_filename(cls, filename: Path | str) -> Self:
"""
Expand Down
37 changes: 32 additions & 5 deletions src/nifreeze/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from os import cpu_count
from pathlib import Path
from tempfile import TemporaryDirectory
from timeit import default_timer as timer
from typing import TypeVar

from tqdm import tqdm
Expand All @@ -42,6 +43,7 @@

DatasetT = TypeVar("DatasetT", bound=BaseDataset)

DEFAULT_CHUNK_SIZE: int = int(1e6)
FIT_MSG = "Fit&predict"
REG_MSG = "Realign"

Expand Down Expand Up @@ -109,13 +111,24 @@
dataset = result # type: ignore[assignment]

n_jobs = kwargs.pop("n_jobs", None) or min(cpu_count() or 1, 8)
n_threads = kwargs.pop("omp_nthreads", None) or ((cpu_count() or 2) - 1)

num_voxels = dataset.brainmask.sum() if dataset.brainmask is not None else dataset.size3d
chunk_size = DEFAULT_CHUNK_SIZE * (n_threads or 1)

# Prepare iterator
iterfunc = getattr(iterators, f"{self._strategy}_iterator")
index_iter = iterfunc(len(dataset), seed=kwargs.get("seed", None))

# Initialize model
if isinstance(self._model, str):
if self._model.endswith("dti"):
self._model_kwargs["step"] = chunk_size

Check warning on line 126 in src/nifreeze/estimator.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/estimator.py#L126

Added line #L126 was not covered by tests

# Example: change model parameters only for DKI
# if self._model.endswith("dki"):
# self._model_kwargs["fit_model"] = "CWLS"

# Factory creates the appropriate model and pipes arguments
model = ModelFactory.init(
model=self._model,
Expand All @@ -125,10 +138,25 @@
else:
model = self._model

# Prepare fit/predict keyword arguments
fit_pred_kwargs = {
"n_jobs": n_jobs,
"omp_nthreads": n_threads,
}
if model.__class__.__name__ == "DTIModel":
fit_pred_kwargs["step"] = chunk_size

Check warning on line 147 in src/nifreeze/estimator.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/estimator.py#L147

Added line #L147 was not covered by tests

print(f"Dataset size: {num_voxels}x{len(dataset)}.")
print(f"Parallel execution: {fit_pred_kwargs}.")
print(f"Model: {model}.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that having proper logging instead of using print statements will be required in the mid-term.


if self._single_fit:
model.fit_predict(None, n_jobs=n_jobs)
print("Fitting 'single' model started ...")
start = timer()
model.fit_predict(None, **fit_pred_kwargs)
print(f"Fitting 'single' model finished, elapsed {timer() - start}s.")

Check warning on line 157 in src/nifreeze/estimator.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/estimator.py#L154-L157

Added lines #L154 - L157 were not covered by tests

kwargs["num_threads"] = kwargs.pop("omp_nthreads", None) or kwargs.pop("num_threads", None)
kwargs["num_threads"] = n_threads
kwargs = self._align_kwargs | kwargs

dataset_length = len(dataset)
Expand All @@ -151,15 +179,14 @@
pbar.set_description_str(f"{FIT_MSG: <16} vol. <{i}>")

# fit the model
test_set = dataset[i]
predicted = model.fit_predict( # type: ignore[union-attr]
i,
n_jobs=n_jobs,
**fit_pred_kwargs,
)

# prepare data for running ANTs
predicted_path, volume_path, init_path = _prepare_registration_data(
test_set[0],
dataset[i][0], # Access the target volume
predicted,
dataset.affine,
i,
Expand Down
25 changes: 18 additions & 7 deletions src/nifreeze/model/dmri.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#

from importlib import import_module
from typing import Any

import numpy as np
from dipy.core.gradients import gradient_table_from_bvals_bvecs
Expand All @@ -38,9 +39,8 @@
B_MIN = 50


def _exec_fit(model, data, chunk=None):
retval = model.fit(data)
return retval, chunk
def _exec_fit(model, data, chunk=None, **kwargs):
return model.fit(data, **kwargs), chunk


def _exec_predict(model, chunk=None, **kwargs):
Expand Down Expand Up @@ -104,7 +104,7 @@

super().__init__(dataset, **kwargs)

def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
def _fit(self, index: int | None = None, n_jobs: int | None = None, **kwargs):

Check warning on line 107 in src/nifreeze/model/dmri.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/model/dmri.py#L107

Added line #L107 was not covered by tests
"""Fit the model chunk-by-chunk asynchronously"""

n_jobs = n_jobs or 1
Expand Down Expand Up @@ -136,9 +136,18 @@
class_name,
)(gtab, **kwargs)

fit_kwargs: dict[str, Any] = {} # Add here keyword arguments

is_dki = model_str == "dipy.reconst.dki.DiffusionKurtosisModel"

# One single CPU - linear execution (full model)
if n_jobs == 1:
_modelfit, _ = _exec_fit(model, data)
# DKI model does not allow parallelization as implemented here
if n_jobs == 1 or is_dki:

Check warning on line 145 in src/nifreeze/model/dmri.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/model/dmri.py#L144-L145

Added lines #L144 - L145 were not covered by tests
_modelfit, _ = _exec_fit(model, data, **fit_kwargs)
self._models = [_modelfit]
return 1
elif is_dki:

Check warning on line 149 in src/nifreeze/model/dmri.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/model/dmri.py#L147-L149

Added lines #L147 - L149 were not covered by tests
_modelfit = model.multi_fit(data, **fit_kwargs)
self._models = [_modelfit]
return 1

Expand All @@ -150,7 +159,8 @@
# Parallelize process with joblib
with Parallel(n_jobs=n_jobs) as executor:
results = executor(
delayed(_exec_fit)(model, dchunk, i) for i, dchunk in enumerate(data_chunks)
delayed(_exec_fit)(model, dchunk, i, **fit_kwargs)
for i, dchunk in enumerate(data_chunks)

Check warning on line 163 in src/nifreeze/model/dmri.py

View check run for this annotation

Codecov / codecov/patch

src/nifreeze/model/dmri.py#L163

Added line #L163 was not covered by tests
)
for submodel, rindex in results:
self._models[rindex] = submodel
Expand All @@ -168,6 +178,7 @@

"""

kwargs.pop("omp_nthreads", None) # Drop omp_nthreads
n_models = self._fit(
index,
n_jobs=kwargs.pop("n_jobs"),
Expand Down
11 changes: 8 additions & 3 deletions test/test_data_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,16 @@

from nifreeze.data import NFDH5_EXT, BaseDataset, load

DEFAULT_RANDOM_DATASET_SHAPE = (32, 32, 32, 5)
DEFAULT_RANDOM_DATASET_SIZE = int(np.prod(DEFAULT_RANDOM_DATASET_SHAPE[:3]))


@pytest.fixture
def random_dataset(request) -> BaseDataset:
def random_dataset(request, size=DEFAULT_RANDOM_DATASET_SHAPE) -> BaseDataset:
"""Create a BaseDataset with random data for testing."""

rng = request.node.rng
data = rng.random((32, 32, 32, 5)).astype(np.float32)
data = rng.random(size).astype(np.float32)
affine = np.eye(4, dtype=np.float32)
return BaseDataset(dataobj=data, affine=affine)

Expand All @@ -47,8 +50,10 @@ def test_base_dataset_init(random_dataset: BaseDataset):
"""Test that the BaseDataset can be initialized with random data."""
assert random_dataset.dataobj is not None
assert random_dataset.affine is not None
assert random_dataset.dataobj.shape == (32, 32, 32, 5)
assert random_dataset.dataobj.shape == DEFAULT_RANDOM_DATASET_SHAPE
assert random_dataset.affine.shape == (4, 4)
assert random_dataset.size3d == DEFAULT_RANDOM_DATASET_SIZE
assert random_dataset.shape3d == DEFAULT_RANDOM_DATASET_SHAPE[:3]


def test_len(random_dataset: BaseDataset):
Expand Down