@@ -137,32 +137,29 @@ class BorutaPy(BaseEstimator, TransformerMixin):
137
137
138
138
Examples
139
139
--------
140
-
141
- import pandas as pd
142
- from sklearn.ensemble import RandomForestClassifier
143
- from boruta_py import BorutaPy
144
-
140
+
145
141
# load X and y
146
142
# NOTE BorutaPy accepts numpy arrays only, hence the .values attribute
147
- X = pd.read_csv('my_X_table.csv', index_col=0).values
148
- y = pd.read_csv('my_y_vector.csv', index_col=0).values
149
-
143
+ X = pd.read_csv('examples/test_X.csv', index_col=0).values
144
+ y = pd.read_csv('examples/test_y.csv', header=None, index_col=0).values
145
+ y = y.ravel()
146
+
150
147
# define random forest classifier, with utilising all cores and
151
148
# sampling in proportion to y labels
152
149
rf = RandomForestClassifier(n_jobs=-1, class_weight='auto', max_depth=5)
153
-
150
+
154
151
# define Boruta feature selection method
155
- feat_selector = BorutaPy(rf, n_estimators='auto', verbose=2)
156
-
157
- # find all relevant features
152
+ feat_selector = BorutaPy(rf, n_estimators='auto', verbose=2, random_state=1 )
153
+
154
+ # find all relevant features - 5 features should be selected
158
155
feat_selector.fit(X, y)
159
-
160
- # check selected features
156
+
157
+ # check selected features - first 5 features are selected
161
158
feat_selector.support_
162
-
159
+
163
160
# check ranking of features
164
161
feat_selector.ranking_
165
-
162
+
166
163
# call transform() on X to filter it down to selected features
167
164
X_filtered = feat_selector.transform(X)
168
165
@@ -181,7 +178,7 @@ def __init__(self, estimator, n_estimators=1000, perc=100, alpha=0.05,
181
178
self .alpha = alpha
182
179
self .two_step = two_step
183
180
self .max_iter = max_iter
184
- self .random_state = check_random_state ( random_state )
181
+ self .random_state = random_state
185
182
self .verbose = verbose
186
183
187
184
def fit (self , X , y ):
@@ -248,6 +245,7 @@ def fit_transform(self, X, y, weak=False):
248
245
def _fit (self , X , y ):
249
246
# check input params
250
247
self ._check_params (X , y )
248
+ self .random_state = check_random_state (self .random_state )
251
249
# setup variables for Boruta
252
250
n_sample , n_feat = X .shape
253
251
_iter = 1
0 commit comments