Skip to content

Commit 6c5f787

Browse files
authored
Hack EEG to be scikit-learn compatible estimator
1 parent 7b2d05b commit 6c5f787

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

imblearn/ensemble/easy_ensemble_generalization.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,22 @@
66

77
import numpy as np
88

9-
from sklearn.base import ClassifierMixin, clone
9+
from abc import ABCMeta, abstractmethod
10+
from sklearn.base import ClassifierMixin, MetaEstimatorMixin, clone
1011
from sklearn.ensemble import VotingClassifier
1112
from sklearn.ensemble.base import BaseEnsemble, _set_random_states
13+
from sklearn.externals import six
1214
from sklearn.tree import DecisionTreeClassifier
1315
from sklearn.utils import check_random_state
1416
from sklearn.utils.validation import check_is_fitted
1517

18+
1619
from ..pipeline import Pipeline
1720
from ..under_sampling import RandomUnderSampler
1821

19-
MAX_INT = np.iinfo(np.int32).max
2022

23+
class _EasyEnsembleGeneralization(ClassifierMixin, BaseEnsemble):
2124

22-
class EasyEnsembleGeneralization(BaseEnsemble, ClassifierMixin):
2325
"""This classifier generalize the Easy Ensemble algorithm for imbalanced
2426
datasets.
2527
@@ -92,7 +94,7 @@ def __init__(self,
9294

9395
def _validate_estimator(self):
9496
"""Check the estimator and set the base_estimator_ attribute."""
95-
super(EasyEnsembleGeneralization, self)._validate_estimator(
97+
super(_EasyEnsembleGeneralization, self)._validate_estimator(
9698
default=DecisionTreeClassifier())
9799

98100
def _validate_sampler(self):
@@ -197,3 +199,11 @@ def predict_proba(self, X):
197199
"""
198200
check_is_fitted(self, "_voting")
199201
return self._voting.predict_proba(X)
202+
203+
204+
# XXX make EasyEnsembleGeneralization to pass sklearn compatibility test
205+
bases = _EasyEnsembleGeneralization.__mro__
206+
bases = tuple(x for x in bases if x != MetaEstimatorMixin)
207+
print(bases)
208+
EasyEnsembleGeneralization = type(
209+
'EasyEnsembleGeneralization', bases, dict(_EasyEnsembleGeneralization.__dict__))

0 commit comments

Comments
 (0)