1111import numpy as np
1212import pandas as pd
1313import torch
14- import wandb
1514from omegaconf import DictConfig , OmegaConf
1615
16+ import wandb
1717from cents .eval .discriminative_score import discriminative_score_metrics
1818from cents .eval .eval_metrics import (
1919 Context_FID ,
2020 calculate_mmd ,
21+ compute_mig ,
22+ compute_sap ,
2123 dynamic_time_warping_dist ,
2224)
25+ from cents .eval .eval_utils import flatten_log_dict
2326from cents .eval .predictive_score import predictive_score_metrics
2427from cents .models .acgan import ACGAN
2528from cents .models .diffusion_ts import Diffusion_TS
@@ -33,8 +36,7 @@ class Evaluator:
3336 A class for evaluating generative models on time series data.
3437
3538 This class handles the evaluation process, including metric computation,
36- visualization generation, and results storage. It can evaluate models on
37- either the entire dataset or specific users.
39+ visualization generation, and results storage.
3840
3941 Attributes:
4042 cfg (DictConfig): Configuration for the evaluation process
@@ -85,33 +87,26 @@ def __init__(
8587
8688 def evaluate_model (
8789 self ,
88- user_id : Optional [int ] = None ,
8990 model : Optional [Any ] = None ,
9091 ) -> Dict :
9192 """
9293 Evaluate the model and store results.
9394
9495 Args:
95- user_id (Optional[int]): The ID of the user to evaluate. If None, evaluate on the entire dataset.
9696 model (Optional[Any]): The model to evaluate. If None, will load or train a model.
9797
9898 Returns:
9999 Dict: Dictionary containing the evaluation results
100100 """
101- if user_id is not None :
102- dataset = self .real_dataset .create_user_dataset (user_id )
103- else :
104- dataset = self .real_dataset
101+ dataset = self .real_dataset
105102
106103 if not model :
107104 model = self .get_trained_model (dataset )
108105
109106 model .to (self .device )
107+ model .eval ()
110108
111- if user_id is not None :
112- logger .info (f"[Cents] Starting evaluation for user { user_id } " )
113- else :
114- logger .info ("[Cents] Starting evaluation for all users" )
109+ logger .info ("[Cents] Starting evaluation" )
115110 logger .info ("----------------------" )
116111
117112 self .run_evaluation (dataset , model )
@@ -120,7 +115,7 @@ def evaluate_model(
120115 self .save_results ()
121116
122117 if self .cfg .get ("wandb" , {}).get ("enabled" , False ) and wandb .run is not None :
123- wandb .log (self .current_results ["metrics" ])
118+ wandb .log (flatten_log_dict ( self .current_results ["metrics" ]) )
124119
125120 return self .current_results
126121
@@ -172,7 +167,7 @@ def load_results(self, timestamp: Optional[str] = None) -> Dict:
172167
173168 return {"metrics" : metrics , "metadata" : metadata }
174169
175- def compute_metrics (
170+ def compute_quality_metrics (
176171 self ,
177172 real_data : np .ndarray ,
178173 syn_data : np .ndarray ,
@@ -213,8 +208,6 @@ def compute_metrics(
213208 metrics ["Pred_Score" ] = pred_score
214209 logger .info (f"[Cents] Pred Score completed" )
215210
216- self .current_results ["metrics" ] = metrics
217-
218211 if mask is not None :
219212 logger .info ("[Cents] Starting Rare-Subset Metrics" )
220213 rare_metrics = {}
@@ -249,6 +242,42 @@ def compute_metrics(
249242 logger .info ("[Cents] Done computing Rare-Subset Metrics." )
250243 metrics ["rare_subset" ] = rare_metrics
251244
245+ self .current_results ["metrics" ] = metrics
246+
247+ def compute_disentanglement_metrics (
248+ self ,
249+ context_vars : Dict [str , torch .Tensor ],
250+ model : Any ,
251+ ) -> None :
252+ """
253+ Compute disentanglement metrics and store them in current_results.
254+
255+ Args:
256+ context_vars (Dict[str, torch.Tensor]): Dictionary of context variables
257+ model (Any): The model to evaluate
258+ """
259+ logger .info ("[Cents] --- Starting Disentanglement Metrics ---" )
260+
261+ with torch .no_grad ():
262+ h , _ = model .context_module (context_vars ) # (N, D)
263+
264+ emb_np = h .cpu ().numpy ()
265+ ctx_np = {k : v .cpu ().numpy () for k , v in context_vars .items ()}
266+
267+ mig , mig_detail = compute_mig (emb_np , ctx_np )
268+ sap , sap_detail = compute_sap (emb_np , ctx_np )
269+
270+ self .current_results ["metrics" ].setdefault ("disentanglement" , {})
271+ self .current_results ["metrics" ]["disentanglement" ].update (
272+ {
273+ "MIG" : {"mean" : mig , ** mig_detail },
274+ "SAP" : {"mean" : sap , ** sap_detail },
275+ }
276+ )
277+
278+ logger .info ("[Cents] MIG completed" )
279+ logger .info ("[Cents] SAP completed" )
280+
252281 def get_trained_model (self , dataset : Any ) -> Any :
253282 model_dict = {
254283 "acgan" : ACGAN ,
@@ -326,6 +355,9 @@ def evaluate_subset(
326355 ):
327356 rare_mask = real_data_subset ["is_rare" ].values
328357
329- self .compute_metrics (
358+ self .compute_quality_metrics (
330359 real_data_array , syn_data_array , real_data_inv , rare_mask
331360 )
361+
362+ if self .cfg .evaluator .eval_disentanglement :
363+ self .compute_disentanglement_metrics (context_vars , model )
0 commit comments