Skip to content

Can't resume automatically a job, ckpt_path="hpc" throws ValueError from the start #20347

@F-Barto

Description

@F-Barto

Summary

When attempting to resume a job from where it left off before reaching wall-time on a SLURM cluster using PyTorch Lightning, the ckpt_path="hpc" option causes an error if no HPC checkpoint exists yet. This prevents the initial training run from starting.

Expected Behavior

  • The job should be able to resume from an HPC checkpoint if one exists when using in combination:

    • SLURMEnvironment(auto_requeue=True, requeue_signal=signal.SIGUSR1)
    • trainer.fit(model, datamodule=dm, ckpt_path="hpc")
    • #SBATCH --signal=SIGUSR1@30
  • If no HPC checkpoint exists (e.g., on the first run), the job should start training from scratch without throwing an error. Currently it throws one:

.fit(ckpt_path="hpc")` is set but no HPC checkpoint was found.' Please pass an exact checkpoint path to `.fit(ckpt_path=...)

Current Behavior

Using ckpt_path=None allows the job to start but doesn't resume from the HPC checkpoint when one is created.

If I use trainer.fit(model, datamodule=dm, ckpt_path=None), the SIGUSR1 is correctly catched and the checkpoint hpc_ckpt_1.ckpt correctly created. However the checkpoint is not used which is expected because we left ckpt_path=None.

requeing job 245038...
Requeued SLURM job: 245038
srun: Job step aborted: Waiting up to 62 seconds for job step to finish.
slurmstepd: error: *** JOB 245038 ON jzxh061 CANCELLED AT 2024-10-17T15:45:56 DUE TO JOB REQUEUE ***
slurmstepd: error: *** STEP 245038.0 ON jzxh061 CANCELLED AT 2024-10-17T15:45:56 DUE TO JOB REQUEUE ***

Using ckpt_path="hpc" throws an error if no HPC checkpoint is found, preventing the initial training run.

The logic of looking for and loading the hpc checkpoint from what I understood should be handled by setting ckpt_path="hpc"
However, as can be seen in https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py#L193C1-L199C46
if an hpc ckpt is not found it throws an error and stops:

.fit(ckpt_path="hpc")` is set but no HPC checkpoint was found.' Please pass an exact checkpoint path to `.fit(ckpt_path=...)

The issue is that for the very first training of course there would be no hpc checkpoint because we haven't started any training yet

Relevant issues

#16639

What version are you seeing the problem on?

v2.4

How to reproduce the bug

dummy_model.py

import os
import torch
import lightning as L
from torch.utils.data import Dataset
from lightning.pytorch.callbacks import ModelCheckpoint
import argparse
from torchdata.stateful_dataloader import StatefulDataLoader
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data.distributed import DistributedSampler
import pickle
import signal
from lightning.pytorch.plugins.environments import SLURMEnvironment
import time

class DummyDataset(Dataset):
    def __init__(self, size=100000):
        self.size = size
        self.data = torch.randn(size, 10)
        self.labels = torch.randint(0, 2, (size,))
        self.current_index = 0

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        print(f"Accessing index: {idx}, {self.data[idx]}")
        return self.data[idx], self.labels[idx]

class DPAwareDataLoader(StatefulDataLoader, Stateful):
    def __init__(self, dataset: Dataset, batch_size: int, sampler=None, **kwargs):
        super().__init__(dataset, batch_size=batch_size, sampler=sampler, **kwargs)
        self._rank_id = f"dp_rank_{sampler.rank if sampler else 0}"
        print(self._rank_id, " initialized")

    def state_dict(self):
        print(f"self._rank_id: ", f"{super().state_dict()}")
        return {self._rank_id: super().state_dict()}

    def load_state_dict(self, state_dict):
        if not state_dict:
            return
        if self._rank_id not in state_dict:
            print(f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}")
            return
        print(f"DataLoader state loading for dp rank {self._dp_rank}")
        super().load_state_dict(state_dict[self._rank_id])

