Skip to content

Commit 24d39f5

Browse files
author
Michael Fuest
committed
more refactoring and training updates
1 parent bcefa87 commit 24d39f5

File tree

17 files changed

+295
-32
lines changed

17 files changed

+295
-32
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,4 @@ tutorials/outputs
123123
checkpoints
124124
bfg.jar
125125
lightning_logs/*
126+
wandb/*

cents/config/dataset/pecanstreet.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ time_series_columns: ["grid", "solar"]
1212
data_columns: ["dataid","local_15min","car1","grid","solar"]
1313
metadata_columns: ["dataid","building_type","solar","car1","city","state","total_square_footage","house_construction_year"]
1414
user_group: all # non_pv_users, all, pv_users
15-
user_id: null
1615
numeric_context_bins: 5
1716

1817
context_vars: # for each desired context variable, add the name and number of categories
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
model_name: ${model.name}
22
eval_pv_shift: False
33
eval_metrics: True
4-
eval_vis: True
5-
eval_context_sparse: False
4+
eval_context_sparse: True
65
save_results: False
6+
eval_disentanglement: True
77
save_dir: ${run_dir}/eval

cents/config/model/diffusion_ts.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ reg_weight: null
2424
gradient_accumulate_every: 2
2525
ema_decay: 0.99
2626
ema_update_interval: 10
27-
use_ema_sampling: True
27+
use_ema_sampling: False

cents/config/trainer/acgan.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ batch_size: 1024
77
sampling_batch_size: 4096
88
gradient_accumulate_every: 1
99
log_every_n_steps: 1
10+
eval_after_training: False
1011

1112
checkpoint:
1213
save_last: False

cents/config/trainer/diffusion_ts.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ log_every_n_steps: 1
77
batch_size: 1024
88
max_epochs: 5000
99
base_lr: 1e-4
10+
eval_after_training: False
1011

1112
checkpoint:
1213
save_last: False

cents/config/trainer/normalizer.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ n_epochs: 2000
88
batch_size: 4096
99
lr: 3e-4
1010
save_cycle: 5000
11+
eval_after_training: False
1112

1213
checkpoint:
1314
save_last: False

cents/data_generator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,9 @@ def load_from_checkpoint(
178178
if ckpt_path.suffix == ".ckpt":
179179
self.model = (
180180
ModelCls.load_from_checkpoint(
181-
checkpoint_path=ckpt_path, map_location=device
181+
checkpoint_path=ckpt_path,
182+
map_location=device,
183+
strict=False,
182184
)
183185
.to(device)
184186
.eval()

cents/datasets/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def check_inverse_transform(
6666
mse_list.append(mse)
6767

6868
avg_mse = np.mean(mse_list)
69-
print(f"Average MSE over all rows: {avg_mse}")
69+
print(f"[Cents] Average MSE over all rows: {avg_mse}")
7070
return avg_mse
7171

7272

cents/eval/eval.py

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,18 @@
1111
import numpy as np
1212
import pandas as pd
1313
import torch
14-
import wandb
1514
from omegaconf import DictConfig, OmegaConf
1615

16+
import wandb
1717
from cents.eval.discriminative_score import discriminative_score_metrics
1818
from cents.eval.eval_metrics import (
1919
Context_FID,
2020
calculate_mmd,
21+
compute_mig,
22+
compute_sap,
2123
dynamic_time_warping_dist,
2224
)
25+
from cents.eval.eval_utils import flatten_log_dict
2326
from cents.eval.predictive_score import predictive_score_metrics
2427
from cents.models.acgan import ACGAN
2528
from cents.models.diffusion_ts import Diffusion_TS
@@ -33,8 +36,7 @@ class Evaluator:
3336
A class for evaluating generative models on time series data.
3437
3538
This class handles the evaluation process, including metric computation,
36-
visualization generation, and results storage. It can evaluate models on
37-
either the entire dataset or specific users.
39+
visualization generation, and results storage.
3840
3941
Attributes:
4042
cfg (DictConfig): Configuration for the evaluation process
@@ -85,33 +87,26 @@ def __init__(
8587

8688
def evaluate_model(
8789
self,
88-
user_id: Optional[int] = None,
8990
model: Optional[Any] = None,
9091
) -> Dict:
9192
"""
9293
Evaluate the model and store results.
9394
9495
Args:
95-
user_id (Optional[int]): The ID of the user to evaluate. If None, evaluate on the entire dataset.
9696
model (Optional[Any]): The model to evaluate. If None, will load or train a model.
9797
9898
Returns:
9999
Dict: Dictionary containing the evaluation results
100100
"""
101-
if user_id is not None:
102-
dataset = self.real_dataset.create_user_dataset(user_id)
103-
else:
104-
dataset = self.real_dataset
101+
dataset = self.real_dataset
105102

106103
if not model:
107104
model = self.get_trained_model(dataset)
108105

109106
model.to(self.device)
107+
model.eval()
110108

111-
if user_id is not None:
112-
logger.info(f"[Cents] Starting evaluation for user {user_id}")
113-
else:
114-
logger.info("[Cents] Starting evaluation for all users")
109+
logger.info("[Cents] Starting evaluation")
115110
logger.info("----------------------")
116111

117112
self.run_evaluation(dataset, model)
@@ -120,7 +115,7 @@ def evaluate_model(
120115
self.save_results()
121116

122117
if self.cfg.get("wandb", {}).get("enabled", False) and wandb.run is not None:
123-
wandb.log(self.current_results["metrics"])
118+
wandb.log(flatten_log_dict(self.current_results["metrics"]))
124119

125120
return self.current_results
126121

@@ -172,7 +167,7 @@ def load_results(self, timestamp: Optional[str] = None) -> Dict:
172167

173168
return {"metrics": metrics, "metadata": metadata}
174169

175-
def compute_metrics(
170+
def compute_quality_metrics(
176171
self,
177172
real_data: np.ndarray,
178173
syn_data: np.ndarray,
@@ -213,8 +208,6 @@ def compute_metrics(
213208
metrics["Pred_Score"] = pred_score
214209
logger.info(f"[Cents] Pred Score completed")
215210

216-
self.current_results["metrics"] = metrics
217-
218211
if mask is not None:
219212
logger.info("[Cents] Starting Rare-Subset Metrics")
220213
rare_metrics = {}
@@ -249,6 +242,42 @@ def compute_metrics(
249242
logger.info("[Cents] Done computing Rare-Subset Metrics.")
250243
metrics["rare_subset"] = rare_metrics
251244

245+
self.current_results["metrics"] = metrics
246+
247+
def compute_disentanglement_metrics(
248+
self,
249+
context_vars: Dict[str, torch.Tensor],
250+
model: Any,
251+
) -> None:
252+
"""
253+
Compute disentanglement metrics and store them in current_results.
254+
255+
Args:
256+
context_vars (Dict[str, torch.Tensor]): Dictionary of context variables
257+
model (Any): The model to evaluate
258+
"""
259+
logger.info("[Cents] --- Starting Disentanglement Metrics ---")
260+
261+
with torch.no_grad():
262+
h, _ = model.context_module(context_vars) # (N, D)
263+
264+
emb_np = h.cpu().numpy()
265+
ctx_np = {k: v.cpu().numpy() for k, v in context_vars.items()}
266+
267+
mig, mig_detail = compute_mig(emb_np, ctx_np)
268+
sap, sap_detail = compute_sap(emb_np, ctx_np)
269+
270+
self.current_results["metrics"].setdefault("disentanglement", {})
271+
self.current_results["metrics"]["disentanglement"].update(
272+
{
273+
"MIG": {"mean": mig, **mig_detail},
274+
"SAP": {"mean": sap, **sap_detail},
275+
}
276+
)
277+
278+
logger.info("[Cents] MIG completed")
279+
logger.info("[Cents] SAP completed")
280+
252281
def get_trained_model(self, dataset: Any) -> Any:
253282
model_dict = {
254283
"acgan": ACGAN,
@@ -326,6 +355,9 @@ def evaluate_subset(
326355
):
327356
rare_mask = real_data_subset["is_rare"].values
328357

329-
self.compute_metrics(
358+
self.compute_quality_metrics(
330359
real_data_array, syn_data_array, real_data_inv, rare_mask
331360
)
361+
362+
if self.cfg.evaluator.eval_disentanglement:
363+
self.compute_disentanglement_metrics(context_vars, model)

0 commit comments

Comments
 (0)