|
11 | 11 | from __future__ import print_function, division
|
12 | 12 | import numpy as np
|
13 | 13 | import scipy as sp
|
14 |
| -from sklearn.utils import check_X_y |
| 14 | +from sklearn.utils import check_random_state, check_X_y |
15 | 15 |
|
16 | 16 | class BorutaPy(object):
|
17 | 17 | """
|
@@ -99,6 +99,12 @@ class BorutaPy(object):
|
99 | 99 | max_iter : int, default = 100
|
100 | 100 | The number of maximum iterations to perform.
|
101 | 101 |
|
| 102 | + random_state : int, RandomState instance or None; default=None |
| 103 | + If int, random_state is the seed used by the random number generator; |
| 104 | + If RandomState instance, random_state is the random number generator; |
| 105 | + If None, the random number generator is the RandomState instance used |
| 106 | + by `np.random`. |
| 107 | +
|
102 | 108 | verbose : int, default=0
|
103 | 109 | Controls verbosity of output:
|
104 | 110 | - 0: no output
|
@@ -166,13 +172,14 @@ class BorutaPy(object):
|
166 | 172 | """
|
167 | 173 |
|
168 | 174 | def __init__(self, estimator, n_estimators=1000, perc=100, alpha=0.05,
|
169 |
| - two_step=True, max_iter=100, verbose=0): |
| 175 | + two_step=True, max_iter=100, random_state=None, verbose=0): |
170 | 176 | self.estimator = estimator
|
171 | 177 | self.n_estimators = n_estimators
|
172 | 178 | self.perc = perc
|
173 | 179 | self.alpha = alpha
|
174 | 180 | self.two_step = two_step
|
175 | 181 | self.max_iter = max_iter
|
| 182 | + self.random_state = check_random_state(random_state) |
176 | 183 | self.verbose = verbose
|
177 | 184 |
|
178 | 185 | def fit(self, X, y):
|
@@ -268,8 +275,7 @@ def _fit(self, X, y):
|
268 | 275 | self.estimator.set_params(n_estimators=n_tree)
|
269 | 276 |
|
270 | 277 | # make sure we start with a new tree in each iteration
|
271 |
| - rnd_st = np.random.randint(1,1e6,1)[0] |
272 |
| - self.estimator.set_params(random_state=rnd_st) |
| 278 | + self.estimator.set_params(random_state=self.random_state) |
273 | 279 |
|
274 | 280 | # add shadow attributes, shuffle them and train estimator, get imps
|
275 | 281 | cur_imp = self._add_shadows_get_imps(X, y, dec_reg)
|
@@ -380,7 +386,7 @@ def _get_imp(self, X, y):
|
380 | 386 | return imp
|
381 | 387 |
|
382 | 388 | def _get_shuffle(self, seq):
|
383 |
| - np.random.shuffle(seq) |
| 389 | + self.random_state.shuffle(seq) |
384 | 390 | return seq
|
385 | 391 |
|
386 | 392 | def _add_shadows_get_imps(self, X, y, dec_reg):
|
|
0 commit comments