Skip to content

Commit 39cc84a

Browse files
authored
Merge pull request #240 from CARRIER-project/226-last-tweaks-documentation
226 last tweaks documentation
2 parents cf6d847 + 65e09f8 commit 39cc84a

File tree

8 files changed

+192
-56
lines changed

8 files changed

+192
-56
lines changed

demo.ipynb

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,9 +286,6 @@
286286
"## Running cox proportional hazard analysis\n",
287287
"If you want to fit a model on the entire dataset you can run `VerticoxClient.fit`\n",
288288
"\n",
289-
"### Docstring\n",
290-
"Run cox proportional hazard analysis on the entire dataset.\n",
291-
"\n",
292289
"Args:\n",
293290
"\n",
294291
"- __feature_columns:__ a list of column names that you want to use as features\n",

docs/development.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,12 @@ analysis.
5757
```
5858

5959
## Further development
60+
If this code would be developed further, it would be good to add the following features:
61+
62+
### Sample size threshold
63+
In order to further prevent data leakage, a threshold should be added to prevent the
64+
analysis from being run if the sample size is below a certain threshold.
65+
66+
### Upgrade to latest vantage6 version
67+
The current version of vantage6 that is being used is 4.7.1. In order to make it compatible with
68+
the latest version, the code should comply with the new client api.

docs/images/vantage6.png

6.3 KB
Loading

mkdocs.yml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,28 @@ nav:
1313

1414
theme:
1515
name: material
16+
logo: images/vantage6.png
17+
palette:
18+
primary: white
1619
plugins:
1720
- mkdocstrings:
1821
handlers:
1922
python:
2023
paths: [python]
24+
docstring_style: google
2125
- search
2226

2327
markdown_extensions:
2428
- pymdownx.arithmatex:
2529
generic: true
2630
- pymdownx.blocks.caption
31+
- pymdownx.highlight:
32+
anchor_linenums: true
33+
line_spans: __span
34+
pygments_lang_class: true
35+
- pymdownx.inlinehilite
36+
- pymdownx.snippets
37+
- pymdownx.superfences
2738

2839
extra_javascript:
2940
- javascripts/mathjax.js

python/pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,4 +56,6 @@ build-backend = "setuptools.build_meta"
5656
pythonpath = "."
5757
addopts = "-v --log-level=INFO --log-file=test.log"
5858

59+
[tool.ruff]
60+
indent-width = 4
5961

python/tests/test_verticox_v6.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
#! /usr/bin/env python3
22

33
import json
4-
import numpy as np
54
import vantage6.client as v6client
65
from clize import run
7-
from verticox.client import FitResult
86

97
from test_constants import OUTCOME_TIME_COLUMN, OUTCOME, PRECISION
108
from verticox.client import VerticoxClient
@@ -17,7 +15,7 @@
1715

1816

1917
def run_verticox_v6(host, port, user, password, *, private_key=None, image: str=IMAGE,
20-
method="fit"):
18+
method="fit", precision: float = PRECISION):
2119

2220
client = v6client.Client(host, port, log_level="warning")
2321

@@ -63,31 +61,24 @@ def run_verticox_v6(host, port, user, password, *, private_key=None, image: str=
6361
OUTCOME,
6462
feature_nodes=feature_orgs,
6563
outcome_node=central_node,
66-
precision=PRECISION,
64+
precision=precision,
6765
database=DATABASE,
6866
)
6967
case "crossval":
7068
task = verticox_client.cross_validate(
7169
feature_columns,
7270
OUTCOME_TIME_COLUMN,
7371
OUTCOME,
74-
feature_nodes=datanodes,
72+
feature_nodes=feature_orgs,
7573
outcome_node=central_node,
76-
precision=PRECISION,
74+
precision=precision,
7775
database=DATABASE,
7876
)
7977

80-
results = task.get_results(timeout=TIMEOUT)
78+
results = task.get_results()
8179

8280
print("Results: ", results)
8381

84-
match results:
85-
case FitResult(coefs, baseline_hazard):
86-
for key, value in coefs.items():
87-
np.testing.assert_almost_equal(value, TARGET_COEFS[key], decimal=4)
88-
89-
print("Test passed")
90-
9182

9283
if __name__ == "__main__":
9384
run(run_verticox_v6)

python/verticox/client.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,29 @@ def __init__(
4242

4343
self.collaboration_id = collaborations[0]["id"]
4444

45-
def get_active_node_organizations(self):
45+
def get_active_node_organizations(self) -> List[int]:
46+
"""
47+
Get the organization ids of the active nodes in the collaboration.
48+
49+
Returns: a list of organization ids
50+
51+
"""
4652
nodes = self._v6client.node.list(is_online=True)
4753

4854
# TODO: Add pagination support
4955
nodes = nodes["data"]
5056
return [n["organization"]["id"] for n in nodes]
5157

5258
def get_column_names(self, **kwargs):
59+
"""
60+
Get the column names of the dataset at all active nodes.
61+
62+
Args:
63+
**kwargs:
64+
65+
Returns:
66+
67+
"""
5368
active_nodes = self.get_active_node_organizations()
5469
self._logger.debug(f"There are currently {len(active_nodes)} active nodes")
5570

@@ -84,7 +99,7 @@ def fit(
8499
database: If the nodes have multiple datasources, indicate the label of the datasource
85100
you would like to use. Otherwise the default will be used.
86101
87-
Returns:
102+
Returns: a `Task` object containing info about the task.
88103
89104
"""
90105
input_params = {
@@ -107,14 +122,36 @@ def cross_validate(self,
107122
feature_nodes,
108123
outcome_node,
109124
precision=_DEFAULT_PRECISION,
125+
n_splits = 10,
110126
database="default"):
127+
"""
128+
Run cox proportional hazard analysis on the entire dataset using cross-validation. Uses 10
129+
fold by default.
130+
131+
Args:
132+
feature_columns: a list of column names that you want to use as features
133+
outcome_time_column: the column name of the outcome time
134+
right_censor_column: the column name of the binary value that indicates if an event
135+
happened.
136+
feature_nodes: A list of node ids from the datasources that contain the feature columns
137+
outcome_node: The node id of the datasource that contains the outcome
138+
precision: precision of the verticox algorithm. The smaller the number, the more
139+
precise the result. Smaller precision will take longer to compute though. The default is
140+
1e-5
141+
n_splits: The number of folds to use for cross-validation. Default is 10.
142+
database: If the nodes have multiple datasources, indicate the label of the datasource
143+
you would like to use. Otherwise the default will be used.
144+
145+
Returns: a `Task` object containing info about the task.
146+
"""
111147
input_params = {
112148
"feature_columns": feature_columns,
113149
"event_times_column": outcome_time_column,
114150
"event_happened_column": right_censor_column,
115151
"datanode_ids": feature_nodes,
116152
"central_node_id": outcome_node,
117153
"convergence_precision": precision,
154+
"n_splits": n_splits,
118155
}
119156

120157
return self._run_task(
@@ -166,6 +203,10 @@ def _run_task(
166203

167204
@dataclass
168205
class FitResult:
206+
"""
207+
FitResult contains the result of a fit task. It contains the coefficients and the baseline
208+
hazard function.
209+
"""
169210
coefs: Dict[str, float]
170211
baseline_hazard: HazardFunction
171212

@@ -191,6 +232,10 @@ def plot(self):
191232

192233
@dataclass
193234
class CrossValResult:
235+
"""
236+
CrossValResult contains the result of a cross-validation task. It contains the c-indices,
237+
coefficients and baseline hazard functions for each fold.
238+
"""
194239
c_indices: List[float]
195240
coefs: List[Dict[str, float]]
196241
baseline_hazards: List[HazardFunction]
@@ -217,20 +262,27 @@ def plot(self):
217262

218263

219264
class Task:
220-
265+
"""
266+
Task is a wrapper around the vantage6 task object.
267+
"""
221268
def __init__(self, client: Client, task_data):
222269
self._raw_data = task_data
223270
self.client = client
224271
self.task_id = task_data["id"]
225272

226-
def get_results(self, timeout=_TIMEOUT):
273+
def get_results(self) -> PartialResult:
274+
"""
275+
Get the results of the task. This will block until the task is finished.
276+
277+
Returns:
278+
279+
"""
227280
results = self.client.wait_for_results(self.task_id)
228-
print(f"Results: {results}")
229281
return self._parse_results(results["data"])
230282

231283

232284
@staticmethod
233-
def _parse_results(results):
285+
def _parse_results(results) -> FitResult| CrossValResult:
234286
return results
235287

236288

0 commit comments

Comments
 (0)