Skip to content

Commit c6d90bc

Browse files
[python-package] support sub-classing scikit-learn estimators (#6783)
Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
1 parent 768f642 commit c6d90bc

File tree

5 files changed

+467
-11
lines changed

5 files changed

+467
-11
lines changed

docs/FAQ.rst

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,3 +377,42 @@ We strongly recommend installation from the ``conda-forge`` channel and not from
377377
For some specific examples, see `this comment <https://github.com/microsoft/LightGBM/issues/4948#issuecomment-1013766397>`__.
378378

379379
In addition, as of ``lightgbm==4.4.0``, the ``conda-forge`` package automatically supports CUDA-based GPU acceleration.
380+
381+
5. How do I subclass ``scikit-learn`` estimators?
382+
-------------------------------------------------
383+
384+
For ``lightgbm <= 4.5.0``, copy all of the constructor arguments from the corresponding
385+
``lightgbm`` class into the constructor of your custom estimator.
386+
387+
For later versions, just ensure that the constructor of your custom estimator calls ``super().__init__()``.
388+
389+
Consider the example below, which implements a regressor that allows creation of truncated predictions.
390+
This pattern will work with ``lightgbm > 4.5.0``.
391+
392+
.. code-block:: python
393+
394+
import numpy as np
395+
from lightgbm import LGBMRegressor
396+
from sklearn.datasets import make_regression
397+
398+
class TruncatedRegressor(LGBMRegressor):
399+
400+
def __init__(self, **kwargs):
401+
super().__init__(**kwargs)
402+
403+
def predict(self, X, max_score: float = np.inf):
404+
preds = super().predict(X)
405+
np.clip(preds, a_min=None, a_max=max_score, out=preds)
406+
return preds
407+
408+
X, y = make_regression(n_samples=1_000, n_features=4)
409+
410+
reg_trunc = TruncatedRegressor().fit(X, y)
411+
412+
preds = reg_trunc.predict(X)
413+
print(f"mean: {preds.mean():.2f}, max: {preds.max():.2f}")
414+
# mean: -6.81, max: 345.10
415+
416+
preds_trunc = reg_trunc.predict(X, max_score=preds.mean())
417+
print(f"mean: {preds_trunc.mean():.2f}, max: {preds_trunc.max():.2f}")
418+
# mean: -56.50, max: -6.81

python-package/lightgbm/dask.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,6 +1115,7 @@ class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
11151115

11161116
def __init__(
11171117
self,
1118+
*,
11181119
boosting_type: str = "gbdt",
11191120
num_leaves: int = 31,
11201121
max_depth: int = -1,
@@ -1318,6 +1319,7 @@ class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
13181319

13191320
def __init__(
13201321
self,
1322+
*,
13211323
boosting_type: str = "gbdt",
13221324
num_leaves: int = 31,
13231325
max_depth: int = -1,
@@ -1485,6 +1487,7 @@ class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
14851487

14861488
def __init__(
14871489
self,
1490+
*,
14881491
boosting_type: str = "gbdt",
14891492
num_leaves: int = 31,
14901493
max_depth: int = -1,

python-package/lightgbm/sklearn.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,7 @@ class LGBMModel(_LGBMModelBase):
488488

489489
def __init__(
490490
self,
491+
*,
491492
boosting_type: str = "gbdt",
492493
num_leaves: int = 31,
493494
max_depth: int = -1,
@@ -745,7 +746,35 @@ def get_params(self, deep: bool = True) -> Dict[str, Any]:
745746
params : dict
746747
Parameter names mapped to their values.
747748
"""
749+
# Based on: https://github.com/dmlc/xgboost/blob/bd92b1c9c0db3e75ec3dfa513e1435d518bb535d/python-package/xgboost/sklearn.py#L941
750+
# which was based on: https://stackoverflow.com/questions/59248211
751+
#
752+
# `get_params()` flows like this:
753+
#
754+
# 0. Get parameters in subclass (self.__class__) first, by using inspect.
755+
# 1. Get parameters in all parent classes (especially `LGBMModel`).
756+
# 2. Get whatever was passed via `**kwargs`.
757+
# 3. Merge them.
758+
#
759+
# This needs to accommodate being called recursively in the following
760+
# inheritance graphs (and similar for classification and ranking):
761+
#
762+
# DaskLGBMRegressor -> LGBMRegressor -> LGBMModel -> BaseEstimator
763+
# (custom subclass) -> LGBMRegressor -> LGBMModel -> BaseEstimator
764+
# LGBMRegressor -> LGBMModel -> BaseEstimator
765+
# (custom subclass) -> LGBMModel -> BaseEstimator
766+
# LGBMModel -> BaseEstimator
767+
#
748768
params = super().get_params(deep=deep)
769+
cp = copy.copy(self)
770+
# If the immediate parent defines get_params(), use that.
771+
if callable(getattr(cp.__class__.__bases__[0], "get_params", None)):
772+
cp.__class__ = cp.__class__.__bases__[0]
773+
# Otherwise, skip it and assume the next class will have it.
774+
# This is here primarily for cases where the first class in MRO is a scikit-learn mixin.
775+
else:
776+
cp.__class__ = cp.__class__.__bases__[1]
777+
params.update(cp.__class__.get_params(cp, deep))
749778
params.update(self._other_params)
750779
return params
751780

@@ -1285,6 +1314,57 @@ def feature_names_in_(self) -> None:
12851314
class LGBMRegressor(_LGBMRegressorBase, LGBMModel):
12861315
"""LightGBM regressor."""
12871316

1317+
# NOTE: all args from LGBMModel.__init__() are intentionally repeated here for
1318+
# docs, help(), and tab completion.
1319+
def __init__(
1320+
self,
1321+
*,
1322+
boosting_type: str = "gbdt",
1323+
num_leaves: int = 31,
1324+
max_depth: int = -1,
1325+
learning_rate: float = 0.1,
1326+
n_estimators: int = 100,
1327+
subsample_for_bin: int = 200000,
1328+
objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
1329+
class_weight: Optional[Union[Dict, str]] = None,
1330+
min_split_gain: float = 0.0,
1331+
min_child_weight: float = 1e-3,
1332+
min_child_samples: int = 20,
1333+
subsample: float = 1.0,
1334+
subsample_freq: int = 0,
1335+
colsample_bytree: float = 1.0,
1336+
reg_alpha: float = 0.0,
1337+
reg_lambda: float = 0.0,
1338+
random_state: Optional[Union[int, np.random.RandomState, np.random.Generator]] = None,
1339+
n_jobs: Optional[int] = None,
1340+
importance_type: str = "split",
1341+
**kwargs: Any,
1342+
) -> None:
1343+
super().__init__(
1344+
boosting_type=boosting_type,
1345+
num_leaves=num_leaves,
1346+
max_depth=max_depth,
1347+
learning_rate=learning_rate,
1348+
n_estimators=n_estimators,
1349+
subsample_for_bin=subsample_for_bin,
1350+
objective=objective,
1351+
class_weight=class_weight,
1352+
min_split_gain=min_split_gain,
1353+
min_child_weight=min_child_weight,
1354+
min_child_samples=min_child_samples,
1355+
subsample=subsample,
1356+
subsample_freq=subsample_freq,
1357+
colsample_bytree=colsample_bytree,
1358+
reg_alpha=reg_alpha,
1359+
reg_lambda=reg_lambda,
1360+
random_state=random_state,
1361+
n_jobs=n_jobs,
1362+
importance_type=importance_type,
1363+
**kwargs,
1364+
)
1365+
1366+
__init__.__doc__ = LGBMModel.__init__.__doc__
1367+
12881368
def _more_tags(self) -> Dict[str, Any]:
12891369
# handle the case where RegressorMixin possibly provides _more_tags()
12901370
if callable(getattr(_LGBMRegressorBase, "_more_tags", None)):
@@ -1344,6 +1424,57 @@ def fit( # type: ignore[override]
13441424
class LGBMClassifier(_LGBMClassifierBase, LGBMModel):
13451425
"""LightGBM classifier."""
13461426

1427+
# NOTE: all args from LGBMModel.__init__() are intentionally repeated here for
1428+
# docs, help(), and tab completion.
1429+
def __init__(
1430+
self,
1431+
*,
1432+
boosting_type: str = "gbdt",
1433+
num_leaves: int = 31,
1434+
max_depth: int = -1,
1435+
learning_rate: float = 0.1,
1436+
n_estimators: int = 100,
1437+
subsample_for_bin: int = 200000,
1438+
objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
1439+
class_weight: Optional[Union[Dict, str]] = None,
1440+
min_split_gain: float = 0.0,
1441+
min_child_weight: float = 1e-3,
1442+
min_child_samples: int = 20,
1443+
subsample: float = 1.0,
1444+
subsample_freq: int = 0,
1445+
colsample_bytree: float = 1.0,
1446+
reg_alpha: float = 0.0,
1447+
reg_lambda: float = 0.0,
1448+
random_state: Optional[Union[int, np.random.RandomState, np.random.Generator]] = None,
1449+
n_jobs: Optional[int] = None,
1450+
importance_type: str = "split",
1451+
**kwargs: Any,
1452+
) -> None:
1453+
super().__init__(
1454+
boosting_type=boosting_type,
1455+
num_leaves=num_leaves,
1456+
max_depth=max_depth,
1457+
learning_rate=learning_rate,
1458+
n_estimators=n_estimators,
1459+
subsample_for_bin=subsample_for_bin,
1460+
objective=objective,
1461+
class_weight=class_weight,
1462+
min_split_gain=min_split_gain,
1463+
min_child_weight=min_child_weight,
1464+
min_child_samples=min_child_samples,
1465+
subsample=subsample,
1466+
subsample_freq=subsample_freq,
1467+
colsample_bytree=colsample_bytree,
1468+
reg_alpha=reg_alpha,
1469+
reg_lambda=reg_lambda,
1470+
random_state=random_state,
1471+
n_jobs=n_jobs,
1472+
importance_type=importance_type,
1473+
**kwargs,
1474+
)
1475+
1476+
__init__.__doc__ = LGBMModel.__init__.__doc__
1477+
13471478
def _more_tags(self) -> Dict[str, Any]:
13481479
# handle the case where ClassifierMixin possibly provides _more_tags()
13491480
if callable(getattr(_LGBMClassifierBase, "_more_tags", None)):
@@ -1554,6 +1685,57 @@ class LGBMRanker(LGBMModel):
15541685
Please use this class mainly for training and applying ranking models in common sklearnish way.
15551686
"""
15561687

1688+
# NOTE: all args from LGBMModel.__init__() are intentionally repeated here for
1689+
# docs, help(), and tab completion.
1690+
def __init__(
1691+
self,
1692+
*,
1693+
boosting_type: str = "gbdt",
1694+
num_leaves: int = 31,
1695+
max_depth: int = -1,
1696+
learning_rate: float = 0.1,
1697+
n_estimators: int = 100,
1698+
subsample_for_bin: int = 200000,
1699+
objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
1700+
class_weight: Optional[Union[Dict, str]] = None,
1701+
min_split_gain: float = 0.0,
1702+
min_child_weight: float = 1e-3,
1703+
min_child_samples: int = 20,
1704+
subsample: float = 1.0,
1705+
subsample_freq: int = 0,
1706+
colsample_bytree: float = 1.0,
1707+
reg_alpha: float = 0.0,
1708+
reg_lambda: float = 0.0,
1709+
random_state: Optional[Union[int, np.random.RandomState, np.random.Generator]] = None,
1710+
n_jobs: Optional[int] = None,
1711+
importance_type: str = "split",
1712+
**kwargs: Any,
1713+
) -> None:
1714+
super().__init__(
1715+
boosting_type=boosting_type,
1716+
num_leaves=num_leaves,
1717+
max_depth=max_depth,
1718+
learning_rate=learning_rate,
1719+
n_estimators=n_estimators,
1720+
subsample_for_bin=subsample_for_bin,
1721+
objective=objective,
1722+
class_weight=class_weight,
1723+
min_split_gain=min_split_gain,
1724+
min_child_weight=min_child_weight,
1725+
min_child_samples=min_child_samples,
1726+
subsample=subsample,
1727+
subsample_freq=subsample_freq,
1728+
colsample_bytree=colsample_bytree,
1729+
reg_alpha=reg_alpha,
1730+
reg_lambda=reg_lambda,
1731+
random_state=random_state,
1732+
n_jobs=n_jobs,
1733+
importance_type=importance_type,
1734+
**kwargs,
1735+
)
1736+
1737+
__init__.__doc__ = LGBMModel.__init__.__doc__
1738+
15571739
def fit( # type: ignore[override]
15581740
self,
15591741
X: _LGBM_ScikitMatrixLike,

tests/python_package_test/test_dask.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,26 +1373,42 @@ def test_machines_should_be_used_if_provided(task, cluster):
13731373

13741374

13751375
@pytest.mark.parametrize(
1376-
"classes",
1376+
"dask_est,sklearn_est",
13771377
[
13781378
(lgb.DaskLGBMClassifier, lgb.LGBMClassifier),
13791379
(lgb.DaskLGBMRegressor, lgb.LGBMRegressor),
13801380
(lgb.DaskLGBMRanker, lgb.LGBMRanker),
13811381
],
13821382
)
1383-
def test_dask_classes_and_sklearn_equivalents_have_identical_constructors_except_client_arg(classes):
1384-
dask_spec = inspect.getfullargspec(classes[0])
1385-
sklearn_spec = inspect.getfullargspec(classes[1])
1383+
def test_dask_classes_and_sklearn_equivalents_have_identical_constructors_except_client_arg(dask_est, sklearn_est):
1384+
dask_spec = inspect.getfullargspec(dask_est)
1385+
sklearn_spec = inspect.getfullargspec(sklearn_est)
1386+
1387+
# should not allow for any varargs
13861388
assert dask_spec.varargs == sklearn_spec.varargs
1389+
assert dask_spec.varargs is None
1390+
1391+
# the only varkw should be **kwargs,
1392+
# for pass-through to parent classes' __init__()
13871393
assert dask_spec.varkw == sklearn_spec.varkw
1388-
assert dask_spec.kwonlyargs == sklearn_spec.kwonlyargs
1389-
assert dask_spec.kwonlydefaults == sklearn_spec.kwonlydefaults
1394+
assert dask_spec.varkw == "kwargs"
13901395

13911396
# "client" should be the only different, and the final argument
1392-
assert dask_spec.args[:-1] == sklearn_spec.args
1393-
assert dask_spec.defaults[:-1] == sklearn_spec.defaults
1394-
assert dask_spec.args[-1] == "client"
1395-
assert dask_spec.defaults[-1] is None
1397+
assert dask_spec.kwonlyargs == [*sklearn_spec.kwonlyargs, "client"]
1398+
1399+
# default values for all constructor arguments should be identical
1400+
#
1401+
# NOTE: if LGBMClassifier / LGBMRanker / LGBMRegressor ever override
1402+
# any of LGBMModel's constructor arguments, this will need to be updated
1403+
assert dask_spec.kwonlydefaults == {**sklearn_spec.kwonlydefaults, "client": None}
1404+
1405+
# only positional argument should be 'self'
1406+
assert dask_spec.args == sklearn_spec.args
1407+
assert dask_spec.args == ["self"]
1408+
assert dask_spec.defaults is None
1409+
1410+
# get_params() should be identical, except for "client"
1411+
assert dask_est().get_params() == {**sklearn_est().get_params(), "client": None}
13961412

13971413

13981414
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)