33import pandas as pd
44import sklearn
55from sklearn .base import ClassifierMixin
6+ import torch
67
78from pytabkit import TabM_D_Classifier , RealMLP_HPO_Classifier , Ensemble_HPO_Classifier , TabM_HPO_Regressor , \
89 TabM_HPO_Classifier , LGBM_HPO_Classifier , CatBoost_HPO_Classifier , XGB_HPO_Classifier , Ensemble_HPO_Regressor , \
910 LGBM_HPO_TPE_Regressor , RealMLP_TD_Regressor , RealMLP_HPO_Regressor , TabM_D_Regressor
1011
1112
1213@pytest .mark .parametrize ('estimator' , [
13- RealMLP_TD_Regressor (n_cv = 2 , n_refit = 2 , n_repeats = 2 , device = 'cpu' ),
14- RealMLP_HPO_Regressor (device = 'cpu' , n_hyperopt_steps = 2 , train_metric_name = 'multi_pinball(0.1,0.9)' ,
14+ RealMLP_TD_Regressor (n_cv = 2 , n_refit = 2 , n_repeats = 2 ),
15+ RealMLP_HPO_Regressor (n_hyperopt_steps = 2 , train_metric_name = 'multi_pinball(0.1,0.9)' ,
1516 val_metric_name = 'multi_pinball(0.1,0.9)' ),
16- TabM_D_Classifier (val_metric_name = 'cross_entropy' , num_emb_type = 'pwl' , tabm_k = 16 , device = 'cpu' , random_state = 0 ),
17- TabM_D_Regressor (val_metric_name = 'cross_entropy' , num_emb_type = 'pwl' , tabm_k = 16 , device = 'cpu' , random_state = 0 ),
18- TabM_HPO_Regressor (val_metric_name = 'mae' , n_hyperopt_steps = 2 , hpo_space_name = 'tabarena' , device = 'cpu' ,
17+ TabM_D_Classifier (val_metric_name = 'cross_entropy' , num_emb_type = 'pwl' , tabm_k = 16 , random_state = 0 ),
18+ TabM_D_Regressor (val_metric_name = 'cross_entropy' , num_emb_type = 'pwl' , tabm_k = 16 , random_state = 0 ),
19+ TabM_HPO_Regressor (val_metric_name = 'mae' , n_hyperopt_steps = 2 , hpo_space_name = 'tabarena' ,
1920 random_state = 0 ),
20- TabM_HPO_Classifier (val_metric_name = 'mae' , n_hyperopt_steps = 2 , hpo_space_name = 'default' , device = 'cpu' ,
21+ TabM_HPO_Classifier (val_metric_name = 'mae' , n_hyperopt_steps = 2 , hpo_space_name = 'default' ,
2122 random_state = 0 , use_caruana_ensembling = True ),
23+ # use CPU since GPU might not support some features in the search space (it has problems with rsm for catboost)
2224 LGBM_HPO_Classifier (use_caruana_ensembling = True , n_hyperopt_steps = 2 , hpo_space_name = 'tabarena' , device = 'cpu' ),
2325 XGB_HPO_Classifier (use_caruana_ensembling = True , n_hyperopt_steps = 2 , hpo_space_name = 'tabarena' , device = 'cpu' ),
2426 CatBoost_HPO_Classifier (use_caruana_ensembling = True , n_hyperopt_steps = 2 , hpo_space_name = 'tabarena' , device = 'cpu' ),
2527 RealMLP_HPO_Classifier (val_metric_name = 'cross_entropy' , n_hyperopt_steps = 3 , use_caruana_ensembling = True ,
26- hpo_space_name = 'tabarena' , n_caruana_steps = 10 , random_state = 0 , device = 'cpu' ),
27- Ensemble_HPO_Classifier (val_metric_name = 'brier' , device = 'cpu' , n_hpo_steps = 2 , use_full_caruana_ensembling = True ,
28+ hpo_space_name = 'tabarena' , n_caruana_steps = 10 , random_state = 0 ),
29+ Ensemble_HPO_Classifier (val_metric_name = 'brier' , n_hpo_steps = 2 , use_full_caruana_ensembling = True ,
2830 use_tabarena_spaces = True ),
29- Ensemble_HPO_Regressor (val_metric_name = 'brier' , device = 'cpu' , n_hpo_steps = 2 , use_full_caruana_ensembling = True ,
31+ Ensemble_HPO_Regressor (val_metric_name = 'brier' , n_hpo_steps = 2 , use_full_caruana_ensembling = True ,
3032 use_tabarena_spaces = True ),
3133 LGBM_HPO_TPE_Regressor (n_cv = 2 , n_refit = 2 , n_hyperopt_steps = 2 ),
3234])
@@ -37,7 +39,9 @@ def test_sklearn_not_crash(estimator):
3739 X ['b' ] = X ['b' ].astype ('category' )
3840
3941 est = sklearn .base .clone (estimator )
40- est .device = 'cpu'
42+ if not torch .cuda .is_available ():
43+ # don't use mps even if it's available
44+ est .device = 'cpu'
4145 if isinstance (est , ClassifierMixin ):
4246 y = np .random .randint (3 , size = (n_train ,))
4347 else :
0 commit comments