-
Notifications
You must be signed in to change notification settings - Fork 4k
[python-package] scikit-learn fit() methods: add eval_X, eval_y, deprecate eval_set #6857
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5f27a57
79665bb
00022f7
d79f545
f8d2d13
69371c6
1e8ac7f
6f0a56d
11cf7b1
3b810a1
5fba410
b4c299d
656380b
6539816
1e622ea
d7e0fff
314ace1
e6790bc
6aa50ee
cd5aa82
4da786a
ed448b5
feafe22
2bab015
2924d68
e4ad4e4
e7422ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |||||||||
| """Scikit-learn wrapper interface for LightGBM.""" | ||||||||||
|
|
||||||||||
| import copy | ||||||||||
| import warnings | ||||||||||
| from inspect import signature | ||||||||||
| from pathlib import Path | ||||||||||
| from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union | ||||||||||
|
|
@@ -13,6 +14,7 @@ | |||||||||
| _MULTICLASS_OBJECTIVES, | ||||||||||
| Booster, | ||||||||||
| Dataset, | ||||||||||
| LGBMDeprecationWarning, | ||||||||||
| LightGBMError, | ||||||||||
| _choose_param_value, | ||||||||||
| _ConfigAliases, | ||||||||||
|
|
@@ -341,7 +343,9 @@ def __call__( | |||||||||
| For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups, | ||||||||||
| where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc. | ||||||||||
| eval_set : list or None, optional (default=None) | ||||||||||
| A list of (X, y) tuple pairs to use as validation sets. | ||||||||||
| .. deprecated:: 4.7.0 | ||||||||||
| A list of (X, y) tuple pairs to use as validation sets. | ||||||||||
| Use ``eval_X`` and ``eval_y`` instead. | ||||||||||
| eval_names : list of str, or None, optional (default=None) | ||||||||||
| Names of eval_set. | ||||||||||
| eval_sample_weight : {eval_sample_weight_shape} | ||||||||||
lorentzenchr marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
|
|
@@ -376,6 +380,10 @@ def __call__( | |||||||||
| See Callbacks in Python API for more information. | ||||||||||
| init_model : str, pathlib.Path, Booster, LGBMModel or None, optional (default=None) | ||||||||||
| Filename of LightGBM model, Booster instance or LGBMModel instance used for continue training. | ||||||||||
| eval_X : {X_shape}, or tuple of such inputs, or None, optional (default=None) | ||||||||||
jameslamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
| Feature matrix or tuple thereof, e.g. ``(X_val0, X_val1)``, to use as validation sets. | ||||||||||
| eval_y : {y_shape}, or tuple of such inputs, or None, optional (default=None) | ||||||||||
| Target values or tuple thereof, e.g. ``(y_val0, y_val1)``, to use as validation sets. | ||||||||||
|
|
||||||||||
| Returns | ||||||||||
| ------- | ||||||||||
|
|
@@ -485,6 +493,42 @@ def _extract_evaluation_meta_data( | |||||||||
| raise TypeError(f"{name} should be dict or list") | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def _validate_eval_set_Xy( | ||||||||||
| *, | ||||||||||
| eval_set: Optional[List[_LGBM_ScikitValidSet]], | ||||||||||
| eval_X: Optional[Union[_LGBM_ScikitMatrixLike, Tuple[_LGBM_ScikitMatrixLike]]], | ||||||||||
| eval_y: Optional[Union[_LGBM_LabelType, Tuple[_LGBM_LabelType]]], | ||||||||||
| ) -> Optional[List[_LGBM_ScikitValidSet]]: | ||||||||||
| """Validate eval args. | ||||||||||
|
|
||||||||||
| Returns | ||||||||||
| ------- | ||||||||||
| eval_set | ||||||||||
| """ | ||||||||||
| if eval_set is not None: | ||||||||||
| msg = "The argument 'eval_set' is deprecated, use 'eval_X' and 'eval_y' instead." | ||||||||||
| warnings.warn(msg, category=LGBMDeprecationWarning, stacklevel=2) | ||||||||||
| if eval_X is not None or eval_y is not None: | ||||||||||
| raise ValueError("Specify either 'eval_set' or 'eval_X' and 'eval_y', but not both.") | ||||||||||
| if isinstance(eval_set, tuple): | ||||||||||
| return [eval_set] | ||||||||||
| else: | ||||||||||
| return eval_set | ||||||||||
|
Comment on lines
+513
to
+516
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (changed by me in d7e0fff) Although providing something like LightGBM/python-package/lightgbm/sklearn.py Line 914 in 544d439
... it has been supported in LightGBM/python-package/lightgbm/sklearn.py Lines 990 to 992 in 544d439
Adding this line preserves that behavior. Changing the existing behavior when |
||||||||||
| if (eval_X is None) != (eval_y is None): | ||||||||||
| raise ValueError("You must specify eval_X and eval_y, not just one of them.") | ||||||||||
| if eval_set is None and eval_X is not None: | ||||||||||
| if isinstance(eval_X, tuple) != isinstance(eval_y, tuple): | ||||||||||
| raise ValueError("If eval_X is a tuple, y_val must be a tuple of same length, and vice versa.") | ||||||||||
| if isinstance(eval_X, tuple) and isinstance(eval_y, tuple): | ||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (changed by me in d7e0fff) This |
||||||||||
| if len(eval_X) != len(eval_y): | ||||||||||
| raise ValueError("If eval_X is a tuple, y_val must be a tuple of same length, and vice versa.") | ||||||||||
| if isinstance(eval_X, tuple) and isinstance(eval_y, tuple): | ||||||||||
| eval_set = list(zip(eval_X, eval_y)) | ||||||||||
| else: | ||||||||||
| eval_set = [(eval_X, eval_y)] | ||||||||||
| return eval_set | ||||||||||
jameslamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
|
|
||||||||||
|
|
||||||||||
| class LGBMModel(_LGBMModelBase): | ||||||||||
| """Implementation of the scikit-learn API for LightGBM.""" | ||||||||||
|
|
||||||||||
|
|
@@ -932,6 +976,9 @@ def fit( | |||||||||
| categorical_feature: _LGBM_CategoricalFeatureConfiguration = "auto", | ||||||||||
| callbacks: Optional[List[Callable]] = None, | ||||||||||
| init_model: Optional[Union[str, Path, Booster, "LGBMModel"]] = None, | ||||||||||
| *, | ||||||||||
| eval_X: Optional[Union[_LGBM_ScikitMatrixLike, Tuple[_LGBM_ScikitMatrixLike]]] = None, | ||||||||||
| eval_y: Optional[Union[_LGBM_LabelType, Tuple[_LGBM_LabelType]]] = None, | ||||||||||
| ) -> "LGBMModel": | ||||||||||
| """Docstring is set after definition, using a template.""" | ||||||||||
| params = self._process_params(stage="fit") | ||||||||||
|
|
@@ -1000,9 +1047,15 @@ def fit( | |||||||||
| ) | ||||||||||
|
|
||||||||||
| valid_sets: List[Dataset] = [] | ||||||||||
| eval_set = _validate_eval_set_Xy(eval_set=eval_set, eval_X=eval_X, eval_y=eval_y) | ||||||||||
| if eval_set is not None: | ||||||||||
| if isinstance(eval_set, tuple): | ||||||||||
| eval_set = [eval_set] | ||||||||||
| # check eval_group (only relevant for ranking tasks) | ||||||||||
| if eval_group is not None: | ||||||||||
| if len(eval_group) != len(eval_set): | ||||||||||
| raise ValueError( | ||||||||||
| f"Length of eval_group ({len(eval_group)}) not equal to length of eval_set ({len(eval_set)})" | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| for i, valid_data in enumerate(eval_set): | ||||||||||
| # reduce cost for prediction training data | ||||||||||
| if valid_data[0] is X and valid_data[1] is y: | ||||||||||
|
|
@@ -1406,6 +1459,9 @@ def fit( # type: ignore[override] | |||||||||
| categorical_feature: _LGBM_CategoricalFeatureConfiguration = "auto", | ||||||||||
| callbacks: Optional[List[Callable]] = None, | ||||||||||
| init_model: Optional[Union[str, Path, Booster, LGBMModel]] = None, | ||||||||||
| *, | ||||||||||
| eval_X: Optional[Union[_LGBM_ScikitMatrixLike, Tuple[_LGBM_ScikitMatrixLike]]] = None, | ||||||||||
| eval_y: Optional[Union[_LGBM_LabelType, Tuple[_LGBM_LabelType]]] = None, | ||||||||||
| ) -> "LGBMRegressor": | ||||||||||
| """Docstring is inherited from the LGBMModel.""" | ||||||||||
| super().fit( | ||||||||||
|
|
@@ -1414,6 +1470,8 @@ def fit( # type: ignore[override] | |||||||||
| sample_weight=sample_weight, | ||||||||||
| init_score=init_score, | ||||||||||
| eval_set=eval_set, | ||||||||||
| eval_X=eval_X, | ||||||||||
| eval_y=eval_y, | ||||||||||
| eval_names=eval_names, | ||||||||||
| eval_sample_weight=eval_sample_weight, | ||||||||||
| eval_init_score=eval_init_score, | ||||||||||
|
|
@@ -1521,6 +1579,9 @@ def fit( # type: ignore[override] | |||||||||
| categorical_feature: _LGBM_CategoricalFeatureConfiguration = "auto", | ||||||||||
| callbacks: Optional[List[Callable]] = None, | ||||||||||
| init_model: Optional[Union[str, Path, Booster, LGBMModel]] = None, | ||||||||||
| *, | ||||||||||
| eval_X: Optional[Union[_LGBM_ScikitMatrixLike, Tuple[_LGBM_ScikitMatrixLike]]] = None, | ||||||||||
| eval_y: Optional[Union[_LGBM_LabelType, Tuple[_LGBM_LabelType]]] = None, | ||||||||||
| ) -> "LGBMClassifier": | ||||||||||
| """Docstring is inherited from the LGBMModel.""" | ||||||||||
| _LGBMAssertAllFinite(y) | ||||||||||
|
|
@@ -1578,6 +1639,8 @@ def fit( # type: ignore[override] | |||||||||
| init_score=init_score, | ||||||||||
| eval_set=valid_sets, | ||||||||||
| eval_names=eval_names, | ||||||||||
| eval_X=eval_X, | ||||||||||
| eval_y=eval_y, | ||||||||||
| eval_sample_weight=eval_sample_weight, | ||||||||||
| eval_class_weight=eval_class_weight, | ||||||||||
| eval_init_score=eval_init_score, | ||||||||||
|
|
@@ -1773,27 +1836,17 @@ def fit( # type: ignore[override] | |||||||||
| categorical_feature: _LGBM_CategoricalFeatureConfiguration = "auto", | ||||||||||
| callbacks: Optional[List[Callable]] = None, | ||||||||||
| init_model: Optional[Union[str, Path, Booster, LGBMModel]] = None, | ||||||||||
| *, | ||||||||||
| eval_X: Optional[Union[_LGBM_ScikitMatrixLike, Tuple[_LGBM_ScikitMatrixLike]]] = None, | ||||||||||
| eval_y: Optional[Union[_LGBM_LabelType, Tuple[_LGBM_LabelType]]] = None, | ||||||||||
| ) -> "LGBMRanker": | ||||||||||
| """Docstring is inherited from the LGBMModel.""" | ||||||||||
| # check group data | ||||||||||
| if group is None: | ||||||||||
| raise ValueError("Should set group for ranking task") | ||||||||||
|
|
||||||||||
| if eval_set is not None: | ||||||||||
| if eval_group is None: | ||||||||||
| raise ValueError("Eval_group cannot be None when eval_set is not None") | ||||||||||
| if len(eval_group) != len(eval_set): | ||||||||||
| raise ValueError("Length of eval_group should be equal to eval_set") | ||||||||||
| if ( | ||||||||||
| isinstance(eval_group, dict) | ||||||||||
| and any(i not in eval_group or eval_group[i] is None for i in range(len(eval_group))) | ||||||||||
| or isinstance(eval_group, list) | ||||||||||
| and any(group is None for group in eval_group) | ||||||||||
| ): | ||||||||||
| raise ValueError( | ||||||||||
| "Should set group for all eval datasets for ranking task; " | ||||||||||
| "if you use dict, the index should start from 0" | ||||||||||
| ) | ||||||||||
|
Comment on lines
-1782
to
-1796
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (changed by me in d7e0fff)
|
||||||||||
| if eval_group is None and (eval_set is not None or eval_X is not None or eval_y is not None): | ||||||||||
| raise ValueError("eval_group cannot be None if any of eval_set, eval_X, or eval_y are provided") | ||||||||||
|
|
||||||||||
| self._eval_at = eval_at | ||||||||||
| super().fit( | ||||||||||
|
|
@@ -1804,6 +1857,8 @@ def fit( # type: ignore[override] | |||||||||
| group=group, | ||||||||||
| eval_set=eval_set, | ||||||||||
| eval_names=eval_names, | ||||||||||
| eval_X=eval_X, | ||||||||||
| eval_y=eval_y, | ||||||||||
| eval_sample_weight=eval_sample_weight, | ||||||||||
| eval_init_score=eval_init_score, | ||||||||||
| eval_group=eval_group, | ||||||||||
|
|
||||||||||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(changed by me in d7e0fff)
If we're going to consider
eval_setdeprecated andeval_{X,y}the new recommended pattern, l think we should nudge users towards that by updating all of the documentation. I've done that here (examples/is the only place with such code).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think
eval_X=X_test, eval_y=y_testwithout wrapping into a tuple is fine, too.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For sure, either will work. I have a very weak preference for the tuple form in these docs, to make it a little clearer that providing multiple validation sets is supported.