@@ -70,19 +70,21 @@ def run(self, dataset: DatasetT, **kwargs) -> DatasetT:
7070class Estimator :
7171 """Orchestrates components for a single estimation step."""
7272
73- __slots__ = ("_model" , "_strategy" , "_prev" , "_model_kwargs" , "_align_kwargs" )
73+ __slots__ = ("_model" , "_single_fit" , " _strategy" , "_prev" , "_model_kwargs" , "_align_kwargs" )
7474
7575 def __init__ (
7676 self ,
7777 model : BaseModel | str ,
7878 strategy : str = "random" ,
7979 prev : Estimator | Filter | None = None ,
8080 model_kwargs : dict | None = None ,
81+ single_fit : bool = False ,
8182 ** kwargs ,
8283 ):
8384 self ._model = model
8485 self ._prev = prev
8586 self ._strategy = strategy
87+ self ._single_fit = single_fit
8688 self ._model_kwargs = model_kwargs or {}
8789 self ._align_kwargs = kwargs or {}
8890
@@ -115,11 +117,16 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
115117 # Initialize model
116118 if isinstance (self ._model , str ):
117119 # Factory creates the appropriate model and pipes arguments
118- self . _model = ModelFactory .init (
120+ model = ModelFactory .init (
119121 model = self ._model ,
120122 dataset = dataset ,
121123 ** self ._model_kwargs ,
122124 )
125+ else :
126+ model = self ._model
127+
128+ if self ._single_fit :
129+ model .fit_predict (None , n_jobs = n_jobs )
123130
124131 kwargs ["num_threads" ] = kwargs .pop ("omp_nthreads" , None ) or kwargs .pop ("num_threads" , None )
125132 kwargs = self ._align_kwargs | kwargs
@@ -145,7 +152,7 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
145152
146153 # fit the model
147154 test_set = dataset [i ]
148- predicted = self . _model .fit_predict ( # type: ignore[union-attr]
155+ predicted = model .fit_predict ( # type: ignore[union-attr]
149156 i ,
150157 n_jobs = n_jobs ,
151158 )
0 commit comments