Skip to content

Commit 4467d7c

Browse files
author
Michael Fuest
committed
updated tests
1 parent f1e90ac commit 4467d7c

File tree

2 files changed

+20
-19
lines changed

2 files changed

+20
-19
lines changed

poetry.lock

Lines changed: 14 additions & 13 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/conftest.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -149,10 +149,10 @@ def load_dataset_config(case: str) -> DictConfig:
149149
return ds_cfg
150150

151151

152-
def load_model_config(model_name: str) -> DictConfig:
152+
def load_model_config(model_type: str) -> DictConfig:
153153
config_dir = os.path.join(ROOT_DIR, "tests", "test_configs", "model")
154154
with initialize_config_dir(config_dir=config_dir, version_base=None):
155-
model_cfg = compose(config_name=model_name, overrides=[])
155+
model_cfg = compose(config_name=model_type, overrides=[])
156156
return model_cfg
157157

158158

@@ -227,7 +227,7 @@ def dummy_trainer_diffusion_1d(
227227
merged_cfg = OmegaConf.merge(merged_cfg, {"trainer": trainer_cfg_diffusion})
228228
merged_cfg.run_dir = str(tmp_path)
229229
trainer = Trainer(
230-
model_name="diffusion_ts", dataset=normalized_dataset_1d, cfg=merged_cfg
230+
model_type="diffusion_ts", dataset=normalized_dataset_1d, cfg=merged_cfg
231231
)
232232
return trainer
233233

@@ -240,7 +240,7 @@ def dummy_trainer_acgan_1d(
240240
merged_cfg = OmegaConf.merge(full_cfg_1d, {"model": model_cfg_acgan})
241241
merged_cfg = OmegaConf.merge(merged_cfg, {"trainer": trainer_cfg_acgan})
242242
merged_cfg.run_dir = str(tmp_path)
243-
trainer = Trainer(model_name="acgan", dataset=normalized_dataset_1d, cfg=merged_cfg)
243+
trainer = Trainer(model_type="acgan", dataset=normalized_dataset_1d, cfg=merged_cfg)
244244
return trainer
245245

246246

@@ -257,7 +257,7 @@ def dummy_trainer_diffusion_2d(
257257
merged_cfg = OmegaConf.merge(merged_cfg, {"trainer": trainer_cfg_diffusion})
258258
merged_cfg.run_dir = str(tmp_path)
259259
trainer = Trainer(
260-
model_name="diffusion_ts", dataset=normalized_dataset_2d, cfg=merged_cfg
260+
model_type="diffusion_ts", dataset=normalized_dataset_2d, cfg=merged_cfg
261261
)
262262
return trainer
263263

@@ -270,5 +270,5 @@ def dummy_trainer_acgan_2d(
270270
merged_cfg = OmegaConf.merge(full_cfg_2d, {"model": model_cfg_acgan})
271271
merged_cfg = OmegaConf.merge(merged_cfg, {"trainer": trainer_cfg_acgan})
272272
merged_cfg.run_dir = str(tmp_path)
273-
trainer = Trainer(model_name="acgan", dataset=normalized_dataset_2d, cfg=merged_cfg)
273+
trainer = Trainer(model_type="acgan", dataset=normalized_dataset_2d, cfg=merged_cfg)
274274
return trainer

0 commit comments

Comments
 (0)