Skip to content

Continuous memory loading causes the program to crash, but no Lighting is a problem #20824

@jialiangZ

Description

@jialiangZ

Bug description

When I trained my VAE code, Lighting would continue to eat up memory until it crashed, but when I used a simple training script, everything worked fine.

import time
import argparse
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR

import lightning as L
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger

from dataset import MultiVariateDataset
from models.IceVAE import IceVAE
from configs import ICE_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES
from utils.metrics import *


L.seed_everything(42)

parser = argparse.ArgumentParser()
parser.add_argument(
    "--task",
    type=str,
    default="25km_525",
    choices=list(ICE_CONFIGS.keys()),
    help="The task to run.",
)
parser.add_argument(
    "--size",
    type=str,
    default="448*304",
    choices=list(SIZE_CONFIGS.keys()),
    help="The area (width*height) of the data.",
)
parser.add_argument(
    "--ckpt_dir",
    type=str,
    default="/home/ubuntu/Oscar/IceDiffusion/checkpoints/vae",
    help="The path to the checkpoint directory.",
)
parser.add_argument(
    "--gpus",
    type=str,
    default="0",
    help="Specify the GPU device IDs, e.g., '0,1,2' for using GPU 0, 1, 2 (default: '0')",
)

args = parser.parse_args()

config = ICE_CONFIGS[args.task]
gpu_ids = [int(gpu_id) for gpu_id in args.gpus.split(",")]

# Datasets and Dataloaders
train_dataset = MultiVariateDataset(
    config.full_data_path,
    config.input_length,
    config.pred_length,
    19790101,
    20231231,
    config.max_values_path,
    config.min_values_path,
)

train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_size=config.batch_size,
    shuffle=True,
    num_workers=config.num_workers,
)


class MyLightningModule(L.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.model = IceVAE(
            in_channels=config.num_channels,
            out_channels=config.num_channels,
            dim=config.dim,
            z_dim=config.z_dim,
            dim_mult=config.dim_mult,
            num_res_blocks=config.num_res_blocks,
            attn_scales=config.attn_scales,
            temperal_downsample=config.temperal_downsample,
            dropout=config.dropout,
        )
        self.save_hyperparameters(config)

    def forward(self, inputs):
        x_recon, mu, log_var = self.model(inputs)
        return x_recon, mu, log_var

    def _calculate_metrics(self, x_recon, inputs, mu, log_var):
        # 计算L1重建损失
        l1_loss = F.l1_loss(x_recon, inputs)

        # 计算KL散度损失
        kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        kl_loss = kl_loss / inputs.numel()

        loss = l1_loss + 1e-5 * kl_loss

        metrics = {f"loss": loss}
        return metrics

    def training_step(self, batch):
        inputs, targets = batch
        x_recon, mu, log_var = self.model(inputs)
        metrics = self._calculate_metrics(x_recon, inputs, mu, log_var)
        self.log_dict(metrics, prog_bar=True, logger=False, sync_dist=True)
        return metrics["loss"]

    def configure_optimizers(self):
        optimizer = AdamW(self.model.parameters(), lr=self.config.lr)
        scheduler = OneCycleLR(
            optimizer,
            max_lr=self.config.lr,
            epochs=self.config.num_epochs,
            steps_per_epoch=len(train_dataloader),
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",  # Update the learning rate after every optimizer step
            },
        }


# Initialize model
model = MyLightningModule(config)

logger = CSVLogger(
    save_dir=config.log_path,
    name=f"{args.task}",
    version=f"{args.task}_{time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())}",
)

callbacks = [
    EarlyStopping(monitor="loss", patience=config.patience),
    ModelCheckpoint(
        monitor="loss",
        dirpath=f"{args.ckpt_dir}/{args.task}",
        filename=f"{args.task}_{time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime())}",
    ),
]

trainer = Trainer(
    accelerator="cuda",
    strategy="ddp",
    devices=gpu_ids,
    precision="16-mixed",
    logger=logger,
    callbacks=callbacks,
    max_epochs=config.num_epochs,
)

# Train model
trainer.fit(model, train_dataloader)

Here is a simple training script:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

from models.IceVAE import IceVAE
from dataset.MultiVariateDataset import MultiVariateDataset
from configs.IceDiff_25km_525 import IceDiff_25km_525 as config


# 配置参数
batch_size = 8
num_epochs = 50
learning_rate = 1e-4


# Datasets and Dataloaders
train_dataset = MultiVariateDataset(
    config.full_data_path,
    config.input_length,
    config.pred_length,
    19790101,
    20231231,
    config.max_values_path,
    config.min_values_path,
)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
)

# 初始化模型
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
model = IceVAE(
    in_channels=6,
    out_channels=6,
    dim=16,
    z_dim=16,
    dim_mult=[1, 2, 4, 4],
    num_res_blocks=2,
    attn_scales=[],
    temperal_downsample=[False, True, True],
    dropout=0.0,
).to(device)

# 定义优化器和损失函数
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
reconstruction_loss_fn = nn.MSELoss()


def compute_vae_loss(recon_x, x, mu, log_var):
    # 重建损失
    recon_loss = reconstruction_loss_fn(recon_x, x)
    # KL 散度损失
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return recon_loss + kl_loss


# 训练循环
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for batch, targets in progress_bar:
        batch = batch.to(device)  # Shape: [B, C, T, H, W]

        # 前向传播
        recon_batch, mu, log_var = model(batch)

        # 计算损失
        loss = compute_vae_loss(recon_batch, batch, mu, log_var)

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 更新进度条
        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())

    avg_loss = total_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")

    # 保存模型
    torch.save(model.state_dict(), f"icevae_epoch_{epoch+1}.pth")

print("训练完成!")

What version are you seeing the problem on?

master

Reproduced in studio

No response

How to reproduce the bug

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.5.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions