Skip to content

Commit 478e4bc

Browse files
Merge pull request #20 from loft-br/feat/add-feature-importance
feat: make feature importance transparent
2 parents 95e83d4 + eddb567 commit 478e4bc

File tree

3 files changed

+8
-1
lines changed

3 files changed

+8
-1
lines changed

xgbse/_debiased_bce.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def __init__(
149149
self.lr_params = lr_params
150150
self.n_jobs = n_jobs
151151
self.persist_train = False
152+
self.feature_importances_ = None
152153

153154
def fit(
154155
self,
@@ -224,7 +225,7 @@ def fit(
224225
evals=evals,
225226
verbose_eval=verbose_eval,
226227
)
227-
228+
self.feature_importances_ = self.bst.get_score()
228229
# predicting and encoding leaves
229230
self.encoder = OneHotEncoder()
230231
leaves = self.bst.predict(dtrain, pred_leaf=True)

xgbse/_kaplan_neighbors.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def __init__(self, xgb_params=None, n_neighbors=30, radius=None):
9898
self.persist_train = False
9999
self.index_id = None
100100
self.radius = None
101+
self.feature_importances_ = None
101102

102103
def fit(
103104
self,
@@ -171,6 +172,7 @@ def fit(
171172
evals=evals,
172173
verbose_eval=verbose_eval,
173174
)
175+
self.feature_importances_ = self.bst.get_score()
174176

175177
# creating nearest neighbor index
176178
leaves = self.bst.predict(dtrain, pred_leaf=True)
@@ -338,6 +340,7 @@ def __init__(
338340
self.xgb_params = xgb_params
339341
self.persist_train = False
340342
self.index_id = None
343+
self.feature_importances_ = None
341344

342345
def fit(
343346
self,
@@ -388,6 +391,7 @@ def fit(
388391

389392
# training XGB
390393
self.bst = xgb.train(self.xgb_params, dtrain, num_boost_round=1, **xgb_kwargs)
394+
self.feature_importances_ = self.bst.get_score()
391395

392396
# getting leaves
393397
leaves = self.bst.predict(dtrain, pred_leaf=True)

xgbse/_stacked_weibull.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def __init__(
9494
self.xgb_params = xgb_params
9595
self.weibull_params = weibull_params
9696
self.persist_train = False
97+
self.feature_importances_ = None
9798

9899
def fit(
99100
self,
@@ -167,6 +168,7 @@ def fit(
167168
evals=evals,
168169
verbose_eval=verbose_eval,
169170
)
171+
self.feature_importances_ = self.bst.get_score()
170172

171173
# predicting risk from XGBoost
172174
train_risk = self.bst.predict(dtrain)

0 commit comments

Comments
 (0)