Skip to content

Commit f1e90ac

Browse files
author
Michael Fuest
committed
updated dependencies and small fix
1 parent e7317ac commit f1e90ac

File tree

3 files changed

+11
-171
lines changed

3 files changed

+11
-171
lines changed

cents/trainer.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
from typing import Dict, List, Optional
33

44
import pytorch_lightning as pl
5+
import wandb
56
from hydra import compose, initialize_config_dir
67
from omegaconf import DictConfig, OmegaConf
78
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
89
from pytorch_lightning.loggers import WandbLogger
910

10-
import wandb
1111
from cents.data_generator import DataGenerator
1212
from cents.datasets.timeseries_dataset import TimeSeriesDataset
1313
from cents.eval.eval import Evaluator
@@ -34,7 +34,7 @@ class Trainer:
3434

3535
def __init__(
3636
self,
37-
model_name: str,
37+
model_type: str,
3838
dataset: Optional[TimeSeriesDataset] = None,
3939
cfg: Optional[DictConfig] = None,
4040
overrides: Optional[List[str]] = None,
@@ -43,26 +43,26 @@ def __init__(
4343
Initialize the Trainer.
4444
4545
Args:
46-
model_name: Key of the model ("acgan", "diffusion_ts", or "normalizer").
46+
model_type: Key of the model ("acgan", "diffusion_ts", or "normalizer").
4747
dataset: Dataset object required for generative models; optional for normalizer.
4848
cfg: Full OmegaConf DictConfig; if None, composed via Hydra.
4949
overrides: List of Hydra override strings.
5050
5151
Raises:
52-
ValueError: If model_name is unknown or dataset requirements are not met.
52+
ValueError: If model_type is unknown or dataset requirements are not met.
5353
"""
5454
try:
55-
get_model_cls(model_name)
55+
get_model_cls(model_type)
5656
except ValueError:
57-
raise ValueError(f"Unknown model '{model_name}'")
57+
raise ValueError(f"Unknown model '{model_type}'")
5858

59-
if model_name != "normalizer" and dataset is None:
60-
raise ValueError(f"Model '{model_name}' requires a TimeSeriesDataset.")
59+
if model_type != "normalizer" and dataset is None:
60+
raise ValueError(f"Model '{model_type}' requires a TimeSeriesDataset.")
6161

62-
if model_name == "normalizer" and dataset is None:
62+
if model_type == "normalizer" and dataset is None:
6363
raise ValueError("Normalizer training needs the raw dataset object.")
6464

65-
self.model_key = model_name
65+
self.model_key = model_type
6666
self.dataset = dataset
6767
self.cfg = cfg or self._compose_cfg(overrides or [])
6868

poetry.lock

Lines changed: 1 addition & 159 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ torch = "2.6.0"
2626
torchaudio = "2.6.0"
2727
torchvision = "^0.21.0"
2828
numpy = "^2.0.0"
29-
openai = "^1.57.3"
3029
pandas = "^2.2.3"
3130
matplotlib = "^3.9.4"
3231
scikit-learn = "^1.6.0"
@@ -49,7 +48,6 @@ hydra-core = "^1.3.2"
4948
pytorch-lightning = "^2.4.0"
5049
wandb = "^0.19.6"
5150
pytest-cov = "^6.0.0"
52-
boto3 = "^1.36.24"
5351
botocore = "^1.36.24"
5452
ipykernel = "^6.29.5"
5553
jupyter = "^1.1.1"

0 commit comments

Comments
 (0)