Skip to content

Commit c5dde44

Browse files
fix tests (#73)
* update dependencies and fix errors
1 parent 56fa631 commit c5dde44

File tree

16 files changed

+121
-127
lines changed

16 files changed

+121
-127
lines changed

.github/workflows/auto-publish.yml

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,30 @@ jobs:
66
release:
77
runs-on: ubuntu-latest
88
steps:
9-
- uses: actions/checkout@v2
10-
- name: Set up Python 3.8
11-
uses: actions/setup-python@v2
12-
with:
13-
python-version: 3.8
14-
- name: Install Tools
15-
run: |
16-
python -m pip install --upgrade pip
17-
pip install setuptools wheel twine
18-
- name: Package and Upload
19-
env:
20-
STACKMANAGER_VERSION: ${{ github.event.release.tag_name }}
21-
TWINE_USERNAME: __token__
22-
TWINE_PASSWORD: ${{ secrets.PYPI_APIKEY }}
23-
run: |
24-
python setup.py sdist bdist_wheel
25-
twine upload dist/*
9+
- uses: actions/checkout@v2
10+
11+
- name: Set up Python 3.8
12+
uses: actions/setup-python@v2
13+
with:
14+
python-version: 3.8
15+
16+
- name: Install Poetry
17+
run: |
18+
curl -sSL https://install.python-poetry.org | python3 -
19+
echo "$HOME/.local/bin" >> $GITHUB_PATH
20+
21+
- name: Configure Poetry
22+
run: |
23+
poetry config pypi-token.pypi ${{ secrets.PYPI_APIKEY }}
24+
25+
- name: Install dependencies
26+
run: |
27+
poetry install
28+
29+
- name: Build package
30+
run: |
31+
poetry build
32+
33+
- name: Publish package
34+
run: |
35+
poetry publish

.github/workflows/testing.yml

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,29 @@ name: Code Checks
33
on:
44
push:
55
branches:
6-
- main
6+
- main
77
pull_request:
88
branches:
9-
- main
9+
- main
1010

1111
jobs:
1212
build:
1313
runs-on: ubuntu-latest
1414
strategy:
1515
matrix:
16-
python-version: [3.7]
17-
16+
python-version: ["3.9", "3.10", "3.11"]
1817
steps:
19-
- uses: actions/checkout@v2
20-
- name: Set up Python ${{ matrix.python-version }}
21-
uses: actions/setup-python@v1
22-
with:
23-
python-version: ${{ matrix.python-version }}
24-
- name: Install Testing Dependencies
25-
run: make install-dev
26-
- name: Automated checks
27-
run: make check
18+
- uses: actions/checkout@v4
19+
- uses: actions/setup-python@v5
20+
with:
21+
python-version: ${{ matrix.python-version }}
22+
- name: Install Poetry
23+
uses: snok/install-poetry@v1
24+
with:
25+
version: 1.5.1 # You can specify the Poetry version you want to use
26+
virtualenvs-create: true
27+
virtualenvs-in-project: true
28+
- name: Install dependencies
29+
run: poetry install --no-interaction --no-root
30+
- name: Automated checks
31+
run: poetry run make check

.pre-commit-config.yaml

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,8 @@ repos:
99
- id: check-yaml
1010
- id: check-added-large-files
1111
args: ["--maxkb=2000"]
12-
- repo: https://gitlab.com/pycqa/flake8
13-
rev: 8f9b4931b9a28896fb43edccb23016a7540f5b82
12+
- repo: https://github.com/astral-sh/ruff-pre-commit
13+
rev: v0.5.0
1414
hooks:
15-
- id: flake8
16-
- repo: https://github.com/psf/black
17-
rev: 20.8b1
18-
hooks:
19-
- id: black
20-
language_version: python3.7
15+
- id: ruff
16+
- id: ruff-format

Makefile

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
black:
2-
black xgbse setup.py tests/ --check
1+
format:
2+
ruff format xgbse tests/ --check
33

4-
flake:
5-
flake8 xgbse setup.py tests/
4+
lint:
5+
ruff check xgbse tests/
66

77
test:
88
pytest --cov-report term-missing --cov=xgbse tests/
99

10-
check: black flake test clean
10+
check: format lint test clean
1111

1212
install:
1313
python -m pip install -e .

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ To cite this repository:
409409
author = {Davi Vieira and Gabriel Gimenez and Guilherme Marmerola and Vitor Estima},
410410
title = {XGBoost Survival Embeddings: improving statistical properties of XGBoost survival analysis implementation},
411411
url = {http://github.com/loft-br/xgboost-survival-embeddings},
412-
version = {0.2.3},
412+
version = {0.3.1},
413413
year = {2021},
414414
}
415415
```

examples/benchmarks/benchmark.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212

1313
def dataframe_to_xy(dataf, event_column, time_column):
14-
1514
e = dataf.loc[:, event_column]
1615
t = dataf.loc[:, time_column]
1716
return dataf.drop([event_column, time_column], axis=1), convert_to_structured(t, e)
@@ -65,7 +64,6 @@ def predict(self):
6564
pass
6665

6766
def test(self):
68-
6967
self.predict()
7068
try:
7169
c_index = concordance_index(
@@ -172,7 +170,6 @@ def __init__(
172170
)
173171

174172
def train(self):
175-
176173
start = time.time()
177174
params = {"objective": self.objective}
178175

@@ -226,14 +223,12 @@ def __init__(
226223
self.objective = objective
227224

228225
def train(self):
229-
230226
start = time.time()
231227

232228
if self.model.__class__.__name__ not in [
233229
"XGBSEKaplanTree",
234230
"XGBSEBootstrapEstimator",
235231
]:
236-
237232
self.model.fit(
238233
self.X_train,
239234
self.y_train,
@@ -244,7 +239,6 @@ def train(self):
244239
)
245240

246241
else:
247-
248242
self.model.fit(self.X_train, self.y_train, time_bins=self.time_bins)
249243

250244
self.training_time = time.time() - start
@@ -284,7 +278,6 @@ def __init__(
284278
)
285279

286280
def train(self):
287-
288281
T = self.train_dataset[self.time_column]
289282
E = self.train_dataset[self.event_column]
290283

mkdocs.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,12 @@ nav:
1919
- converters: modules/converters.md
2020
- metrics: modules/metrics.md
2121
- Examples:
22+
- Basic usage: examples/basic_usage.md
2223
- Confidence intervals: examples/confidence_interval.md
2324
- Extrapolation: examples/extrapolation_example.md
2425
- Benchmarks: benchmarks/benchmarks.md
2526
plugins:
26-
- mkdocstrings:
27-
watch:
28-
- xgbse
27+
- mkdocstrings
2928
- search
3029
copyright:
3130
theme:

pyproject.toml

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
[tool.poetry]
2+
name = "xgbse"
3+
version = "0.3.1"
4+
description = "Improving XGBoost survival analysis with embeddings and debiased estimators"
5+
authors = ["Loft Data Science Team <bandits@loft.com.br>"]
6+
readme = "README.md"
7+
packages = [{ include = "xgbse" }]
8+
repository = "https://github.com/loft-br/xgboost-survival-embeddings"
9+
10+
[tool.poetry.dependencies]
11+
python = ">=3.9"
12+
xgboost = "^2.1.0"
13+
numpy = "^1.26.4"
14+
scikit-learn = "^1.5.0"
15+
pandas = "^2.2.0"
16+
joblib = "^1.4.2"
17+
lifelines = "^0.29.0"
18+
19+
[tool.poetry.group.docs.dependencies]
20+
mkdocs = "^1.6.0"
21+
mkdocs-material = "^9.5.28"
22+
mkdocstrings = { version = ">=0.18", extras = ["python-legacy"] }
23+
24+
25+
[tool.poetry.group.dev.dependencies]
26+
pre-commit = "^3.7.1"
27+
pytest = "^8.2.2"
28+
pytest-cov = "^5.0.0"
29+
ruff = "^0.5.0"
30+
31+
[tool.poetry.group.benchmark.dependencies]
32+
pycox = "0.2.1"
33+
34+
[build-system]
35+
requires = ["poetry-core"]
36+
build-backend = "poetry.core.masonry.api"

setup.py

Lines changed: 0 additions & 52 deletions
This file was deleted.

tests/test_feature_extractors.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,21 @@ def test_no_objective():
3131

3232
def test_predict_leaves_early_stop():
3333
xgbse = FeatureExtractor()
34+
early_stopping_rounds = 10
3435
xgbse.fit(
3536
X_train,
3637
y_train,
37-
num_boost_round=10000,
38+
num_boost_round=1000,
3839
validation_data=(X_valid, y_valid),
39-
early_stopping_rounds=10,
40+
early_stopping_rounds=early_stopping_rounds,
4041
verbose_eval=0,
4142
)
4243
prediction = xgbse.predict_leaves(X_test)
43-
assert prediction.shape == (
44-
X_test.shape[0],
45-
xgbse.bst.best_iteration + 1,
44+
assert prediction.shape[0] == X_test.shape[0]
45+
assert (
46+
xgbse.bst.best_iteration
47+
<= prediction.shape[1]
48+
<= xgbse.bst.best_iteration + 1 + early_stopping_rounds
4649
)
4750

4851

@@ -64,7 +67,7 @@ def test_predict_hazard_early_stop():
6467
xgbse.fit(
6568
X_train,
6669
y_train,
67-
num_boost_round=10000,
70+
num_boost_round=1000,
6871
validation_data=(X_valid, y_valid),
6972
early_stopping_rounds=10,
7073
verbose_eval=0,

0 commit comments

Comments
 (0)