-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Description
Is your feature request related to a problem? Please describe.
One thing I found interesting is that you did not implement fit model for tree models. Still, I am trying to implement the fit method on my own, so I imherit the XGBoostClassifier model to implement myself.
class XGBARTClassifier(XGBoostClassifier)
However, I found one thing that might be problematic is that if I am passing a unfitted xgboost.XGBClassifier, then the model is going to tell me
AttributeError: 'XGBClassifier' object has no attribute 'n_classes_'
I found the problem is caused by this method structure:
def _get_nb_classes(self, nb_classes: Optional[int]) -> int:
"""
Return the number of output classes.
:return: Number of classes in the data.
"""
from xgboost import Booster, XGBClassifier
if isinstance(self._model, Booster):
try:
return int(len(self._model.get_dump(dump_format="json")) / self._model.n_estimators) # type: ignore
except AttributeError:
if nb_classes is not None:
return nb_classes
raise NotImplementedError(
"Number of classes cannot be determined automatically. "
+ "Please manually set argument nb_classes in XGBoostClassifier."
) from AttributeError
if isinstance(self._model, `XGBClassifier`):
return self._model.n_classes_
return -1
For XGBClassifier I passed in, if the model is unfitted, it does not have the attribute n_classes_. Then I cannot pass through the unfitted model to proceed.
Describe the solution you'd like
This issue might be minor, but I think if you could check if the XGBClassifier is fitted or not before return the n_classes_ would be better if I am going to pass in an unfitted model.
Thank you