66import pandas as pd
77import pytest
88
9+ from octopus .models import ModelName
910from octopus .modules import Octo
1011from octopus .study import OctoTimeToEvent
1112
@@ -93,7 +94,7 @@ def test_octo_task_configuration(self):
9394 task_id = 0 ,
9495 depends_on = None ,
9596 description = "step_1" ,
96- models = [" CatBoostCoxSurvival" ],
97+ models = [ModelName . CatBoostCoxSurvival ],
9798 n_trials = 12 ,
9899 max_features = 6 ,
99100 ensemble_selection = True ,
@@ -108,9 +109,9 @@ def test_octo_task_configuration(self):
108109 assert octo_task .max_features == 6
109110 assert octo_task .ensemble_selection is True
110111 assert octo_task .ensel_n_save_trials == 10
111- assert octo_task .models == [" CatBoostCoxSurvival" ]
112+ assert octo_task .models == [ModelName . CatBoostCoxSurvival ]
112113
113- @pytest .mark .parametrize ("model_name" , [" CatBoostCoxSurvival" , " XGBoostCoxSurvival" ])
114+ @pytest .mark .parametrize ("model_name" , [ModelName . CatBoostCoxSurvival , ModelName . XGBoostCoxSurvival ])
114115 def test_single_model_configuration (self , model_name ):
115116 """Test configuration with each survival model individually."""
116117 octo_task = Octo (
@@ -136,22 +137,22 @@ def test_multi_model_configuration(self):
136137 task_id = 0 ,
137138 depends_on = None ,
138139 description = "step_1" ,
139- models = [" CatBoostCoxSurvival" , " XGBoostCoxSurvival" ],
140+ models = [ModelName . CatBoostCoxSurvival , ModelName . XGBoostCoxSurvival ],
140141 n_trials = 12 ,
141142 max_features = 6 ,
142143 ensemble_selection = True ,
143144 ensel_n_save_trials = 10 ,
144145 )
145146
146- assert {str (m ) for m in octo_task .models } == {" CatBoostCoxSurvival" , " XGBoostCoxSurvival" }
147+ assert {str (m ) for m in octo_task .models } == {ModelName . CatBoostCoxSurvival , ModelName . XGBoostCoxSurvival }
147148
148149 def test_ensemble_selection_configuration (self ):
149150 """Test ensemble selection configuration."""
150151 octo_task = Octo (
151152 task_id = 0 ,
152153 depends_on = None ,
153154 description = "step_1" ,
154- models = [" CatBoostCoxSurvival" ],
155+ models = [ModelName . CatBoostCoxSurvival ],
155156 n_trials = 12 ,
156157 max_features = 6 ,
157158 ensemble_selection = True ,
@@ -167,7 +168,7 @@ def test_hyperparameter_optimization_configuration(self):
167168 task_id = 0 ,
168169 depends_on = None ,
169170 description = "step_1" ,
170- models = [" CatBoostCoxSurvival" ],
171+ models = [ModelName . CatBoostCoxSurvival ],
171172 n_trials = 12 ,
172173 max_features = 6 ,
173174 ensemble_selection = True ,
@@ -177,7 +178,7 @@ def test_hyperparameter_optimization_configuration(self):
177178 penalty_factor = 1.5 ,
178179 )
179180
180- assert " CatBoostCoxSurvival" in octo_task .models
181+ assert ModelName . CatBoostCoxSurvival in octo_task .models
181182 assert octo_task .n_trials == 12
182183 assert octo_task .max_features == 6
183184 assert octo_task .ensemble_selection is True
@@ -211,7 +212,7 @@ def test_octo_timetoevent_actual_execution(self, survival_dataset):
211212 task_id = 0 ,
212213 depends_on = None ,
213214 description = "step_1" ,
214- models = [" CatBoostCoxSurvival" ],
215+ models = [ModelName . CatBoostCoxSurvival ],
215216 n_trials = 12 ,
216217 max_features = 6 ,
217218 ensemble_selection = True ,
@@ -258,7 +259,7 @@ def test_full_configuration_parameters(self):
258259 task_id = 0 ,
259260 depends_on = None ,
260261 description = "step_1" ,
261- models = [" CatBoostCoxSurvival" ],
262+ models = [ModelName . CatBoostCoxSurvival ],
262263 n_trials = 12 ,
263264 max_features = 6 ,
264265 ensemble_selection = True ,
@@ -279,7 +280,7 @@ def test_full_configuration_parameters(self):
279280 assert octo_task .task_id == 0
280281 assert octo_task .depends_on is None
281282 assert octo_task .description == "step_1"
282- assert octo_task .models == [" CatBoostCoxSurvival" ]
283+ assert octo_task .models == [ModelName . CatBoostCoxSurvival ]
283284 assert octo_task .model_seed == 0
284285 assert octo_task .n_jobs == 1
285286 assert octo_task .max_outl == 0
0 commit comments