Skip to content

Commit a246d67

Browse files
committed
refactored boruta_py.py
1 parent 647c192 commit a246d67

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

boruta/boruta_py.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ class BorutaPy(BaseEstimator, TransformerMixin):
173173
Journal of Statistical Software, Vol. 36, Issue 11, Sep 2010
174174
"""
175175

176-
def __init__(self, estimator, n_estimators=1000, perc=100, alpha=0.05,
176+
def __init__(self, estimator, n_estimators=1000, perc=100, alpha=0.05,
177177
two_step=True, max_iter=100, random_state=None, verbose=0):
178178
self.estimator = estimator
179179
self.n_estimators = n_estimators
@@ -349,7 +349,6 @@ def _fit(self, X, y):
349349
self._print_results(dec_reg, _iter, 1)
350350
return self
351351

352-
353352
def _transform(self, X, weak=False):
354353
# sanity check
355354
try:
@@ -410,7 +409,7 @@ def _add_shadows_get_imps(self, X, y, dec_reg):
410409
imp_real = np.zeros(X.shape[1])
411410
imp_real[:] = np.nan
412411
imp_real[x_cur_ind] = imp[:x_cur_w]
413-
return (imp_real, imp_sha)
412+
return imp_real, imp_sha
414413

415414
def _assign_hits(self, hit_reg, cur_imp, imp_sha_max):
416415
# register hits for feautres that did better than the best of shadows
@@ -437,14 +436,14 @@ def _do_tests(self, dec_reg, hit_reg, _iter):
437436
to_reject2 = to_reject_ps <= self.alpha / float(_iter)
438437

439438
# combine the two multi corrections, and get indexes
440-
to_accept = to_accept * to_accept2
441-
to_reject = to_reject * to_reject2
439+
to_accept *= to_accept2
440+
to_reject *= to_reject2
442441
else:
443442
# as in th original Boruta, we simply do bonferroni correction
444443
# with the total n_feat in each iteration
445444
to_accept = to_accept_ps <= self.alpha / float(len(dec_reg))
446445
to_reject = to_reject_ps <= self.alpha / float(len(dec_reg))
447-
446+
448447
# find features which are 0 and have been rejected or accepted
449448
to_accept = np.where((dec_reg[active_features] == 0) * to_accept)[0]
450449
to_reject = np.where((dec_reg[active_features] == 0) * to_reject)[0]
@@ -477,16 +476,16 @@ def _fdrcorrection(self, pvals, alpha=0.05):
477476
pvals_sortind = np.argsort(pvals)
478477
pvals_sorted = np.take(pvals, pvals_sortind)
479478
nobs = len(pvals_sorted)
480-
ecdffactor = np.arange(1,nobs+1)/float(nobs)
479+
ecdffactor = np.arange(1, nobs + 1) / float(nobs)
481480

482-
reject = pvals_sorted <= ecdffactor*alpha
481+
reject = pvals_sorted <= ecdffactor * alpha
483482
if reject.any():
484483
rejectmax = max(np.nonzero(reject)[0])
485484
reject[:rejectmax] = True
486485

487486
pvals_corrected_raw = pvals_sorted / ecdffactor
488487
pvals_corrected = np.minimum.accumulate(pvals_corrected_raw[::-1])[::-1]
489-
pvals_corrected[pvals_corrected>1] = 1
488+
pvals_corrected[pvals_corrected > 1] = 1
490489
# reorder p-values and rejection mask to original order of pvals
491490
pvals_corrected_ = np.empty_like(pvals_corrected)
492491
pvals_corrected_[pvals_sortind] = pvals_corrected
@@ -523,16 +522,16 @@ def _print_results(self, dec_reg, _iter, flag):
523522
# still in feature selection
524523
if flag == 0:
525524
n_tentative = np.where(dec_reg == 0)[0].shape[0]
526-
content = map(str,[n_iter,n_confirmed,n_tentative,n_rejected])
525+
content = map(str, [n_iter, n_confirmed, n_tentative, n_rejected])
527526
if self.verbose == 1:
528527
output = cols[0] + n_iter
529528
elif self.verbose > 1:
530-
output = '\n'.join([x[0]+'\t'+x[1] for x in zip(cols,content)])
529+
output = '\n'.join([x[0] + '\t' + x[1] for x in zip(cols, content)])
531530

532531
# Boruta finished running and tentatives have been filtered
533532
else:
534533
n_tentative = np.sum(self.support_weak_)
535534
content = map(str, [n_iter, n_confirmed, n_tentative, n_rejected])
536-
result = '\n'.join([x[0] +'\t' + x[1] for x in zip(cols, content)])
535+
result = '\n'.join([x[0] + '\t' + x[1] for x in zip(cols, content)])
537536
output = "\n\nBorutaPy finished running.\n\n" + result
538-
print(output)
537+
print(output)

0 commit comments

Comments
 (0)