class DummyDataModule(L.LightningDataModule, Stateful):
    def __init__(self, batch_size=32):
        super().__init__()
        self.batch_size = batch_size
        self.train_dataset = None
        self.dataloader = None

    def setup(self, stage=None):
        self.train_dataset = DummyDataset()

        # DistributedSampler automatically retrieves world_size and rank
        # from the current distributed group.
        #
        # ref: https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
        #
        # In PyTorch Lightning:
        # - The distributed environment is initialized by the Trainer.
        # - This sets up the process group with the correct world_size and rank.
        # - DistributedSampler then uses these values automatically.
        #
        # By not specifying num_replicas and rank, we allow DistributedSampler
        # to adapt to the current distributed setup, making our code more flexible.
        # This works seamlessly with PyTorch Lightning's managed distributed training.
        #
        # Note: This automatic retrieval only works correctly if the distributed
        # environment has been initialized, which Lightning ensures before calling setup().
        self.sampler = DistributedSampler(
            self.train_dataset,
            shuffle=False  # Ensure deterministic data order across processes for testing purposes
        )

    def train_dataloader(self):
        if self.dataloader is None:
            self.dataloader = DPAwareDataLoader(
                self.train_dataset,
                batch_size=self.batch_size,
                sampler=self.sampler,
                num_workers=2
            )
        return self.dataloader

    def state_dict(self):
        return {
            "dataloader_state": self.dataloader.state_dict() if self.dataloader else None,
        }

    def load_state_dict(self, state_dict):
        if self.dataloader and state_dict["dataloader_state"]:
            self.dataloader.load_state_dict(state_dict["dataloader_state"])

class DummyModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(10, 2)
        self.loss = torch.nn.CrossEntropyLoss()
        self.example_count = 0
        self.custom_global_step = 0

    def training_step(self, batch, batch_idx):
        time.sleep(10)
        x, y = batch
        y_hat = self.layer(x)
        loss = self.loss(y_hat, y)
        self.log('train_loss', loss)

        self.example_count += len(x)
        self.custom_global_step = self.global_step

        if self.example_count % 100 == 0:
            print(f"GPU {self.global_rank}: Processed {self.example_count} examples, Global Step: {self.custom_global_step}")

        return loss

    def on_save_checkpoint(self, checkpoint):
        checkpoint['example_count'] = self.example_count
        checkpoint['custom_global_step'] = self.custom_global_step

    def on_load_checkpoint(self, checkpoint):
        self.example_count = checkpoint['example_count']
        self.custom_global_step = checkpoint['custom_global_step']

    def on_train_start(self):
        print(f"GPU {self.global_rank}: Starting/Resuming training. Example count: {self.example_count}, Global Step: {self.custom_global_step}")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

def main():
    parser = argparse.ArgumentParser(description="Train a dummy model with a unique run name.")
    parser.add_argument("run_name", type=str, help="Unique name for this training run")
    args = parser.parse_args()
    print("="*30, args.run_name, "="*30)

    log_dir = os.path.join("logs", args.run_name)
    os.makedirs(log_dir, exist_ok=True)

    model = DummyModel()
    dm = DummyDataModule(batch_size=4)

    checkpoint_callback = ModelCheckpoint(
        dirpath=os.path.join(log_dir, 'checkpoints'),
        filename='model-{epoch:02d}-{train_loss:.2f}',
        save_top_k=1,
        verbose=True,
        monitor='train_loss',
        mode='min',
        every_n_epochs=1,
        save_last=True
    )

    trainer = L.Trainer(
        max_epochs=100,
        devices=4,
        accelerator='gpu',
        strategy='ddp',
        callbacks=[checkpoint_callback],
        plugins=[SLURMEnvironment(auto_requeue=True, requeue_signal=signal.SIGUSR1)],
        default_root_dir=log_dir,
        use_distributed_sampler=False,
    )

    trainer.fit(model, datamodule=dm, ckpt_path="last")


if __name__ == '__main__':
    main()

dummy_slurm.sh

#!/bin/bash
#SBATCH --job-name=auto_requeue_test
#SBATCH -C h100
#SBATCH -A ycy@h100
#SBATCH --nodes=1
#SBATCH --qos=qos_gpu_h100-dev
#SBATCH --ntasks-per-node=4
#SBATCH --gpus-per-node=4
#SBATCH --time=00:3:00
#SBATCH --signal=SIGUSR1@30  # Send signal 30 seconds before time limit

# Load any necessary modules or activate your environment here
# For example:
module purge
module load arch/h100
module load pytorch-gpu/py3/2.4.0
export PYTHONUSERBASE=$WORK/python_envs/worldmodel

echo "Starting job at $(date)"

# Generate a unique run name using the current date and time
RUN_NAME="run_${SLURM_JOB_ID}"

# Run the Python script with the unique run name
srun python dummy_model.py "$RUN_NAME"

echo "Job ended or requeued at $(date)"

Environment

Current environment
- PyTorch Lightning Version: 2.4.0
- PyTorch Version: 2.4.0
- Python version: 3.11.9
- How you installed Lightning(`conda`, `pip`, source): pip

cc @lantiga

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions