29
29
30
30
# Do not use class names on scikit-learn directly. Re-define the classes on
31
31
# .compat to guarantee the behavior without scikit-learn
32
- from .compat import SKLEARN_INSTALLED , XGBClassifierBase , XGBModelBase , XGBRegressorBase
32
+ from .compat import (
33
+ SKLEARN_INSTALLED ,
34
+ XGBClassifierBase ,
35
+ XGBModelBase ,
36
+ XGBRegressorBase ,
37
+ _sklearn_Tags ,
38
+ _sklearn_version ,
39
+ )
33
40
from .config import config_context
34
41
from .core import (
35
42
Booster ,
45
52
from .training import train
46
53
47
54
48
- class XGBRankerMixIn : # pylint: disable=too-few-public-methods
55
+ class XGBRankerMixIn :
49
56
"""MixIn for ranking, defines the _estimator_type usually defined in scikit-learn
50
57
base classes.
51
58
@@ -69,7 +76,7 @@ def _can_use_qdm(tree_method: Optional[str]) -> bool:
69
76
return tree_method in ("hist" , "gpu_hist" , None , "auto" )
70
77
71
78
72
- class _SklObjWProto (Protocol ): # pylint: disable=too-few-public-methods
79
+ class _SklObjWProto (Protocol ):
73
80
def __call__ (
74
81
self ,
75
82
y_true : ArrayLike ,
@@ -782,11 +789,52 @@ def __init__(
782
789
783
790
def _more_tags (self ) -> Dict [str , bool ]:
784
791
"""Tags used for scikit-learn data validation."""
785
- tags = {"allow_nan" : True , "no_validation" : True }
792
+ tags = {"allow_nan" : True , "no_validation" : True , "sparse" : True }
786
793
if hasattr (self , "kwargs" ) and self .kwargs .get ("updater" ) == "shotgun" :
787
794
tags ["non_deterministic" ] = True
795
+
796
+ tags ["categorical" ] = self .enable_categorical
797
+ return tags
798
+
799
+ @staticmethod
800
+ def _update_sklearn_tags_from_dict (
801
+ * ,
802
+ tags : _sklearn_Tags ,
803
+ tags_dict : Dict [str , bool ],
804
+ ) -> _sklearn_Tags :
805
+ """Update ``sklearn.utils.Tags`` inherited from ``scikit-learn`` base classes.
806
+
807
+ ``scikit-learn`` 1.6 introduced a dataclass-based interface for estimator tags.
808
+ ref: https://github.com/scikit-learn/scikit-learn/pull/29677
809
+
810
+ This method handles updating that instance based on the values in
811
+ ``self._more_tags()``.
812
+
813
+ """
814
+ tags .non_deterministic = tags_dict .get ("non_deterministic" , False )
815
+ tags .no_validation = tags_dict ["no_validation" ]
816
+ tags .input_tags .allow_nan = tags_dict ["allow_nan" ]
817
+ tags .input_tags .sparse = tags_dict ["sparse" ]
818
+ tags .input_tags .categorical = tags_dict ["categorical" ]
788
819
return tags
789
820
821
+ def __sklearn_tags__ (self ) -> _sklearn_Tags :
822
+ # XGBModelBase.__sklearn_tags__() cannot be called unconditionally,
823
+ # because that method isn't defined for scikit-learn<1.6
824
+ if not hasattr (XGBModelBase , "__sklearn_tags__" ):
825
+ err_msg = (
826
+ "__sklearn_tags__() should not be called when using scikit-learn<1.6. "
827
+ f"Detected version: { _sklearn_version } "
828
+ )
829
+ raise AttributeError (err_msg )
830
+
831
+ # take whatever tags are provided by BaseEstimator, then modify
832
+ # them with XGBoost-specific values
833
+ return self ._update_sklearn_tags_from_dict (
834
+ tags = super ().__sklearn_tags__ (), # pylint: disable=no-member
835
+ tags_dict = self ._more_tags (),
836
+ )
837
+
790
838
def __sklearn_is_fitted__ (self ) -> bool :
791
839
return hasattr (self , "_Booster" )
792
840
@@ -841,13 +889,27 @@ def get_params(self, deep: bool = True) -> Dict[str, Any]:
841
889
"""Get parameters."""
842
890
# Based on: https://stackoverflow.com/questions/59248211
843
891
# The basic flow in `get_params` is:
844
- # 0. Return parameters in subclass first, by using inspect.
845
- # 1. Return parameters in `XGBModel` (the base class ).
892
+ # 0. Return parameters in subclass (self.__class__) first, by using inspect.
893
+ # 1. Return parameters in all parent classes (especially `XGBModel` ).
846
894
# 2. Return whatever in `**kwargs`.
847
895
# 3. Merge them.
896
+ #
897
+ # This needs to accommodate being called recursively in the following
898
+ # inheritance graphs (and similar for classification and ranking):
899
+ #
900
+ # XGBRFRegressor -> XGBRegressor -> XGBModel -> BaseEstimator
901
+ # XGBRegressor -> XGBModel -> BaseEstimator
902
+ # XGBModel -> BaseEstimator
903
+ #
848
904
params = super ().get_params (deep )
849
905
cp = copy .copy (self )
850
- cp .__class__ = cp .__class__ .__bases__ [0 ]
906
+ # If the immediate parent defines get_params(), use that.
907
+ if callable (getattr (cp .__class__ .__bases__ [0 ], "get_params" , None )):
908
+ cp .__class__ = cp .__class__ .__bases__ [0 ]
909
+ # Otherwise, skip it and assume the next class will have it.
910
+ # This is here primarily for cases where the first class in MRO is a scikit-learn mixin.
911
+ else :
912
+ cp .__class__ = cp .__class__ .__bases__ [1 ]
851
913
params .update (cp .__class__ .get_params (cp , deep ))
852
914
# if kwargs is a dict, update params accordingly
853
915
if hasattr (self , "kwargs" ) and isinstance (self .kwargs , dict ):
@@ -1431,7 +1493,7 @@ def _cls_predict_proba(n_classes: int, prediction: PredtT, vstack: Callable) ->
1431
1493
Number of boosting rounds.
1432
1494
""" ,
1433
1495
)
1434
- class XGBClassifier (XGBModel , XGBClassifierBase ):
1496
+ class XGBClassifier (XGBClassifierBase , XGBModel ):
1435
1497
# pylint: disable=missing-docstring,invalid-name,too-many-instance-attributes
1436
1498
@_deprecate_positional_args
1437
1499
def __init__ (
@@ -1447,6 +1509,12 @@ def _more_tags(self) -> Dict[str, bool]:
1447
1509
tags ["multilabel" ] = True
1448
1510
return tags
1449
1511
1512
+ def __sklearn_tags__ (self ) -> _sklearn_Tags :
1513
+ tags = super ().__sklearn_tags__ ()
1514
+ tags_dict = self ._more_tags ()
1515
+ tags .classifier_tags .multi_label = tags_dict ["multilabel" ]
1516
+ return tags
1517
+
1450
1518
@_deprecate_positional_args
1451
1519
def fit (
1452
1520
self ,
@@ -1717,7 +1785,7 @@ def fit(
1717
1785
"Implementation of the scikit-learn API for XGBoost regression." ,
1718
1786
["estimators" , "model" , "objective" ],
1719
1787
)
1720
- class XGBRegressor (XGBModel , XGBRegressorBase ):
1788
+ class XGBRegressor (XGBRegressorBase , XGBModel ):
1721
1789
# pylint: disable=missing-docstring
1722
1790
@_deprecate_positional_args
1723
1791
def __init__ (
@@ -1731,6 +1799,13 @@ def _more_tags(self) -> Dict[str, bool]:
1731
1799
tags ["multioutput_only" ] = False
1732
1800
return tags
1733
1801
1802
+ def __sklearn_tags__ (self ) -> _sklearn_Tags :
1803
+ tags = super ().__sklearn_tags__ ()
1804
+ tags_dict = self ._more_tags ()
1805
+ tags .target_tags .multi_output = tags_dict ["multioutput" ]
1806
+ tags .target_tags .single_output = not tags_dict ["multioutput_only" ]
1807
+ return tags
1808
+
1734
1809
1735
1810
@xgboost_model_doc (
1736
1811
"scikit-learn API for XGBoost random forest regression." ,
@@ -1858,7 +1933,7 @@ def _get_qid(
1858
1933
`qid` can be a special column of input `X` instead of a separated parameter, see
1859
1934
:py:meth:`fit` for more info.""" ,
1860
1935
)
1861
- class XGBRanker (XGBModel , XGBRankerMixIn ):
1936
+ class XGBRanker (XGBRankerMixIn , XGBModel ):
1862
1937
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
1863
1938
@_deprecate_positional_args
1864
1939
def __init__ (self , * , objective : str = "rank:ndcg" , ** kwargs : Any ):
0 commit comments