Skip to content

Commit f573915

Browse files
authored
Merge branch 'main' into flaml-fix-groups-issue-for-regression-tasks
2 parents d2e7d46 + 42d1dcf commit f573915

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

flaml/automl/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2066,8 +2066,8 @@ def __init__(
20662066
self.estimator_class = CatBoostRegressor
20672067

20682068
def fit(self, X_train, y_train, budget=None, free_mem_ratio=0, **kwargs):
2069-
if "is_retrain" in kwargs:
2070-
kwargs.pop("is_retrain")
2069+
kwargs.pop("is_retrain", None)
2070+
kwargs.pop("groups", None)
20712071
start_time = time.time()
20722072
deadline = start_time + budget if budget else np.inf
20732073
train_dir = f"catboost_{str(start_time)}"

test/automl/test_split.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def test_groups_for_classification_task():
6969
"model_history": True,
7070
"eval_method": "cv",
7171
"groups": np.random.randint(low=0, high=10, size=len(y)),
72-
"estimator_list": ["lgbm", "rf", "xgboost", "kneighbor"],
72+
"estimator_list": ["catboost", "lgbm", "rf", "xgboost", "kneighbor"],
7373
"learner_selector": "roundrobin",
7474
}
7575
automl.fit(X, y, **automl_settings)
@@ -138,6 +138,7 @@ def test_stratified_groupkfold():
138138
"split_type": splitter,
139139
"groups": X_train["Airline"],
140140
"estimator_list": [
141+
"catboost",
141142
"lgbm",
142143
"rf",
143144
"xgboost",

0 commit comments

Comments
 (0)