Skip to content

Commit c038fbc

Browse files
dannycg1996Daniel Grindrodthinkall
authored
fix: KeyError no longer occurs when using groupfolds for regression tasks. (#1385)
* fix: Now resetting indexes for regression datasets when using group folds * refactor: Simplified if statement to include all fold types * docs: Updated docs to make it clear that group folds can be used for regression tasks --------- Co-authored-by: Daniel Grindrod <daniel.grindrod@evotec.com> Co-authored-by: Li Jiang <bnujli@gmail.com>
1 parent 6a99202 commit c038fbc

File tree

4 files changed

+39
-9
lines changed

4 files changed

+39
-9
lines changed

flaml/automl/automl.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def custom_metric(
203203
* Valid str options depend on different tasks.
204204
For classification tasks, valid choices are
205205
["auto", 'stratified', 'uniform', 'time', 'group']. "auto" -> stratified.
206-
For regression tasks, valid choices are ["auto", 'uniform', 'time'].
206+
For regression tasks, valid choices are ["auto", 'uniform', 'time', 'group'].
207207
"auto" -> uniform.
208208
For time series forecast tasks, must be "auto" or 'time'.
209209
For ranking task, must be "auto" or 'group'.
@@ -739,7 +739,7 @@ def retrain_from_log(
739739
* Valid str options depend on different tasks.
740740
For classification tasks, valid choices are
741741
["auto", 'stratified', 'uniform', 'time', 'group']. "auto" -> stratified.
742-
For regression tasks, valid choices are ["auto", 'uniform', 'time'].
742+
For regression tasks, valid choices are ["auto", 'uniform', 'time', 'group'].
743743
"auto" -> uniform.
744744
For time series forecast tasks, must be "auto" or 'time'.
745745
For ranking task, must be "auto" or 'group'.
@@ -1358,7 +1358,7 @@ def custom_metric(
13581358
* Valid str options depend on different tasks.
13591359
For classification tasks, valid choices are
13601360
["auto", 'stratified', 'uniform', 'time', 'group']. "auto" -> stratified.
1361-
For regression tasks, valid choices are ["auto", 'uniform', 'time'].
1361+
For regression tasks, valid choices are ["auto", 'uniform', 'time', 'group'].
13621362
"auto" -> uniform.
13631363
For time series forecast tasks, must be "auto" or 'time'.
13641364
For ranking task, must be "auto" or 'group'.

flaml/automl/task/generic_task.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,8 +442,8 @@ def prepare_data(
442442
X_train_all, y_train_all = shuffle(X_train_all, y_train_all, random_state=RANDOM_SEED)
443443
if data_is_df:
444444
X_train_all.reset_index(drop=True, inplace=True)
445-
if isinstance(y_train_all, pd.Series):
446-
y_train_all.reset_index(drop=True, inplace=True)
445+
if isinstance(y_train_all, pd.Series):
446+
y_train_all.reset_index(drop=True, inplace=True)
447447

448448
X_train, y_train = X_train_all, y_train_all
449449
state.groups_all = state.groups

flaml/automl/task/task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def prepare_data(
192192
* Valid str options depend on different tasks.
193193
For classification tasks, valid choices are
194194
["auto", 'stratified', 'uniform', 'time', 'group']. "auto" -> stratified.
195-
For regression tasks, valid choices are ["auto", 'uniform', 'time'].
195+
For regression tasks, valid choices are ["auto", 'uniform', 'time', 'group'].
196196
"auto" -> uniform.
197197
For time series forecast tasks, must be "auto" or 'time'.
198198
For ranking task, must be "auto" or 'group'.

test/automl/test_split.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from sklearn.datasets import fetch_openml
1+
import numpy as np
2+
from sklearn.datasets import fetch_openml, load_iris
23
from sklearn.metrics import accuracy_score
34
from sklearn.model_selection import GroupKFold, KFold, train_test_split
45

@@ -48,7 +49,7 @@ def test_time():
4849
_test(split_type="time")
4950

5051

51-
def test_groups():
52+
def test_groups_for_classification_task():
5253
from sklearn.externals._arff import ArffException
5354

5455
try:
@@ -88,6 +89,35 @@ def test_groups():
8889
automl.fit(X, y, **automl_settings)
8990

9091

92+
def test_groups_for_regression_task():
93+
"""Append nonsensical groups to iris dataset and use it to test that GroupKFold works for regression tasks"""
94+
iris_dict_data = load_iris(as_frame=True) # numpy arrays
95+
iris_data = iris_dict_data["frame"] # pandas dataframe data + target
96+
97+
rng = np.random.default_rng(42)
98+
iris_data["cluster"] = rng.integers(
99+
low=0, high=5, size=iris_data.shape[0]
100+
) # np.random.randint(0, 5, iris_data.shape[0])
101+
102+
automl = AutoML()
103+
X = iris_data[["sepal length (cm)", "sepal width (cm)", "petal length (cm)"]].to_numpy()
104+
y = iris_data["petal width (cm)"]
105+
X_train, X_test, y_train, y_test, groups_train, groups_test = train_test_split(
106+
X, y, iris_data["cluster"], random_state=42
107+
)
108+
automl_settings = {
109+
"max_iter": 5,
110+
"time_budget": -1,
111+
"metric": "r2",
112+
"task": "regression",
113+
"estimator_list": ["lgbm", "rf", "xgboost", "kneighbor"],
114+
"eval_method": "cv",
115+
"split_type": "uniform",
116+
"groups": groups_train,
117+
}
118+
automl.fit(X_train, y_train, **automl_settings)
119+
120+
91121
def test_stratified_groupkfold():
92122
from minio.error import ServerError
93123
from sklearn.model_selection import StratifiedGroupKFold
@@ -204,4 +234,4 @@ def get_n_splits(self, X=None, y=None, groups=None):
204234

205235

206236
if __name__ == "__main__":
207-
test_groups()
237+
test_groups_for_classification_task()

0 commit comments

Comments
 (0)