Skip to content

Commit e7317ac

Browse files
author
Michael Fuest
committed
updates tests
1 parent 0ac80c3 commit e7317ac

File tree

4 files changed

+11
-5
lines changed

4 files changed

+11
-5
lines changed

cents/data_generator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def __init__(
6363
self.set_dataset_spec(
6464
cfg.dataset, self._read_ctx_codes(cfg.dataset.name)
6565
)
66+
self.model_type = cfg.model.name
6667
elif model_name is not None:
6768
self.model_type = get_model_type_from_hf_name(model_name)
6869
self.cfg = cfg or self._default_cfg()

tests/evaluator/test_evaluator.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ def cfg(tmp_path):
5959
"eval_vis": True,
6060
"eval_pv_shift": False,
6161
"eval_context_sparse": False,
62+
"eval_disentanglement": False,
6263
},
6364
}
6465
)
@@ -81,7 +82,7 @@ def test_compute_metrics_sets_all_keys(evaluator):
8182
"context": [0, 1, 0],
8283
}
8384
)
84-
evaluator.compute_metrics(real, syn, df)
85+
evaluator.compute_quality_metrics(real, syn, df)
8586
for key in ("DTW", "MMD", "Context_FID", "Disc_Score", "Pred_Score"):
8687
assert key in evaluator.current_results["metrics"], f"{key} missing"
8788

@@ -103,6 +104,9 @@ class DummyModel:
103104
def to(self, device):
104105
return self
105106

107+
def eval(self):
108+
return self
109+
106110
def generate(self, ctx):
107111
batch_size = next(iter(ctx.values())).shape[0]
108112
return torch.zeros((batch_size, 1, 1))

tests/model/test_models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def test_trainer_fit_diffusion_1d(dummy_trainer_diffusion_1d):
1212
def test_trainer_get_data_generator_diffusion_1d(dummy_trainer_diffusion_1d):
1313
trainer = dummy_trainer_diffusion_1d
1414
dg = trainer.get_data_generator()
15-
assert dg.model_name == "diffusion_ts"
15+
assert dg.model_type == "diffusion_ts"
1616
assert dg.model is not None
1717

1818

@@ -25,7 +25,7 @@ def test_trainer_fit_acgan_1d(dummy_trainer_acgan_1d):
2525
def test_trainer_get_data_generator_acgan_1d(dummy_trainer_acgan_1d):
2626
trainer = dummy_trainer_acgan_1d
2727
dg = trainer.get_data_generator()
28-
assert dg.model_name == "acgan"
28+
assert dg.model_type == "acgan"
2929
assert dg.model is not None
3030

3131

@@ -38,7 +38,7 @@ def test_trainer_fit_diffusion_2d(dummy_trainer_diffusion_2d):
3838
def test_trainer_get_data_generator_diffusion_2d(dummy_trainer_diffusion_2d):
3939
trainer = dummy_trainer_diffusion_2d
4040
dg = trainer.get_data_generator()
41-
assert dg.model_name == "diffusion_ts"
41+
assert dg.model_type == "diffusion_ts"
4242
assert dg.model is not None
4343

4444

@@ -51,5 +51,5 @@ def test_trainer_fit_acgan_2d(dummy_trainer_acgan_2d):
5151
def test_trainer_get_data_generator_acgan_2d(dummy_trainer_acgan_2d):
5252
trainer = dummy_trainer_acgan_2d
5353
dg = trainer.get_data_generator()
54-
assert dg.model_name == "acgan"
54+
assert dg.model_type == "acgan"
5555
assert dg.model is not None
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
model_name: ${model.name}
22
eval_pv_shift: False
3+
eval_disentanglement: True
34
eval_metrics: True
45
eval_vis: True
56
eval_context_sparse: False

0 commit comments

Comments
 (0)