-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_vae.py
More file actions
178 lines (153 loc) · 6.4 KB
/
train_vae.py
File metadata and controls
178 lines (153 loc) · 6.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
from typing import List, Optional
import hydra
from omegaconf import DictConfig
from pytorch_lightning import (
Callback,
LightningDataModule,
LightningModule,
Trainer,
seed_everything,
)
import pandas as pd
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.loggers.logger import Logger as LightningLoggerBase
from datetime import datetime
from core import utils
from core.callbacks import NeptuneModelLogger
from pathlib import Path
import shutil
import torch
from datetime import datetime
log = utils.get_logger(__name__)
def save_metrics_to_csv(metrics_dict, save_dir, filename="metrics.csv"):
save_dir = Path(save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
df = pd.DataFrame([metrics_dict])
csv_path = save_dir / filename
df.to_csv(csv_path, index=False)
log.info(f"Metrics saved to {csv_path}")
def train(config: DictConfig) -> Optional[float]:
"""Contains training pipeline.
Instantiates all PyTorch Lightning objects from config.
Args:
config (DictConfig): Configuration composed by Hydra.
Returns:
Optional[float]: Metric score for hyperparameter optimization.
"""
# Set seed for random number generators in pytorch, numpy and python.random
if "seed" in config:
print("seed", config.seed)
seed_everything(config.seed, workers=True)
# If required
# Init Dataloaders
log.info(f"Instantiating datamodule <{config.datamodule._target_}>")
datamodule: LightningDataModule = hydra.utils.instantiate(
config.datamodule, _convert_="partial"
)
# Init Lightning model
log.info(f"Instantiating model <{config.model._target_}>")
model: LightningModule = hydra.utils.instantiate(config.model)
# Init Lightning callbacks
lr_monitor = LearningRateMonitor(logging_interval="step")
callbacks: List[Callback] = [lr_monitor]
if "callbacks" in config:
for _, cb_conf in config["callbacks"].items():
if "_target_" in cb_conf:
log.info(f"Instantiating callback <{cb_conf._target_}>")
callbacks.append(hydra.utils.instantiate(cb_conf))
# Init Lightning loggers
logger: List[LightningLoggerBase] = []
if "logger" in config:
for _, lg_conf in config["logger"].items():
if "_target_" in lg_conf:
log.info(f"Instantiating logger <{lg_conf._target_}>")
logger.append(hydra.utils.instantiate(lg_conf))
# Init Lightning trainer
log.info(f"Instantiating trainer <{config.trainer._target_}>")
trainer: Trainer = hydra.utils.instantiate(
config.trainer,
callbacks=callbacks,
logger=logger,
)
if config.mode == "train":
# Send some parameters from config to all lightning loggers
log.info("Logging hyperparameters!")
utils.log_hyperparameters(
config=config,
model=model,
trainer=trainer,
)
# Train the model
# if config.trainer.auto_lr_find or config.trainer.auto_scale_batch_size:
# log.info("Starting tuning!")
# trainer.tune(model=model, datamodule=datamodule)
log.info("Starting training!")
trainer.fit(model=model, datamodule=datamodule)
log.info("Starting testing!")
ckpt_path = trainer.checkpoint_callback.best_model_path
if ckpt_path == "":
log.warning("Best ckpt not found! Using current weights for testing...")
ckpt_path = None
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
log.info(f"Best ckpt path: {ckpt_path}")
print(f"Best model checkpoint saved at: {ckpt_path}")
if ckpt_path:
default_root_dir = Path(trainer.default_root_dir)
best_ckpt_path = default_root_dir / "best.ckpt"
shutil.copyfile(ckpt_path, best_ckpt_path)
print(f"Copied best model to: {best_ckpt_path}")
else:
print("No best model found to save as 'best.ckpt'")
elif config.mode == "test":
# Testing mode (Pretrained Model)
pretrained_ckpt_path = config.get("pretrained_ckpt_path")
assert pretrained_ckpt_path is not None, "Pretrained ckpt path must be specified in test mode!"
log.info(f"Loading pretrained model from checkpoint: {pretrained_ckpt_path}")
model_class = type(model) # Get the class of the instantiated model
model = model_class.load_from_checkpoint(pretrained_ckpt_path, strict=False) # Load the checkpoint
# Run testing
log.info("Starting testing!")
with torch.no_grad():
trainer.test(model=model, datamodule=datamodule)
else:
raise ValueError(f"Unknown mode '{config.mode}'. Use 'train' or 'test'.")
save_dir = Path(config.paths.output_dir) / "test_results"
save_dir.mkdir(parents=True, exist_ok=True)
test_metrics = trainer.callback_metrics
print(test_metrics)
metrics_dict = {k: v.item() if hasattr(v, 'item') else v
for k, v in test_metrics.items()}
save_dir = Path(config.paths.output_dir) / "metrics"
save_metrics_to_csv(metrics_dict, save_dir)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_metrics_to_csv(
metrics_dict,
save_dir,
f"metrics_{timestamp}.csv"
)
log.info(f"Best checkpoint path:\n{trainer.checkpoint_callback.best_model_path}")
optimized_metric = config.get("optimized_metric")
if optimized_metric:
return metrics_dict[optimized_metric]
return metrics_dict
import os
os.environ["HYDRA_FULL_ERROR"] = "1"
import dotenv
import hydra
from omegaconf import DictConfig
from core import utils
# load environment variables from `.env` file if it exists
# recursively searches for `.env` in all folders starting from work dir
dotenv.load_dotenv(override=True)
import sys
import os
torch.set_float32_matmul_precision("high")
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@hydra.main(version_base="1.3", config_path="./VAE/config/", config_name="VAE.yaml")
def main(config: DictConfig):
utils.extras(config)
if config.get("print_config"):
utils.print_config(config, resolve=True)
return train(config)
if __name__ == "__main__":
main()