diff --git a/src/nifreeze/cli/parser.py b/src/nifreeze/cli/parser.py index 16ad2459c..93d6f9848 100644 --- a/src/nifreeze/cli/parser.py +++ b/src/nifreeze/cli/parser.py @@ -91,6 +91,8 @@ def build_parser() -> ArgumentParser: ) parser.add_argument( "--nthreads", + "--omp-nthreads", + "--ncpus", action="store", type=int, default=None, diff --git a/src/nifreeze/data/base.py b/src/nifreeze/data/base.py index 4acb1e252..e61ddc923 100644 --- a/src/nifreeze/data/base.py +++ b/src/nifreeze/data/base.py @@ -127,6 +127,16 @@ def __getitem__( affine = self.motion_affines[idx] if self.motion_affines is not None else None return self.dataobj[..., idx], affine, *self._getextra(idx) + @property + def shape3d(self): + """Get the shape of the 3D volume.""" + return self.dataobj.shape[:3] + + @property + def size3d(self): + """Get the number of voxels in the 3D volume.""" + return np.prod(self.dataobj.shape[:3]) + @classmethod def from_filename(cls, filename: Path | str) -> Self: """ diff --git a/src/nifreeze/estimator.py b/src/nifreeze/estimator.py index 46d3f96f0..d37ce7aaf 100644 --- a/src/nifreeze/estimator.py +++ b/src/nifreeze/estimator.py @@ -27,6 +27,7 @@ from os import cpu_count from pathlib import Path from tempfile import TemporaryDirectory +from timeit import default_timer as timer from typing import TypeVar from tqdm import tqdm @@ -42,6 +43,7 @@ DatasetT = TypeVar("DatasetT", bound=BaseDataset) +DEFAULT_CHUNK_SIZE: int = int(1e6) FIT_MSG = "Fit&predict" REG_MSG = "Realign" @@ -109,6 +111,10 @@ def run(self, dataset: DatasetT, **kwargs) -> Self: dataset = result # type: ignore[assignment] n_jobs = kwargs.pop("n_jobs", None) or min(cpu_count() or 1, 8) + n_threads = kwargs.pop("omp_nthreads", None) or ((cpu_count() or 2) - 1) + + num_voxels = dataset.brainmask.sum() if dataset.brainmask is not None else dataset.size3d + chunk_size = DEFAULT_CHUNK_SIZE * (n_threads or 1) # Prepare iterator iterfunc = getattr(iterators, f"{self._strategy}_iterator") @@ -116,6 +122,13 @@ def run(self, dataset: DatasetT, **kwargs) -> Self: # Initialize model if isinstance(self._model, str): + if self._model.endswith("dti"): + self._model_kwargs["step"] = chunk_size + + # Example: change model parameters only for DKI + # if self._model.endswith("dki"): + # self._model_kwargs["fit_model"] = "CWLS" + # Factory creates the appropriate model and pipes arguments model = ModelFactory.init( model=self._model, @@ -125,10 +138,25 @@ def run(self, dataset: DatasetT, **kwargs) -> Self: else: model = self._model + # Prepare fit/predict keyword arguments + fit_pred_kwargs = { + "n_jobs": n_jobs, + "omp_nthreads": n_threads, + } + if model.__class__.__name__ == "DTIModel": + fit_pred_kwargs["step"] = chunk_size + + print(f"Dataset size: {num_voxels}x{len(dataset)}.") + print(f"Parallel execution: {fit_pred_kwargs}.") + print(f"Model: {model}.") + if self._single_fit: - model.fit_predict(None, n_jobs=n_jobs) + print("Fitting 'single' model started ...") + start = timer() + model.fit_predict(None, **fit_pred_kwargs) + print(f"Fitting 'single' model finished, elapsed {timer() - start}s.") - kwargs["num_threads"] = kwargs.pop("omp_nthreads", None) or kwargs.pop("num_threads", None) + kwargs["num_threads"] = n_threads kwargs = self._align_kwargs | kwargs dataset_length = len(dataset) @@ -151,15 +179,14 @@ def run(self, dataset: DatasetT, **kwargs) -> Self: pbar.set_description_str(f"{FIT_MSG: <16} vol. <{i}>") # fit the model - test_set = dataset[i] predicted = model.fit_predict( # type: ignore[union-attr] i, - n_jobs=n_jobs, + **fit_pred_kwargs, ) # prepare data for running ANTs predicted_path, volume_path, init_path = _prepare_registration_data( - test_set[0], + dataset[i][0], # Access the target volume predicted, dataset.affine, i, diff --git a/src/nifreeze/model/dmri.py b/src/nifreeze/model/dmri.py index ca8a78fa6..e2b16db3a 100644 --- a/src/nifreeze/model/dmri.py +++ b/src/nifreeze/model/dmri.py @@ -22,6 +22,7 @@ # from importlib import import_module +from typing import Any import numpy as np from dipy.core.gradients import gradient_table_from_bvals_bvecs @@ -38,9 +39,8 @@ B_MIN = 50 -def _exec_fit(model, data, chunk=None): - retval = model.fit(data) - return retval, chunk +def _exec_fit(model, data, chunk=None, **kwargs): + return model.fit(data, **kwargs), chunk def _exec_predict(model, chunk=None, **kwargs): @@ -104,7 +104,7 @@ def __init__(self, dataset: DWI, max_b: float | int | None = None, **kwargs): super().__init__(dataset, **kwargs) - def _fit(self, index: int | None = None, n_jobs=None, **kwargs): + def _fit(self, index: int | None = None, n_jobs: int | None = None, **kwargs): """Fit the model chunk-by-chunk asynchronously""" n_jobs = n_jobs or 1 @@ -136,9 +136,18 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs): class_name, )(gtab, **kwargs) + fit_kwargs: dict[str, Any] = {} # Add here keyword arguments + + is_dki = model_str == "dipy.reconst.dki.DiffusionKurtosisModel" + # One single CPU - linear execution (full model) - if n_jobs == 1: - _modelfit, _ = _exec_fit(model, data) + # DKI model does not allow parallelization as implemented here + if n_jobs == 1 or is_dki: + _modelfit, _ = _exec_fit(model, data, **fit_kwargs) + self._models = [_modelfit] + return 1 + elif is_dki: + _modelfit = model.multi_fit(data, **fit_kwargs) self._models = [_modelfit] return 1 @@ -150,7 +159,8 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs): # Parallelize process with joblib with Parallel(n_jobs=n_jobs) as executor: results = executor( - delayed(_exec_fit)(model, dchunk, i) for i, dchunk in enumerate(data_chunks) + delayed(_exec_fit)(model, dchunk, i, **fit_kwargs) + for i, dchunk in enumerate(data_chunks) ) for submodel, rindex in results: self._models[rindex] = submodel @@ -168,6 +178,7 @@ def fit_predict(self, index: int | None = None, **kwargs): """ + kwargs.pop("omp_nthreads", None) # Drop omp_nthreads n_models = self._fit( index, n_jobs=kwargs.pop("n_jobs"), diff --git a/test/test_data_base.py b/test/test_data_base.py index bda22797c..a96abaa14 100644 --- a/test/test_data_base.py +++ b/test/test_data_base.py @@ -32,13 +32,16 @@ from nifreeze.data import NFDH5_EXT, BaseDataset, load +DEFAULT_RANDOM_DATASET_SHAPE = (32, 32, 32, 5) +DEFAULT_RANDOM_DATASET_SIZE = int(np.prod(DEFAULT_RANDOM_DATASET_SHAPE[:3])) + @pytest.fixture -def random_dataset(request) -> BaseDataset: +def random_dataset(request, size=DEFAULT_RANDOM_DATASET_SHAPE) -> BaseDataset: """Create a BaseDataset with random data for testing.""" rng = request.node.rng - data = rng.random((32, 32, 32, 5)).astype(np.float32) + data = rng.random(size).astype(np.float32) affine = np.eye(4, dtype=np.float32) return BaseDataset(dataobj=data, affine=affine) @@ -47,8 +50,10 @@ def test_base_dataset_init(random_dataset: BaseDataset): """Test that the BaseDataset can be initialized with random data.""" assert random_dataset.dataobj is not None assert random_dataset.affine is not None - assert random_dataset.dataobj.shape == (32, 32, 32, 5) + assert random_dataset.dataobj.shape == DEFAULT_RANDOM_DATASET_SHAPE assert random_dataset.affine.shape == (4, 4) + assert random_dataset.size3d == DEFAULT_RANDOM_DATASET_SIZE + assert random_dataset.shape3d == DEFAULT_RANDOM_DATASET_SHAPE[:3] def test_len(random_dataset: BaseDataset):