22from typing import Dict , List , Optional
33
44import pytorch_lightning as pl
5+ import wandb
56from hydra import compose , initialize_config_dir
67from omegaconf import DictConfig , OmegaConf
78from pytorch_lightning .callbacks import Callback , ModelCheckpoint
89from pytorch_lightning .loggers import WandbLogger
910
10- import wandb
1111from cents .data_generator import DataGenerator
1212from cents .datasets .timeseries_dataset import TimeSeriesDataset
1313from cents .eval .eval import Evaluator
@@ -34,7 +34,7 @@ class Trainer:
3434
3535 def __init__ (
3636 self ,
37- model_name : str ,
37+ model_type : str ,
3838 dataset : Optional [TimeSeriesDataset ] = None ,
3939 cfg : Optional [DictConfig ] = None ,
4040 overrides : Optional [List [str ]] = None ,
@@ -43,26 +43,26 @@ def __init__(
4343 Initialize the Trainer.
4444
4545 Args:
46- model_name : Key of the model ("acgan", "diffusion_ts", or "normalizer").
46+ model_type : Key of the model ("acgan", "diffusion_ts", or "normalizer").
4747 dataset: Dataset object required for generative models; optional for normalizer.
4848 cfg: Full OmegaConf DictConfig; if None, composed via Hydra.
4949 overrides: List of Hydra override strings.
5050
5151 Raises:
52- ValueError: If model_name is unknown or dataset requirements are not met.
52+ ValueError: If model_type is unknown or dataset requirements are not met.
5353 """
5454 try :
55- get_model_cls (model_name )
55+ get_model_cls (model_type )
5656 except ValueError :
57- raise ValueError (f"Unknown model '{ model_name } '" )
57+ raise ValueError (f"Unknown model '{ model_type } '" )
5858
59- if model_name != "normalizer" and dataset is None :
60- raise ValueError (f"Model '{ model_name } ' requires a TimeSeriesDataset." )
59+ if model_type != "normalizer" and dataset is None :
60+ raise ValueError (f"Model '{ model_type } ' requires a TimeSeriesDataset." )
6161
62- if model_name == "normalizer" and dataset is None :
62+ if model_type == "normalizer" and dataset is None :
6363 raise ValueError ("Normalizer training needs the raw dataset object." )
6464
65- self .model_key = model_name
65+ self .model_key = model_type
6666 self .dataset = dataset
6767 self .cfg = cfg or self ._compose_cfg (overrides or [])
6868
0 commit comments