Skip to content

Commit 42d1dcf

Browse files
dannycg1996Daniel Grindrod
andauthored
fix: Fixed bug with catboost and groups (#1383)
Co-authored-by: Daniel Grindrod <daniel.grindrod@evotec.com>
1 parent b83c8a7 commit 42d1dcf

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

flaml/automl/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2066,8 +2066,8 @@ def __init__(
20662066
self.estimator_class = CatBoostRegressor
20672067

20682068
def fit(self, X_train, y_train, budget=None, free_mem_ratio=0, **kwargs):
2069-
if "is_retrain" in kwargs:
2070-
kwargs.pop("is_retrain")
2069+
kwargs.pop("is_retrain", None)
2070+
kwargs.pop("groups", None)
20712071
start_time = time.time()
20722072
deadline = start_time + budget if budget else np.inf
20732073
train_dir = f"catboost_{str(start_time)}"

test/automl/test_split.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def test_groups():
6868
"model_history": True,
6969
"eval_method": "cv",
7070
"groups": np.random.randint(low=0, high=10, size=len(y)),
71-
"estimator_list": ["lgbm", "rf", "xgboost", "kneighbor"],
71+
"estimator_list": ["catboost", "lgbm", "rf", "xgboost", "kneighbor"],
7272
"learner_selector": "roundrobin",
7373
}
7474
automl.fit(X, y, **automl_settings)
@@ -108,6 +108,7 @@ def test_stratified_groupkfold():
108108
"split_type": splitter,
109109
"groups": X_train["Airline"],
110110
"estimator_list": [
111+
"catboost",
111112
"lgbm",
112113
"rf",
113114
"xgboost",

0 commit comments

Comments
 (0)