Skip to content

Commit 9c4aa39

Browse files
committed
moved random state from init to fit, added example files so that the example in the docstring is self-contained
1 parent bcaa9bf commit 9c4aa39

File tree

3 files changed

+2016
-17
lines changed

3 files changed

+2016
-17
lines changed

boruta/boruta_py.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -137,32 +137,29 @@ class BorutaPy(BaseEstimator, TransformerMixin):
137137
138138
Examples
139139
--------
140-
141-
import pandas as pd
142-
from sklearn.ensemble import RandomForestClassifier
143-
from boruta_py import BorutaPy
144-
140+
145141
# load X and y
146142
# NOTE BorutaPy accepts numpy arrays only, hence the .values attribute
147-
X = pd.read_csv('my_X_table.csv', index_col=0).values
148-
y = pd.read_csv('my_y_vector.csv', index_col=0).values
149-
143+
X = pd.read_csv('examples/test_X.csv', index_col=0).values
144+
y = pd.read_csv('examples/test_y.csv', header=None, index_col=0).values
145+
y = y.ravel()
146+
150147
# define random forest classifier, with utilising all cores and
151148
# sampling in proportion to y labels
152149
rf = RandomForestClassifier(n_jobs=-1, class_weight='auto', max_depth=5)
153-
150+
154151
# define Boruta feature selection method
155-
feat_selector = BorutaPy(rf, n_estimators='auto', verbose=2)
156-
157-
# find all relevant features
152+
feat_selector = BorutaPy(rf, n_estimators='auto', verbose=2, random_state=1)
153+
154+
# find all relevant features - 5 features should be selected
158155
feat_selector.fit(X, y)
159-
160-
# check selected features
156+
157+
# check selected features - first 5 features are selected
161158
feat_selector.support_
162-
159+
163160
# check ranking of features
164161
feat_selector.ranking_
165-
162+
166163
# call transform() on X to filter it down to selected features
167164
X_filtered = feat_selector.transform(X)
168165
@@ -181,7 +178,7 @@ def __init__(self, estimator, n_estimators=1000, perc=100, alpha=0.05,
181178
self.alpha = alpha
182179
self.two_step = two_step
183180
self.max_iter = max_iter
184-
self.random_state = check_random_state(random_state)
181+
self.random_state = random_state
185182
self.verbose = verbose
186183

187184
def fit(self, X, y):
@@ -248,6 +245,7 @@ def fit_transform(self, X, y, weak=False):
248245
def _fit(self, X, y):
249246
# check input params
250247
self._check_params(X, y)
248+
self.random_state = check_random_state(self.random_state)
251249
# setup variables for Boruta
252250
n_sample, n_feat = X.shape
253251
_iter = 1

0 commit comments

Comments
 (0)