|
5 | 5 | # License: MIT
|
6 | 6 |
|
7 | 7 | import numpy as np
|
8 |
| - |
9 | 8 | from sklearn.base import ClassifierMixin, clone
|
10 |
| -from sklearn.ensemble import BaseEnsemble, VotingClassifier |
| 9 | +from sklearn.ensemble import VotingClassifier |
| 10 | +from sklearn.ensemble.base import BaseEnsemble, _set_random_states |
11 | 11 | from sklearn.tree import DecisionTreeClassifier
|
12 | 12 | from sklearn.utils import check_random_state
|
13 | 13 | from sklearn.utils.validation import check_is_fitted
|
14 | 14 |
|
15 | 15 | from ..pipeline import Pipeline
|
16 |
| -from ..under_sampling import RandomUnderSampler as ROS |
17 |
| - |
| 16 | +from ..under_sampling import RandomUnderSampler |
18 | 17 |
|
19 | 18 | MAX_INT = np.iinfo(np.int32).max
|
20 | 19 |
|
@@ -101,7 +100,7 @@ def _validate_sampler(self):
|
101 | 100 | if self.base_sampler is not None:
|
102 | 101 | self.base_sampler_ = self.base_sampler
|
103 | 102 | else:
|
104 |
| - self.base_sampler_ = ROS() |
| 103 | + self.base_sampler_ = RandomUnderSampler() |
105 | 104 |
|
106 | 105 | if self.base_sampler_ is None:
|
107 | 106 | raise ValueError("base_sampler cannot be None")
|
@@ -136,27 +135,18 @@ def fit(self, X, y, sample_weight=None):
|
136 | 135 | self._validate_sampler()
|
137 | 136 |
|
138 | 137 | random_state = check_random_state(self.random_state)
|
139 |
| - estimator_seeds = random_state.randint(MAX_INT, size=self.n_estimators) |
140 |
| - sampler_seeds = random_state.randint(MAX_INT, size=self.n_estimators) |
141 | 138 |
|
142 | 139 | if not hasattr(self.base_sampler, 'random_state'):
|
143 | 140 | ValueError('Base sampler must have a random_state parameter')
|
144 | 141 |
|
145 |
| - pipelines = [] |
146 |
| - seeds = zip(estimator_seeds, sampler_seeds) |
147 |
| - |
148 |
| - for i, (estimator_seed, sampler_seed) in enumerate(seeds): |
149 |
| - |
150 |
| - sampler = clone(self.base_sampler_) |
151 |
| - sampler.set_params(random_state=sampler_seed) |
| 142 | + steps = [('sampler', self.base_sampler_), |
| 143 | + ('estimator', self.base_estimator_)] |
| 144 | + pipeline_template = Pipeline(steps) |
152 | 145 |
|
153 |
| - if hasattr(self.base_estimator_, 'random_state'): |
154 |
| - estimator = clone(self.base_estimator_) |
155 |
| - estimator.set_params(random_state=estimator_seed) |
156 |
| - else: |
157 |
| - estimator = clone(self.base_estimator_) |
158 |
| - steps = [('sampler', sampler), ('estimator', estimator)] |
159 |
| - pipeline = Pipeline(steps) |
| 146 | + pipelines = [] |
| 147 | + for i in enumerate(range(self.n_estimators)): |
| 148 | + pipeline = clone(pipeline_template) |
| 149 | + _set_random_states(pipeline, random_state) |
160 | 150 | pipelines.append(pipeline)
|
161 | 151 |
|
162 | 152 | ensemble_members = [[str(i), pipeline]
|
|
0 commit comments