Skip to content

Commit 5e23050

Browse files
committed
enh: update CLI and estimator to enable single fit
1 parent 3dd2b65 commit 5e23050

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

src/nifreeze/cli/run.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,11 @@ def main(argv=None) -> None:
6969

7070
prev_model: Estimator | None = None
7171
for _model in args.models:
72+
single_fit = _model.lower().startswith("single")
7273
estimator: Estimator = Estimator(
73-
_model,
74+
_model.lower().replace("single", ""),
7475
prev=prev_model,
76+
single_fit=single_fit,
7577
)
7678
prev_model = estimator
7779

src/nifreeze/estimator.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,21 @@ def run(self, dataset: DatasetT, **kwargs) -> DatasetT:
7070
class Estimator:
7171
"""Orchestrates components for a single estimation step."""
7272

73-
__slots__ = ("_model", "_strategy", "_prev", "_model_kwargs", "_align_kwargs")
73+
__slots__ = ("_model", "_single_fit", "_strategy", "_prev", "_model_kwargs", "_align_kwargs")
7474

7575
def __init__(
7676
self,
7777
model: BaseModel | str,
7878
strategy: str = "random",
7979
prev: Estimator | Filter | None = None,
8080
model_kwargs: dict | None = None,
81+
single_fit: bool = False,
8182
**kwargs,
8283
):
8384
self._model = model
8485
self._prev = prev
8586
self._strategy = strategy
87+
self._single_fit = single_fit
8688
self._model_kwargs = model_kwargs or {}
8789
self._align_kwargs = kwargs or {}
8890

@@ -115,11 +117,16 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
115117
# Initialize model
116118
if isinstance(self._model, str):
117119
# Factory creates the appropriate model and pipes arguments
118-
self._model = ModelFactory.init(
120+
model = ModelFactory.init(
119121
model=self._model,
120122
dataset=dataset,
121123
**self._model_kwargs,
122124
)
125+
else:
126+
model = self._model
127+
128+
if self._single_fit:
129+
model.fit_predict(None, n_jobs=n_jobs)
123130

124131
kwargs["num_threads"] = kwargs.pop("omp_nthreads", None) or kwargs.pop("num_threads", None)
125132
kwargs = self._align_kwargs | kwargs
@@ -145,7 +152,7 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
145152

146153
# fit the model
147154
test_set = dataset[i]
148-
predicted = self._model.fit_predict( # type: ignore[union-attr]
155+
predicted = model.fit_predict( # type: ignore[union-attr]
149156
i,
150157
n_jobs=n_jobs,
151158
)

0 commit comments

Comments
 (0)