11from __future__ import annotations
22
3- from abc import ABC , abstractmethod
3+ from abc import ABC
44from typing import Any , Generic , Iterable , Sequence , TypeVar
55
66import warnings
2121)
2222from sklearn .utils .multiclass import type_of_target
2323from smac .runhistory .runhistory import RunInfo , RunValue
24- from typing_extensions import Literal , TypeAlias
24+ from typing_extensions import Literal
2525
2626from autosklearn .automl import AutoML , AutoMLClassifier , AutoMLRegressor
2727from autosklearn .data .validation import convert_if_sparse
3131from autosklearn .pipeline .base import BasePipeline
3232from autosklearn .util .smac_wrap import SMACCallback
3333
34- # Used to indicate what type the underlying AutoML instance is
35- TAutoML = TypeVar ("TAutoML" , bound = AutoML )
36- TParetoModel = TypeVar ("TParetoModel" , VotingClassifier , VotingRegressor )
37-
3834# Used to return self and give correct type information from subclasses,
3935# see `fit(self: Self) -> Self`
4036Self = TypeVar ("Self" , bound = "AutoSklearnEstimator" )
4137
42- ResampleOptions : TypeAlias = Literal [
43- "holdout" ,
44- "cv" ,
45- "holdout-iterative-fit" ,
46- "cv-iterative-fit" ,
47- "partial-cv" ,
48- ]
49- DisableEvaluatorOptions : TypeAlias = Literal ["y_optimization" , "model" ]
50-
38+ # Used to indicate what type the underlying AutoML instance is
39+ TParetoModel = TypeVar ("TParetoModel" , VotingClassifier , VotingRegressor )
40+ TAutoML = TypeVar ("TAutoML" , bound = AutoML )
5141
52- class AutoSklearnEstimator (ABC , Generic [TAutoML , TParetoModel ], BaseEstimator ):
5342
54- # List of target types supported by the estimator class
55- supported_target_types : list [str ]
43+ class AutoSklearnEstimator (ABC , BaseEstimator , Generic [TAutoML , TParetoModel ]):
5644
57- # The automl class used by the estimator class
58- _automl_class : type [TAutoML ]
45+ supported_target_types : list [ str ] # Support output types for the estimator
46+ _automl_class : type [TAutoML ] # The automl class used by the estimator class
5947
6048 def __init__ (
6149 self ,
50+ * ,
6251 time_left_for_this_task : int = 3600 ,
6352 per_run_time_limit : int | None = None , # TODO: allow percentage
6453 initial_configurations_via_metalearning : int = 25 , # TODO validate
@@ -71,7 +60,9 @@ def __init__(
7160 memory_limit : int | None = 3072 ,
7261 include : dict [str , list [str ]] | None = None ,
7362 exclude : dict [str , list [str ]] | None = None ,
74- resampling_strategy : ResampleOptions
63+ resampling_strategy : Literal [
64+ "holdout" , "cv" , "holdout-iterative-fit" , "cv-iterative-fit" , "partial-cv"
65+ ]
7566 | BaseCrossValidator
7667 | _RepeatedSplits
7768 | BaseShuffleSplit = "holdout" ,
@@ -81,7 +72,7 @@ def __init__(
8172 n_jobs : int = 1 ,
8273 dask_client : dask .distributed .Client | None = None ,
8374 disable_evaluator_output : bool
84- | Sequence [DisableEvaluatorOptions ] = False , # TODO fill in
75+ | Sequence [Literal [ "y_optimization" , "model" ]] = False , # TODO: fill in
8576 get_smac_object_callback : SMACCallback | None = None ,
8677 smac_scenario_args : dict [str , Any ] | None = None ,
8778 logging_config : dict [str , Any ] | None = None ,
@@ -490,7 +481,7 @@ def __init__(
490481 self .allow_string_features = allow_string_features
491482
492483 # Cached
493- self .automl_ : AutoML | None = None
484+ self .automl_ : TAutoML | None = None
494485
495486 # Handle the number of jobs and the time for them
496487 # Made private by `_n_jobs` to keep with sklearn compliance
@@ -504,21 +495,19 @@ def __init__(
504495 self .per_run_time_limit = self ._n_jobs * self .time_left_for_this_task // 10
505496
506497 @property
507- @abstractmethod
508498 def automl (self ) -> TAutoML :
509499 """Get the underlying Automl instance
510500
511501 Returns
512502 -------
513503 AutoML
514- The underlying AutoML instanec
504+ The underlying AutoML instance
515505 """
516506 if self .automl_ is not None :
517507 return self .automl_
518508
519509 initial_configs = self .initial_configurations_via_metalearning
520- cls = self ._get_automl_class ()
521- automl = cls (
510+ automl = self ._automl_class (
522511 temporary_directory = self .tmp_folder ,
523512 delete_tmp_folder_after_terminate = self .delete_tmp_folder_after_terminate ,
524513 time_left_for_this_task = self .time_left_for_this_task ,
@@ -568,16 +557,14 @@ def ensemble(self) -> AbstractEnsemble:
568557 NotFittedError
569558 If there this estimator has not been fitted
570559 """
571-
572- def __getstate__ (self ) -> dict [str , Any ]:
573- # Cannot serialize a client!
574- self .dask_client = None
575- return self .__dict__
560+ # TODO
561+ raise NotImplementedError ()
576562
577563 def fit (
578564 self : Self ,
579565 X : np .ndarray | pd .DataFrame | list | spmatrix ,
580566 y : np .ndarray | pd .DataFrame | pd .Series | list ,
567+ * ,
581568 X_test : np .ndarray | pd .DataFrame | list | spmatrix | None = None ,
582569 y_test : np .ndarray | pd .DataFrame | pd .Series | list | None = None ,
583570 feat_type : list [str ] | None = None ,
@@ -697,6 +684,7 @@ def fit_pipeline(
697684 self ,
698685 X : np .ndarray | pd .DataFrame | list | spmatrix ,
699686 y : np .ndarray | pd .DataFrame | pd .Series | list ,
687+ * ,
700688 config : Configuration | dict [str , Any ],
701689 dataset_name : str | None = None ,
702690 X_test : np .ndarray | pd .DataFrame | list | spmatrix | None = None ,
@@ -767,6 +755,7 @@ def fit_pipeline(
767755 def fit_ensemble (
768756 self : Self ,
769757 y : np .ndarray | pd .DataFrame | pd .Series | list ,
758+ * ,
770759 task : int | None = None ,
771760 precision : Literal [16 , 32 , 64 ] = 32 ,
772761 dataset_name : str | None = None ,
@@ -913,6 +902,7 @@ def refit(
913902 def predict (
914903 self ,
915904 X : np .ndarray | pd .DataFrame | list | spmatrix ,
905+ * ,
916906 batch_size : int | None = None ,
917907 n_jobs : int = 1 ,
918908 ) -> np .ndarray :
@@ -1088,6 +1078,7 @@ def sprint_statistics(self) -> str:
10881078
10891079 def leaderboard (
10901080 self ,
1081+ * ,
10911082 detailed : bool = False ,
10921083 ensemble_only : bool = True ,
10931084 top_k : int | Literal ["all" ] = "all" ,
@@ -1501,6 +1492,7 @@ def get_configuration_space(
15011492 self ,
15021493 X : np .ndarray | pd .DataFrame | list | spmatrix ,
15031494 y : np .ndarray | pd .DataFrame | pd .Series | list ,
1495+ * ,
15041496 X_test : np .ndarray | pd .DataFrame | list | spmatrix | None = None ,
15051497 y_test : np .ndarray | pd .DataFrame | pd .Series | list | None = None ,
15061498 dataset_name : str | None = None ,
@@ -1549,6 +1541,11 @@ def get_pareto_set(self) -> Sequence[TParetoModel]:
15491541 """
15501542 return self .automl ._load_pareto_set ()
15511543
1544+ def __getstate__ (self ) -> dict [str , Any ]:
1545+ # Cannot serialize a client!
1546+ self .dask_client = None
1547+ return self .__dict__
1548+
15521549 def __sklearn_is_fitted__ (self ) -> bool :
15531550 return self .automl_ is not None and self .automl .fitted
15541551
0 commit comments