@@ -42,14 +42,29 @@ def __init__(
4242
4343 self .collaboration_id = collaborations [0 ]["id" ]
4444
45- def get_active_node_organizations (self ):
45+ def get_active_node_organizations (self ) -> List [int ]:
46+ """
47+ Get the organization ids of the active nodes in the collaboration.
48+
49+ Returns: a list of organization ids
50+
51+ """
4652 nodes = self ._v6client .node .list (is_online = True )
4753
4854 # TODO: Add pagination support
4955 nodes = nodes ["data" ]
5056 return [n ["organization" ]["id" ] for n in nodes ]
5157
5258 def get_column_names (self , ** kwargs ):
59+ """
60+ Get the column names of the dataset at all active nodes.
61+
62+ Args:
63+ **kwargs:
64+
65+ Returns:
66+
67+ """
5368 active_nodes = self .get_active_node_organizations ()
5469 self ._logger .debug (f"There are currently { len (active_nodes )} active nodes" )
5570
@@ -84,7 +99,7 @@ def fit(
8499 database: If the nodes have multiple datasources, indicate the label of the datasource
85100 you would like to use. Otherwise the default will be used.
86101
87- Returns:
102+ Returns: a `Task` object containing info about the task.
88103
89104 """
90105 input_params = {
@@ -107,14 +122,36 @@ def cross_validate(self,
107122 feature_nodes ,
108123 outcome_node ,
109124 precision = _DEFAULT_PRECISION ,
125+ n_splits = 10 ,
110126 database = "default" ):
127+ """
128+ Run cox proportional hazard analysis on the entire dataset using cross-validation. Uses 10
129+ fold by default.
130+
131+ Args:
132+ feature_columns: a list of column names that you want to use as features
133+ outcome_time_column: the column name of the outcome time
134+ right_censor_column: the column name of the binary value that indicates if an event
135+ happened.
136+ feature_nodes: A list of node ids from the datasources that contain the feature columns
137+ outcome_node: The node id of the datasource that contains the outcome
138+ precision: precision of the verticox algorithm. The smaller the number, the more
139+ precise the result. Smaller precision will take longer to compute though. The default is
140+ 1e-5
141+ n_splits: The number of folds to use for cross-validation. Default is 10.
142+ database: If the nodes have multiple datasources, indicate the label of the datasource
143+ you would like to use. Otherwise the default will be used.
144+
145+ Returns: a `Task` object containing info about the task.
146+ """
111147 input_params = {
112148 "feature_columns" : feature_columns ,
113149 "event_times_column" : outcome_time_column ,
114150 "event_happened_column" : right_censor_column ,
115151 "datanode_ids" : feature_nodes ,
116152 "central_node_id" : outcome_node ,
117153 "convergence_precision" : precision ,
154+ "n_splits" : n_splits ,
118155 }
119156
120157 return self ._run_task (
@@ -166,6 +203,10 @@ def _run_task(
166203
167204@dataclass
168205class FitResult :
206+ """
207+ FitResult contains the result of a fit task. It contains the coefficients and the baseline
208+ hazard function.
209+ """
169210 coefs : Dict [str , float ]
170211 baseline_hazard : HazardFunction
171212
@@ -191,6 +232,10 @@ def plot(self):
191232
192233@dataclass
193234class CrossValResult :
235+ """
236+ CrossValResult contains the result of a cross-validation task. It contains the c-indices,
237+ coefficients and baseline hazard functions for each fold.
238+ """
194239 c_indices : List [float ]
195240 coefs : List [Dict [str , float ]]
196241 baseline_hazards : List [HazardFunction ]
@@ -217,20 +262,27 @@ def plot(self):
217262
218263
219264class Task :
220-
265+ """
266+ Task is a wrapper around the vantage6 task object.
267+ """
221268 def __init__ (self , client : Client , task_data ):
222269 self ._raw_data = task_data
223270 self .client = client
224271 self .task_id = task_data ["id" ]
225272
226- def get_results (self , timeout = _TIMEOUT ):
273+ def get_results (self ) -> PartialResult :
274+ """
275+ Get the results of the task. This will block until the task is finished.
276+
277+ Returns:
278+
279+ """
227280 results = self .client .wait_for_results (self .task_id )
228- print (f"Results: { results } " )
229281 return self ._parse_results (results ["data" ])
230282
231283
232284 @staticmethod
233- def _parse_results (results ):
285+ def _parse_results (results ) -> FitResult | CrossValResult :
234286 return results
235287
236288
0 commit comments