Skip to content

Commit eb28240

Browse files
authored
Merge pull request scikit-learn-contrib#18 from bittremieux/random
Allow explicit specification of the random state
2 parents bf84256 + cb0d751 commit eb28240

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

boruta/boruta_py.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from __future__ import print_function, division
1212
import numpy as np
1313
import scipy as sp
14-
from sklearn.utils import check_X_y
14+
from sklearn.utils import check_random_state, check_X_y
1515

1616
class BorutaPy(object):
1717
"""
@@ -99,6 +99,12 @@ class BorutaPy(object):
9999
max_iter : int, default = 100
100100
The number of maximum iterations to perform.
101101
102+
random_state : int, RandomState instance or None; default=None
103+
If int, random_state is the seed used by the random number generator;
104+
If RandomState instance, random_state is the random number generator;
105+
If None, the random number generator is the RandomState instance used
106+
by `np.random`.
107+
102108
verbose : int, default=0
103109
Controls verbosity of output:
104110
- 0: no output
@@ -166,13 +172,14 @@ class BorutaPy(object):
166172
"""
167173

168174
def __init__(self, estimator, n_estimators=1000, perc=100, alpha=0.05,
169-
two_step=True, max_iter=100, verbose=0):
175+
two_step=True, max_iter=100, random_state=None, verbose=0):
170176
self.estimator = estimator
171177
self.n_estimators = n_estimators
172178
self.perc = perc
173179
self.alpha = alpha
174180
self.two_step = two_step
175181
self.max_iter = max_iter
182+
self.random_state = check_random_state(random_state)
176183
self.verbose = verbose
177184

178185
def fit(self, X, y):
@@ -268,8 +275,7 @@ def _fit(self, X, y):
268275
self.estimator.set_params(n_estimators=n_tree)
269276

270277
# make sure we start with a new tree in each iteration
271-
rnd_st = np.random.randint(1,1e6,1)[0]
272-
self.estimator.set_params(random_state=rnd_st)
278+
self.estimator.set_params(random_state=self.random_state)
273279

274280
# add shadow attributes, shuffle them and train estimator, get imps
275281
cur_imp = self._add_shadows_get_imps(X, y, dec_reg)
@@ -380,7 +386,7 @@ def _get_imp(self, X, y):
380386
return imp
381387

382388
def _get_shuffle(self, seq):
383-
np.random.shuffle(seq)
389+
self.random_state.shuffle(seq)
384390
return seq
385391

386392
def _add_shadows_get_imps(self, X, y, dec_reg):

0 commit comments

Comments
 (0)