Skip to content

ART XGBoostClassifier Classifier Parameters Optimization Suggestions #1403

@vincentliuheyang

Description

@vincentliuheyang

Is your feature request related to a problem? Please describe.
One thing I found interesting is that you did not implement fit model for tree models. Still, I am trying to implement the fit method on my own, so I imherit the XGBoostClassifier model to implement myself.

class XGBARTClassifier(XGBoostClassifier)

However, I found one thing that might be problematic is that if I am passing a unfitted xgboost.XGBClassifier, then the model is going to tell me

AttributeError: 'XGBClassifier' object has no attribute 'n_classes_'

I found the problem is caused by this method structure:

    def _get_nb_classes(self, nb_classes: Optional[int]) -> int:
        """
        Return the number of output classes.

        :return: Number of classes in the data.
        """
        from xgboost import Booster, XGBClassifier

        if isinstance(self._model, Booster):
            try:
                return int(len(self._model.get_dump(dump_format="json")) / self._model.n_estimators)  # type: ignore
            except AttributeError:
                if nb_classes is not None:
                    return nb_classes
                raise NotImplementedError(
                    "Number of classes cannot be determined automatically. "
                    + "Please manually set argument nb_classes in XGBoostClassifier."
                ) from AttributeError

        if isinstance(self._model, `XGBClassifier`):
            return self._model.n_classes_

        return -1

For XGBClassifier I passed in, if the model is unfitted, it does not have the attribute n_classes_. Then I cannot pass through the unfitted model to proceed.

Describe the solution you'd like
This issue might be minor, but I think if you could check if the XGBClassifier is fitted or not before return the n_classes_ would be better if I am going to pass in an unfitted model.

Thank you

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions