Skip to content

Commit 6668af6

Browse files
committed
updated model names
1 parent 5a6ceac commit 6668af6

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

tests/workflows/test_octo_t2e.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pandas as pd
77
import pytest
88

9+
from octopus.models import ModelName
910
from octopus.modules import Octo
1011
from octopus.study import OctoTimeToEvent
1112

@@ -93,7 +94,7 @@ def test_octo_task_configuration(self):
9394
task_id=0,
9495
depends_on=None,
9596
description="step_1",
96-
models=["CatBoostCoxSurvival"],
97+
models=[ModelName.CatBoostCoxSurvival],
9798
n_trials=12,
9899
max_features=6,
99100
ensemble_selection=True,
@@ -108,9 +109,9 @@ def test_octo_task_configuration(self):
108109
assert octo_task.max_features == 6
109110
assert octo_task.ensemble_selection is True
110111
assert octo_task.ensel_n_save_trials == 10
111-
assert octo_task.models == ["CatBoostCoxSurvival"]
112+
assert octo_task.models == [ModelName.CatBoostCoxSurvival]
112113

113-
@pytest.mark.parametrize("model_name", ["CatBoostCoxSurvival", "XGBoostCoxSurvival"])
114+
@pytest.mark.parametrize("model_name", [ModelName.CatBoostCoxSurvival, ModelName.XGBoostCoxSurvival])
114115
def test_single_model_configuration(self, model_name):
115116
"""Test configuration with each survival model individually."""
116117
octo_task = Octo(
@@ -136,22 +137,22 @@ def test_multi_model_configuration(self):
136137
task_id=0,
137138
depends_on=None,
138139
description="step_1",
139-
models=["CatBoostCoxSurvival", "XGBoostCoxSurvival"],
140+
models=[ModelName.CatBoostCoxSurvival, ModelName.XGBoostCoxSurvival],
140141
n_trials=12,
141142
max_features=6,
142143
ensemble_selection=True,
143144
ensel_n_save_trials=10,
144145
)
145146

146-
assert {str(m) for m in octo_task.models} == {"CatBoostCoxSurvival", "XGBoostCoxSurvival"}
147+
assert {str(m) for m in octo_task.models} == {ModelName.CatBoostCoxSurvival, ModelName.XGBoostCoxSurvival}
147148

148149
def test_ensemble_selection_configuration(self):
149150
"""Test ensemble selection configuration."""
150151
octo_task = Octo(
151152
task_id=0,
152153
depends_on=None,
153154
description="step_1",
154-
models=["CatBoostCoxSurvival"],
155+
models=[ModelName.CatBoostCoxSurvival],
155156
n_trials=12,
156157
max_features=6,
157158
ensemble_selection=True,
@@ -167,7 +168,7 @@ def test_hyperparameter_optimization_configuration(self):
167168
task_id=0,
168169
depends_on=None,
169170
description="step_1",
170-
models=["CatBoostCoxSurvival"],
171+
models=[ModelName.CatBoostCoxSurvival],
171172
n_trials=12,
172173
max_features=6,
173174
ensemble_selection=True,
@@ -177,7 +178,7 @@ def test_hyperparameter_optimization_configuration(self):
177178
penalty_factor=1.5,
178179
)
179180

180-
assert "CatBoostCoxSurvival" in octo_task.models
181+
assert ModelName.CatBoostCoxSurvival in octo_task.models
181182
assert octo_task.n_trials == 12
182183
assert octo_task.max_features == 6
183184
assert octo_task.ensemble_selection is True
@@ -211,7 +212,7 @@ def test_octo_timetoevent_actual_execution(self, survival_dataset):
211212
task_id=0,
212213
depends_on=None,
213214
description="step_1",
214-
models=["CatBoostCoxSurvival"],
215+
models=[ModelName.CatBoostCoxSurvival],
215216
n_trials=12,
216217
max_features=6,
217218
ensemble_selection=True,
@@ -258,7 +259,7 @@ def test_full_configuration_parameters(self):
258259
task_id=0,
259260
depends_on=None,
260261
description="step_1",
261-
models=["CatBoostCoxSurvival"],
262+
models=[ModelName.CatBoostCoxSurvival],
262263
n_trials=12,
263264
max_features=6,
264265
ensemble_selection=True,
@@ -279,7 +280,7 @@ def test_full_configuration_parameters(self):
279280
assert octo_task.task_id == 0
280281
assert octo_task.depends_on is None
281282
assert octo_task.description == "step_1"
282-
assert octo_task.models == ["CatBoostCoxSurvival"]
283+
assert octo_task.models == [ModelName.CatBoostCoxSurvival]
283284
assert octo_task.model_seed == 0
284285
assert octo_task.n_jobs == 1
285286
assert octo_task.max_outl == 0

0 commit comments

Comments
 (0)