Skip to content

Commit 4633b47

Browse files
authored
[doc] Display survival demos in sphinx doc. [skip ci] (dmlc#8328)
1 parent 3ef1703 commit 4633b47

File tree

8 files changed

+31
-15
lines changed

8 files changed

+31
-15
lines changed

demo/aft_survival/README.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Survival Analysis Walkthrough
2+
=============================
3+
4+
This is a collection of examples for using the XGBoost Python package for training
5+
survival models. For an introduction, see :doc:`/tutorials/aft_survival_analysis`

demo/aft_survival/aft_survival_demo.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
"""
2-
Demo for survival analysis (regression) using Accelerated Failure Time (AFT) model
2+
Demo for survival analysis (regression).
3+
========================================
4+
5+
Demo for survival analysis (regression). using Accelerated Failure Time (AFT) model.
36
"""
7+
48
import os
59
from sklearn.model_selection import ShuffleSplit
610
import pandas as pd

demo/aft_survival/aft_survival_demo_with_optuna.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
"""
2-
Demo for survival analysis (regression) using Accelerated Failure Time (AFT) model, using Optuna
3-
to tune hyperparameters
2+
Demo for survival analysis (regression) with Optuna.
3+
====================================================
4+
5+
Demo for survival analysis (regression) using Accelerated Failure Time (AFT) model,
6+
using Optuna to tune hyperparameters
7+
48
"""
59
from sklearn.model_selection import ShuffleSplit
610
import pandas as pd
@@ -45,7 +49,7 @@ def objective(trial):
4549
params.update(base_params)
4650
pruning_callback = optuna.integration.XGBoostPruningCallback(trial, 'valid-aft-nloglik')
4751
bst = xgb.train(params, dtrain, num_boost_round=10000,
48-
evals=[(dtrain, 'train'), (dvalid, 'valid')],
52+
evals=[(dtrain, 'train'), (dvalid, 'valid')],
4953
early_stopping_rounds=50, verbose_eval=False, callbacks=[pruning_callback])
5054
if bst.best_iteration >= 25:
5155
return bst.best_score
@@ -63,7 +67,7 @@ def objective(trial):
6367
# Re-run training with the best hyperparameter combination
6468
print('Re-running the best trial... params = {}'.format(params))
6569
bst = xgb.train(params, dtrain, num_boost_round=10000,
66-
evals=[(dtrain, 'train'), (dvalid, 'valid')],
70+
evals=[(dtrain, 'train'), (dvalid, 'valid')],
6771
early_stopping_rounds=50)
6872

6973
# Run prediction on the validation set

demo/aft_survival/aft_survival_viz_demo.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""
22
Visual demo for survival analysis (regression) with Accelerated Failure Time (AFT) model.
3+
=========================================================================================
34
4-
This demo uses 1D toy data and visualizes how XGBoost fits a tree ensemble. The ensemble model
5-
starts out as a flat line and evolves into a step function in order to account for all ranged
6-
labels.
5+
This demo uses 1D toy data and visualizes how XGBoost fits a tree ensemble. The ensemble
6+
model starts out as a flat line and evolves into a step function in order to account for
7+
all ranged labels.
78
"""
89
import numpy as np
910
import xgboost as xgb
@@ -57,7 +58,7 @@ def plot_intermediate_model_callback(env):
5758
# the corresponding predicted label (y_pred)
5859
acc = np.sum(np.logical_and(y_pred >= y_lower, y_pred <= y_upper)/len(X) * 100)
5960
accuracy_history.append(acc)
60-
61+
6162
# Plot ranged labels as well as predictions by the model
6263
plt.subplot(5, 3, env.iteration + 1)
6364
plot_censored_labels(X, y_lower, y_upper)

doc/conf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@
9393

9494
sphinx_gallery_conf = {
9595
# path to your example scripts
96-
"examples_dirs": ["../demo/guide-python", "../demo/dask"],
96+
"examples_dirs": ["../demo/guide-python", "../demo/dask", "../demo/aft_survival"],
9797
# path to where to save gallery generated output
98-
"gallery_dirs": ["python/examples", "python/dask-examples"],
98+
"gallery_dirs": ["python/examples", "python/dask-examples", "python/survival-examples"],
9999
"matplotlib_animations": True,
100100
}
101101

doc/python/.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
examples
2-
dask-examples
2+
dask-examples
3+
survival-examples

doc/python/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ Contents
1515
model
1616
examples/index
1717
dask-examples/index
18+
survival-examples/index

doc/tutorials/aft_survival_analysis.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ Collect the lower bound numbers in one array (let's call it ``y_lower_bound``) a
9898
# 4-by-2 Data matrix
9999
X = np.array([[1, -1], [-1, 1], [0, 1], [1, 0]])
100100
dtrain = xgb.DMatrix(X)
101-
101+
102102
# Associate ranged labels with the data matrix.
103103
# This example shows each kind of censored labels.
104104
# uncensored right left interval
@@ -109,7 +109,7 @@ Collect the lower bound numbers in one array (let's call it ``y_lower_bound``) a
109109
110110
.. code-block:: r
111111
:caption: R
112-
112+
113113
library(xgboost)
114114
115115
# 4-by-2 Data matrix
@@ -165,4 +165,4 @@ Currently, you can choose from three probability distributions for ``aft_loss_di
165165
``extreme`` :math:`e^z e^{-\exp{z}}`
166166
========================= ===========================================
167167

168-
Note that it is not yet possible to set the ranged label using the scikit-learn interface (e.g. :class:`xgboost.XGBRegressor`). For now, you should use :class:`xgboost.train` with :class:`xgboost.DMatrix`.
168+
Note that it is not yet possible to set the ranged label using the scikit-learn interface (e.g. :class:`xgboost.XGBRegressor`). For now, you should use :class:`xgboost.train` with :class:`xgboost.DMatrix`. For a collection of Python examples, see :doc:`/python/survival-examples/index`

0 commit comments

Comments
 (0)