-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed as not planned
Closed as not planned
Copy link
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.5.x
Description
Bug description
When I trained my VAE code, Lighting would continue to eat up memory until it crashed, but when I used a simple training script, everything worked fine.
import time
import argparse
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
import lightning as L
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
from dataset import MultiVariateDataset
from models.IceVAE import IceVAE
from configs import ICE_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES
from utils.metrics import *
L.seed_everything(42)
parser = argparse.ArgumentParser()
parser.add_argument(
"--task",
type=str,
default="25km_525",
choices=list(ICE_CONFIGS.keys()),
help="The task to run.",
)
parser.add_argument(
"--size",
type=str,
default="448*304",
choices=list(SIZE_CONFIGS.keys()),
help="The area (width*height) of the data.",
)
parser.add_argument(
"--ckpt_dir",
type=str,
default="/home/ubuntu/Oscar/IceDiffusion/checkpoints/vae",
help="The path to the checkpoint directory.",
)
parser.add_argument(
"--gpus",
type=str,
default="0",
help="Specify the GPU device IDs, e.g., '0,1,2' for using GPU 0, 1, 2 (default: '0')",
)
args = parser.parse_args()
config = ICE_CONFIGS[args.task]
gpu_ids = [int(gpu_id) for gpu_id in args.gpus.split(",")]
# Datasets and Dataloaders
train_dataset = MultiVariateDataset(
config.full_data_path,
config.input_length,
config.pred_length,
19790101,
20231231,
config.max_values_path,
config.min_values_path,
)
train_dataloader = DataLoader(
dataset=train_dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers,
)
class MyLightningModule(L.LightningModule):
def __init__(self, config):
super().__init__()
self.config = config
self.model = IceVAE(
in_channels=config.num_channels,
out_channels=config.num_channels,
dim=config.dim,
z_dim=config.z_dim,
dim_mult=config.dim_mult,
num_res_blocks=config.num_res_blocks,
attn_scales=config.attn_scales,
temperal_downsample=config.temperal_downsample,
dropout=config.dropout,
)
self.save_hyperparameters(config)
def forward(self, inputs):
x_recon, mu, log_var = self.model(inputs)
return x_recon, mu, log_var
def _calculate_metrics(self, x_recon, inputs, mu, log_var):
# 计算L1重建损失
l1_loss = F.l1_loss(x_recon, inputs)
# 计算KL散度损失
kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
kl_loss = kl_loss / inputs.numel()
loss = l1_loss + 1e-5 * kl_loss
metrics = {f"loss": loss}
return metrics
def training_step(self, batch):
inputs, targets = batch
x_recon, mu, log_var = self.model(inputs)
metrics = self._calculate_metrics(x_recon, inputs, mu, log_var)
self.log_dict(metrics, prog_bar=True, logger=False, sync_dist=True)
return metrics["loss"]
def configure_optimizers(self):
optimizer = AdamW(self.model.parameters(), lr=self.config.lr)
scheduler = OneCycleLR(
optimizer,
max_lr=self.config.lr,
epochs=self.config.num_epochs,
steps_per_epoch=len(train_dataloader),
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "step", # Update the learning rate after every optimizer step
},
}
# Initialize model
model = MyLightningModule(config)
logger = CSVLogger(
save_dir=config.log_path,
name=f"{args.task}",
version=f"{args.task}_{time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())}",
)
callbacks = [
EarlyStopping(monitor="loss", patience=config.patience),
ModelCheckpoint(
monitor="loss",
dirpath=f"{args.ckpt_dir}/{args.task}",
filename=f"{args.task}_{time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())}",
),
]
trainer = Trainer(
accelerator="cuda",
strategy="ddp",
devices=gpu_ids,
precision="16-mixed",
logger=logger,
callbacks=callbacks,
max_epochs=config.num_epochs,
)
# Train model
trainer.fit(model, train_dataloader)
Here is a simple training script:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from models.IceVAE import IceVAE
from dataset.MultiVariateDataset import MultiVariateDataset
from configs.IceDiff_25km_525 import IceDiff_25km_525 as config
# 配置参数
batch_size = 8
num_epochs = 50
learning_rate = 1e-4
# Datasets and Dataloaders
train_dataset = MultiVariateDataset(
config.full_data_path,
config.input_length,
config.pred_length,
19790101,
20231231,
config.max_values_path,
config.min_values_path,
)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=4,
)
# 初始化模型
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
model = IceVAE(
in_channels=6,
out_channels=6,
dim=16,
z_dim=16,
dim_mult=[1, 2, 4, 4],
num_res_blocks=2,
attn_scales=[],
temperal_downsample=[False, True, True],
dropout=0.0,
).to(device)
# 定义优化器和损失函数
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
reconstruction_loss_fn = nn.MSELoss()
def compute_vae_loss(recon_x, x, mu, log_var):
# 重建损失
recon_loss = reconstruction_loss_fn(recon_x, x)
# KL 散度损失
kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
return recon_loss + kl_loss
# 训练循环
for epoch in range(num_epochs):
model.train()
total_loss = 0
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
for batch, targets in progress_bar:
batch = batch.to(device) # Shape: [B, C, T, H, W]
# 前向传播
recon_batch, mu, log_var = model(batch)
# 计算损失
loss = compute_vae_loss(recon_batch, batch, mu, log_var)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 更新进度条
total_loss += loss.item()
progress_bar.set_postfix(loss=loss.item())
avg_loss = total_loss / len(train_loader)
print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")
# 保存模型
torch.save(model.state_dict(), f"icevae_epoch_{epoch+1}.pth")
print("训练完成!")
What version are you seeing the problem on?
master
Reproduced in studio
No response
How to reproduce the bug
Error messages and logs
# Error messages and logs here please
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
More info
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.5.x