3333from bluecast .experimentation .tracking import ExperimentTracker
3434from bluecast .general_utils .general_utils import save_out_of_fold_data
3535from bluecast .ml_modelling .catboost import CatboostModel
36- from bluecast .ml_modelling .xgboost import XgboostModel
3736from bluecast .preprocessing .category_encoder_orchestration import (
3837 CategoryEncoderOrchestrator ,
3938)
@@ -69,7 +68,7 @@ class BlueCast:
6968 BlueCast will infer these automatically.
7069 :param :time_split_column: Takes a string containing the name of the time split column. If not provided,
7170 BlueCast will not split the data by time or order, but do a random split instead.
72- :param :ml_model: Takes an instance of a XgboostModel class. If not provided, BlueCast will instantiate one.
71+ :param :ml_model: Takes an instance of a CatboostModel class. If not provided, BlueCast will instantiate one.
7372 This is an API to pass any model class. Inherit the baseclass from ml_modelling.base_model.BaseModel.
7473 :param custom_in_fold_preprocessor: Takes an instance of a CustomPreprocessing class. Allows users to eeecute
7574 preprocessing after the train test split within cv folds. This will be executed only if precise_cv_tuning in
@@ -92,7 +91,7 @@ def __init__(
9291 cat_columns : Optional [List [Union [str , float , int ]]] = None ,
9392 date_columns : Optional [List [Union [str , float , int ]]] = None ,
9493 time_split_column : Optional [str ] = None ,
95- ml_model : Optional [Union [XgboostModel , Any ]] = None ,
94+ ml_model : Optional [Union [CatboostModel , Any ]] = None ,
9695 custom_in_fold_preprocessor : Optional [CustomPreprocessing ] = None ,
9796 custom_last_mile_computation : Optional [CustomPreprocessing ] = None ,
9897 custom_preprocessor : Optional [CustomPreprocessing ] = None ,
@@ -132,7 +131,7 @@ def __init__(
132131 self .target_label_encoder : Optional [TargetLabelEncoder ] = None
133132 self .schema_detector : Optional [SchemaDetector ] = None
134133 self .date_part_extractor : Optional [DatePartExtractor ] = None
135- self .ml_model : Optional [XgboostModel ] = ml_model
134+ self .ml_model : Optional [CatboostModel ] = ml_model
136135 self .custom_in_fold_preprocessor = custom_in_fold_preprocessor
137136 self .custom_last_mile_computation = custom_last_mile_computation
138137 self .custom_preprocessor = custom_preprocessor
@@ -149,25 +148,21 @@ def __init__(
149148 self .experiment_tracker = ExperimentTracker ()
150149
151150 if not self .conf_params_xgboost :
152- self .conf_params_xgboost = XgboostFinalParamConfig ()
151+ self .conf_params_xgboost = CatboostFinalParamConfig ()
153152
154153 self .conf_training : TrainingConfig = conf_training or TrainingConfig ()
155154
156155 if not self .conf_xgboost :
157- self .conf_xgboost = XgboostTuneParamsConfig ()
158-
156+ self .conf_xgboost = CatboostTuneParamsConfig ()
159157 if not self .single_fold_eval_metric_func :
160158 self .single_fold_eval_metric_func = ClassificationEvalWrapper ()
161-
162159 logging .basicConfig (
163160 filename = self .conf_training .logging_file_path ,
164161 filemode = "w" ,
165162 format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" ,
166163 level = logging .INFO ,
167- # stream=sys.stdout,
168164 force = True ,
169165 )
170-
171166 logging .info ("BlueCast blueprint initialized." )
172167
173168 def initial_checks (self , df : pd .DataFrame ) -> None :
@@ -197,12 +192,14 @@ def initial_checks(self, df: pd.DataFrame) -> None:
197192 many features have been removed. Otherwise, consider disabling feature selection or providing a custom
198193 feature selector."""
199194 warnings .warn (message , UserWarning , stacklevel = 2 )
195+
200196 if not self .conf_xgboost :
201- message = """No XgboostTuneParamsConfig has been provided. Falling back to default values. Default values
197+ message = """No CatboostTuneParamsConfig has been provided. Falling back to default values. Default values
202198 have been chosen to speed up the prototyping. For robust hyperparameter tuning consider providing a custom
203- XgboostTuneParamsConfig with a deeper hyperparameter search space and a custom TrainingConfig to enable
199+ CatboostTuneParamsConfig with a deeper hyperparameter search space and a custom TrainingConfig to enable
204200 cross-validation."""
205201 warnings .warn (message , UserWarning , stacklevel = 2 )
202+
206203 if (
207204 self .conf_training .cat_encoding_via_ml_algorithm
208205 and self .conf_training .calculate_shap_values
@@ -217,11 +214,13 @@ def initial_checks(self, df: pd.DataFrame) -> None:
217214 required. Alternatively use Xgboost as a custom model and calculate shap values manually via
218215 pred_contribs=True."""
219216 warnings .warn (message , UserWarning , stacklevel = 2 )
217+
220218 if self .conf_training .cat_encoding_via_ml_algorithm and self .ml_model :
221219 message = """Categorical encoding via ML algorithm is enabled. Make sure to handle categorical features
222220 within the provided ml model or consider disabling categorical encoding via ML algorithm in the
223221 TrainingConfig alternatively."""
224222 warnings .warn (message , UserWarning , stacklevel = 2 )
223+
225224 if (
226225 self .conf_training .cat_encoding_via_ml_algorithm
227226 and self .custom_last_mile_computation
@@ -230,11 +229,13 @@ def initial_checks(self, df: pd.DataFrame) -> None:
230229 within the provided last mile computation or consider disabling categorical encoding via ML algorithm in the
231230 TrainingConfig alternatively."""
232231 warnings .warn (message , UserWarning , stacklevel = 2 )
232+
233233 if self .conf_training .precise_cv_tuning :
234234 message = """Precise fine tuning has been enabled. Please make sure to transform your data to a normal
235235 distribution (yeo-johnson). This is an experimental feature as it includes a special
236236 evaluation (see more in the docs). If you plan to use this feature, please make sure to read the docs."""
237237 warnings .warn (message , UserWarning , stacklevel = 2 )
238+
238239 if (
239240 self .conf_training .precise_cv_tuning
240241 and not self .custom_in_fold_preprocessor
@@ -245,6 +246,7 @@ def initial_checks(self, df: pd.DataFrame) -> None:
245246 using precise_cv_tuning. Otherwise disable precise_cv_tuning to benefit from early pruning of unpromising
246247 hyperparameter sets."""
247248 warnings .warn (message , UserWarning , stacklevel = 2 )
249+
248250 if (
249251 self .conf_training .precise_cv_tuning
250252 and self .conf_training .hypertuning_cv_folds < 2
@@ -253,21 +255,21 @@ def initial_checks(self, df: pd.DataFrame) -> None:
253255 less than 2 folds precise_cv_tuning will not have any impact. Consider raising the number of folds to two
254256 or higher or disable precise_cv_tuning."""
255257 warnings .warn (message , UserWarning , stacklevel = 2 )
258+
256259 if self .class_problem == "binary" and df [self .target_column ].nunique () > 2 :
257260 message = """During class instantiation class_problem = 'binary' has been passed. However more than 2
258261 unique target classes have been found. Did you mean 'multiclass' instead?"""
259262 warnings .warn (message , UserWarning , stacklevel = 2 )
263+
260264 if self .class_problem == "multiclass" and df [self .target_column ].nunique () < 3 :
261265 message = """During class instantiation class_problem = 'multiclass' has been passed. However less than 3
262266 unique target classes have been found. Did you mean 'binary' instead?"""
263267 warnings .warn (message , UserWarning , stacklevel = 2 )
264268
265269 if self .conf_xgboost and isinstance (self .conf_xgboost , XgboostTuneParamsConfig ):
266- if (
267- self .conf_training .cat_encoding_via_ml_algorithm
268- and "exact" in self .conf_xgboost .tree_method
269- ):
270- self .conf_xgboost .tree_method .remove ("exact" )
270+ if self .conf_training .cat_encoding_via_ml_algorithm :
271+ if "exact" in self .conf_xgboost .tree_method :
272+ self .conf_xgboost .tree_method .remove ("exact" )
271273 message = f"""Categorical encoding via ML algorithm is enabled. The tree method 'exact' is not supported with categorical encoding within Xgboost. The tree method 'exact' has been removed. Using { self .conf_xgboost .tree_method } only during hyperparameter tuning."""
272274 warnings .warn (message , UserWarning , stacklevel = 2 )
273275
@@ -432,40 +434,44 @@ def fit(self, df: pd.DataFrame, target_col: str) -> None:
432434 )
433435
434436 if not self .ml_model :
435- self .ml_model = XgboostModel (
437+ self .ml_model = CatboostModel (
436438 self .class_problem ,
437439 conf_training = self .conf_training ,
438- conf_xgboost = (
440+ conf_catboost = (
439441 self .conf_xgboost
440- if isinstance (self .conf_xgboost , XgboostTuneParamsConfig )
441- else XgboostTuneParamsConfig ()
442+ if isinstance (self .conf_xgboost , CatboostTuneParamsConfig )
443+ else CatboostTuneParamsConfig ()
442444 ),
443- conf_params_xgboost = (
445+ conf_params_catboost = (
444446 self .conf_params_xgboost
445- if isinstance (self .conf_params_xgboost , XgboostFinalParamConfig )
446- else XgboostFinalParamConfig ()
447+ if isinstance (self .conf_params_xgboost , CatboostFinalParamConfig )
448+ else CatboostFinalParamConfig ()
447449 ),
448450 experiment_tracker = self .experiment_tracker ,
449451 custom_in_fold_preprocessor = self .custom_in_fold_preprocessor ,
450452 cat_columns = self .cat_columns ,
451453 single_fold_eval_metric_func = self .single_fold_eval_metric_func ,
452454 )
453455
454- if not getattr (self .ml_model , "cat_columns" , None ):
455- self .ml_model .experiment_tracker = self .experiment_tracker
456- self .ml_model .custom_in_fold_preprocessor = self .custom_in_fold_preprocessor
457- self .ml_model .cat_columns = [
458- col
459- for col in self .feat_type_detector .cat_columns
460- if col != self .target_column
461- ]
462- if self .single_fold_eval_metric_func is not None :
463- self .ml_model .single_fold_eval_metric_func = (
464- self .single_fold_eval_metric_func
465- )
466- self .ml_model .conf_training = self .conf_training
467- if isinstance (self .ml_model , CatboostModel ):
456+ # Always override model wiring based on detected schema; ensure target is excluded
457+ self .ml_model .experiment_tracker = self .experiment_tracker
458+ self .ml_model .custom_in_fold_preprocessor = self .custom_in_fold_preprocessor
459+ self .ml_model .cat_columns = [
460+ col
461+ for col in self .feat_type_detector .cat_columns
462+ if col != self .target_column
463+ ]
464+ if self .single_fold_eval_metric_func is not None :
465+ self .ml_model .single_fold_eval_metric_func = (
466+ self .single_fold_eval_metric_func
467+ )
468+ self .ml_model .conf_training = self .conf_training
469+ if isinstance (self .ml_model , CatboostModel ):
470+ # Ensure CatBoost final params config exists and is of correct type
471+ if isinstance (self .conf_params_xgboost , CatboostFinalParamConfig ):
468472 self .ml_model .conf_params_catboost = self .conf_params_xgboost
473+ else :
474+ self .ml_model .conf_params_catboost = CatboostFinalParamConfig ()
469475
470476 self .ml_model .fit (x_train , x_test , y_train , y_test )
471477
@@ -524,8 +530,15 @@ def fit_eval(
524530 if not self .conf_training :
525531 raise ValueError ("Could not find any training config" )
526532
527- if not self .conf_params_xgboost :
528- raise ValueError ("Could not find Xgboost params" )
533+ # Ensure final params exist depending on model backend
534+ if isinstance (self .ml_model , CatboostModel ):
535+ if not getattr (self .ml_model , "conf_params_catboost" , None ):
536+ raise ValueError ("Could not find CatBoost params" )
537+ final_params_for_log = self .ml_model .conf_params_catboost .params
538+ else :
539+ if not self .conf_params_xgboost :
540+ raise ValueError ("Could not find Xgboost params" )
541+ final_params_for_log = self .conf_params_xgboost .params
529542
530543 if len (self .experiment_tracker .experiment_id ) == 0 :
531544 self .experiment_tracker .experiment_id .append (0 )
@@ -564,7 +577,7 @@ def fit_eval(
564577 experiment_id = experiment_id ,
565578 score_category = "oof_score" ,
566579 training_config = self .conf_training ,
567- model_parameters = self . conf_params_xgboost . params , # noqa
580+ model_parameters = final_params_for_log , # noqa
568581 eval_scores = self .eval_metrics ["accuracy" ],
569582 metric_used = metric ,
570583 metric_higher_is_better = higher_is_better ,
0 commit comments