@@ -149,10 +149,10 @@ def load_dataset_config(case: str) -> DictConfig:
149149 return ds_cfg
150150
151151
152- def load_model_config (model_name : str ) -> DictConfig :
152+ def load_model_config (model_type : str ) -> DictConfig :
153153 config_dir = os .path .join (ROOT_DIR , "tests" , "test_configs" , "model" )
154154 with initialize_config_dir (config_dir = config_dir , version_base = None ):
155- model_cfg = compose (config_name = model_name , overrides = [])
155+ model_cfg = compose (config_name = model_type , overrides = [])
156156 return model_cfg
157157
158158
@@ -227,7 +227,7 @@ def dummy_trainer_diffusion_1d(
227227 merged_cfg = OmegaConf .merge (merged_cfg , {"trainer" : trainer_cfg_diffusion })
228228 merged_cfg .run_dir = str (tmp_path )
229229 trainer = Trainer (
230- model_name = "diffusion_ts" , dataset = normalized_dataset_1d , cfg = merged_cfg
230+ model_type = "diffusion_ts" , dataset = normalized_dataset_1d , cfg = merged_cfg
231231 )
232232 return trainer
233233
@@ -240,7 +240,7 @@ def dummy_trainer_acgan_1d(
240240 merged_cfg = OmegaConf .merge (full_cfg_1d , {"model" : model_cfg_acgan })
241241 merged_cfg = OmegaConf .merge (merged_cfg , {"trainer" : trainer_cfg_acgan })
242242 merged_cfg .run_dir = str (tmp_path )
243- trainer = Trainer (model_name = "acgan" , dataset = normalized_dataset_1d , cfg = merged_cfg )
243+ trainer = Trainer (model_type = "acgan" , dataset = normalized_dataset_1d , cfg = merged_cfg )
244244 return trainer
245245
246246
@@ -257,7 +257,7 @@ def dummy_trainer_diffusion_2d(
257257 merged_cfg = OmegaConf .merge (merged_cfg , {"trainer" : trainer_cfg_diffusion })
258258 merged_cfg .run_dir = str (tmp_path )
259259 trainer = Trainer (
260- model_name = "diffusion_ts" , dataset = normalized_dataset_2d , cfg = merged_cfg
260+ model_type = "diffusion_ts" , dataset = normalized_dataset_2d , cfg = merged_cfg
261261 )
262262 return trainer
263263
@@ -270,5 +270,5 @@ def dummy_trainer_acgan_2d(
270270 merged_cfg = OmegaConf .merge (full_cfg_2d , {"model" : model_cfg_acgan })
271271 merged_cfg = OmegaConf .merge (merged_cfg , {"trainer" : trainer_cfg_acgan })
272272 merged_cfg .run_dir = str (tmp_path )
273- trainer = Trainer (model_name = "acgan" , dataset = normalized_dataset_2d , cfg = merged_cfg )
273+ trainer = Trainer (model_type = "acgan" , dataset = normalized_dataset_2d , cfg = merged_cfg )
274274 return trainer
0 commit comments