Skip to content

Commit 5a947ee

Browse files
Merge pull request #27 from loft-br/fix/preds-using-early-stop
fix:predictions use best tree when early stopping
2 parents 6e18b55 + 614d1c0 commit 5a947ee

File tree

5 files changed

+49
-12
lines changed

5 files changed

+49
-12
lines changed

tests/test_survival_curves.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,23 @@ def test_survival_curve(model):
6767
assert_survival_curve(xgbse, X_test, preds, cindex)
6868

6969

70+
@pytest.mark.parametrize(
71+
"model", [XGBSEDebiasedBCE, XGBSEKaplanNeighbors, XGBSEStackedWeibull]
72+
)
73+
def test_survival_curve_without_early_stopping(model):
74+
xgbse = model()
75+
76+
xgbse.fit(
77+
X_train,
78+
y_train,
79+
)
80+
81+
preds = xgbse.predict(X_test)
82+
cindex = concordance_index(y_test, preds)
83+
84+
assert_survival_curve(xgbse, X_test, preds, cindex)
85+
86+
7087
def test_survival_curve_tree():
7188
xgbse = XGBSEKaplanTree()
7289

xgbse/_base.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,18 @@ def get_neighbors(
5858
index = self.tree
5959
else:
6060
index_matrix = xgb.DMatrix(index_data)
61-
index_leaves = self.bst.predict(index_matrix, pred_leaf=True)
61+
index_leaves = self.bst.predict(
62+
index_matrix, pred_leaf=True, ntree_limit=self.bst.best_ntree_limit
63+
)
6264

6365
if len(index_leaves.shape) == 1:
6466
index_leaves = index_leaves.reshape(-1, 1)
6567
index = BallTree(index_leaves, metric="hamming")
6668

6769
query_matrix = xgb.DMatrix(query_data)
68-
query_leaves = self.bst.predict(query_matrix, pred_leaf=True)
70+
query_leaves = self.bst.predict(
71+
query_matrix, pred_leaf=True, ntree_limit=self.bst.best_ntree_limit
72+
)
6973

7074
if len(query_leaves.shape) == 1:
7175
query_leaves = query_leaves.reshape(-1, 1)

xgbse/_debiased_bce.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,9 @@ def fit(
228228
self.feature_importances_ = self.bst.get_score()
229229
# predicting and encoding leaves
230230
self.encoder = OneHotEncoder()
231-
leaves = self.bst.predict(dtrain, pred_leaf=True)
231+
leaves = self.bst.predict(
232+
dtrain, pred_leaf=True, ntree_limit=self.bst.best_ntree_limit
233+
)
232234
leaves_encoded = self.encoder.fit_transform(leaves)
233235

234236
# convert targets for using with logistic regression
@@ -244,7 +246,9 @@ def fit(
244246
if index_id is None:
245247
index_id = X.index.copy()
246248

247-
index_leaves = self.bst.predict(dtrain, pred_leaf=True)
249+
index_leaves = self.bst.predict(
250+
dtrain, pred_leaf=True, ntree_limit=self.bst.best_ntree_limit
251+
)
248252
self.tree = BallTree(index_leaves, metric="hamming")
249253

250254
self.index_id = index_id
@@ -369,7 +373,9 @@ def predict(self, X, return_interval_probs=False):
369373
d_matrix = xgb.DMatrix(X)
370374

371375
# getting leaves and extracting neighbors
372-
leaves = self.bst.predict(d_matrix, pred_leaf=True)
376+
leaves = self.bst.predict(
377+
d_matrix, pred_leaf=True, ntree_limit=self.bst.best_ntree_limit
378+
)
373379
leaves_encoded = self.encoder.transform(leaves)
374380

375381
# predicting from logistic regression artifacts

xgbse/_kaplan_neighbors.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,9 @@ def fit(
175175
self.feature_importances_ = self.bst.get_score()
176176

177177
# creating nearest neighbor index
178-
leaves = self.bst.predict(dtrain, pred_leaf=True)
178+
leaves = self.bst.predict(
179+
dtrain, pred_leaf=True, ntree_limit=self.bst.best_ntree_limit
180+
)
179181

180182
self.tree = BallTree(leaves, metric="hamming", leaf_size=40)
181183

@@ -229,7 +231,9 @@ def predict(
229231
d_matrix = xgb.DMatrix(X)
230232

231233
# getting leaves and extracting neighbors
232-
leaves = self.bst.predict(d_matrix, pred_leaf=True)
234+
leaves = self.bst.predict(
235+
d_matrix, pred_leaf=True, ntree_limit=self.bst.best_ntree_limit
236+
)
233237

234238
if self.radius:
235239
assert self.radius > 0, "Radius must be positive"
@@ -394,7 +398,9 @@ def fit(
394398
self.feature_importances_ = self.bst.get_score()
395399

396400
# getting leaves
397-
leaves = self.bst.predict(dtrain, pred_leaf=True)
401+
leaves = self.bst.predict(
402+
dtrain, pred_leaf=True, ntree_limit=self.bst.best_ntree_limit
403+
)
398404

399405
# organizing elements per leaf
400406
leaf_neighs = (
@@ -462,7 +468,9 @@ def predict(self, X, return_ci=False, return_interval_probs=False):
462468
d_matrix = xgb.DMatrix(X)
463469

464470
# getting leaves and extracting neighbors
465-
leaves = self.bst.predict(d_matrix, pred_leaf=True)
471+
leaves = self.bst.predict(
472+
d_matrix, pred_leaf=True, ntree_limit=self.bst.best_ntree_limit
473+
)
466474

467475
# searching for kaplan meier curves in leaves
468476
preds_df = self._train_survival.loc[leaves].reset_index(drop=True)

xgbse/_stacked_weibull.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def fit(
171171
self.feature_importances_ = self.bst.get_score()
172172

173173
# predicting risk from XGBoost
174-
train_risk = self.bst.predict(dtrain)
174+
train_risk = self.bst.predict(dtrain, ntree_limit=self.bst.best_ntree_limit)
175175

176176
# replacing 0 by minimum positive value in df
177177
# so Weibull can be fitted
@@ -192,7 +192,9 @@ def fit(
192192
if index_id is None:
193193
index_id = X.index.copy()
194194

195-
index_leaves = self.bst.predict(dtrain, pred_leaf=True)
195+
index_leaves = self.bst.predict(
196+
dtrain, pred_leaf=True, ntree_limit=self.bst.best_ntree_limit
197+
)
196198
self.tree = BallTree(index_leaves, metric="hamming")
197199

198200
self.index_id = index_id
@@ -222,7 +224,7 @@ def predict(self, X, return_interval_probs=False):
222224
d_matrix = xgb.DMatrix(X)
223225

224226
# getting leaves and extracting neighbors
225-
risk = self.bst.predict(d_matrix)
227+
risk = self.bst.predict(d_matrix, ntree_limit=self.bst.best_ntree_limit)
226228
weibull_score_df = pd.DataFrame({"risk": risk})
227229

228230
# predicting from logistic regression artifacts

0 commit comments

Comments
 (0)