-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Summary
When attempting to resume a job from where it left off before reaching wall-time on a SLURM cluster using PyTorch Lightning, the ckpt_path="hpc" option causes an error if no HPC checkpoint exists yet. This prevents the initial training run from starting.
Expected Behavior
-
The job should be able to resume from an HPC checkpoint if one exists when using in combination:
SLURMEnvironment(auto_requeue=True, requeue_signal=signal.SIGUSR1)
trainer.fit(model, datamodule=dm, ckpt_path="hpc")
#SBATCH --signal=SIGUSR1@30
-
If no HPC checkpoint exists (e.g., on the first run), the job should start training from scratch without throwing an error. Currently it throws one:
.fit(ckpt_path="hpc")` is set but no HPC checkpoint was found.' Please pass an exact checkpoint path to `.fit(ckpt_path=...)
Current Behavior
Using ckpt_path=None
allows the job to start but doesn't resume from the HPC checkpoint when one is created.
If I use trainer.fit(model, datamodule=dm, ckpt_path=None)
, the SIGUSR1 is correctly catched and the checkpoint hpc_ckpt_1.ckpt
correctly created. However the checkpoint is not used which is expected because we left ckpt_path=None
.
requeing job 245038...
Requeued SLURM job: 245038
srun: Job step aborted: Waiting up to 62 seconds for job step to finish.
slurmstepd: error: *** JOB 245038 ON jzxh061 CANCELLED AT 2024-10-17T15:45:56 DUE TO JOB REQUEUE ***
slurmstepd: error: *** STEP 245038.0 ON jzxh061 CANCELLED AT 2024-10-17T15:45:56 DUE TO JOB REQUEUE ***
Using ckpt_path="hpc"
throws an error if no HPC checkpoint is found, preventing the initial training run.
The logic of looking for and loading the hpc checkpoint from what I understood should be handled by setting ckpt_path="hpc"
However, as can be seen in https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py#L193C1-L199C46
if an hpc ckpt is not found it throws an error and stops:
.fit(ckpt_path="hpc")` is set but no HPC checkpoint was found.' Please pass an exact checkpoint path to `.fit(ckpt_path=...)
The issue is that for the very first training of course there would be no hpc checkpoint because we haven't started any training yet
Relevant issues
What version are you seeing the problem on?
v2.4
How to reproduce the bug
dummy_model.py
import os
import torch
import lightning as L
from torch.utils.data import Dataset
from lightning.pytorch.callbacks import ModelCheckpoint
import argparse
from torchdata.stateful_dataloader import StatefulDataLoader
from torch.distributed.checkpoint.stateful import Stateful
from torch.utils.data.distributed import DistributedSampler
import pickle
import signal
from lightning.pytorch.plugins.environments import SLURMEnvironment
import time
class DummyDataset(Dataset):
def __init__(self, size=100000):
self.size = size
self.data = torch.randn(size, 10)
self.labels = torch.randint(0, 2, (size,))
self.current_index = 0
def __len__(self):
return self.size
def __getitem__(self, idx):
print(f"Accessing index: {idx}, {self.data[idx]}")
return self.data[idx], self.labels[idx]
class DPAwareDataLoader(StatefulDataLoader, Stateful):
def __init__(self, dataset: Dataset, batch_size: int, sampler=None, **kwargs):
super().__init__(dataset, batch_size=batch_size, sampler=sampler, **kwargs)
self._rank_id = f"dp_rank_{sampler.rank if sampler else 0}"
print(self._rank_id, " initialized")
def state_dict(self):
print(f"self._rank_id: ", f"{super().state_dict()}")
return {self._rank_id: super().state_dict()}
def load_state_dict(self, state_dict):
if not state_dict:
return
if self._rank_id not in state_dict:
print(f"DataLoader state is empty for dp rank {self._dp_rank}, expected key {self._rank_id}")
return
print(f"DataLoader state loading for dp rank {self._dp_rank}")
super().load_state_dict(state_dict[self._rank_id])
class DummyDataModule(L.LightningDataModule, Stateful):
def __init__(self, batch_size=32):
super().__init__()
self.batch_size = batch_size
self.train_dataset = None
self.dataloader = None
def setup(self, stage=None):
self.train_dataset = DummyDataset()
# DistributedSampler automatically retrieves world_size and rank
# from the current distributed group.
#
# ref: https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
#
# In PyTorch Lightning:
# - The distributed environment is initialized by the Trainer.
# - This sets up the process group with the correct world_size and rank.
# - DistributedSampler then uses these values automatically.
#
# By not specifying num_replicas and rank, we allow DistributedSampler
# to adapt to the current distributed setup, making our code more flexible.
# This works seamlessly with PyTorch Lightning's managed distributed training.
#
# Note: This automatic retrieval only works correctly if the distributed
# environment has been initialized, which Lightning ensures before calling setup().
self.sampler = DistributedSampler(
self.train_dataset,
shuffle=False # Ensure deterministic data order across processes for testing purposes
)
def train_dataloader(self):
if self.dataloader is None:
self.dataloader = DPAwareDataLoader(
self.train_dataset,
batch_size=self.batch_size,
sampler=self.sampler,
num_workers=2
)
return self.dataloader
def state_dict(self):
return {
"dataloader_state": self.dataloader.state_dict() if self.dataloader else None,
}
def load_state_dict(self, state_dict):
if self.dataloader and state_dict["dataloader_state"]:
self.dataloader.load_state_dict(state_dict["dataloader_state"])
class DummyModel(L.LightningModule):
def __init__(self):
super().__init__()
self.layer = torch.nn.Linear(10, 2)
self.loss = torch.nn.CrossEntropyLoss()
self.example_count = 0
self.custom_global_step = 0
def training_step(self, batch, batch_idx):
time.sleep(10)
x, y = batch
y_hat = self.layer(x)
loss = self.loss(y_hat, y)
self.log('train_loss', loss)
self.example_count += len(x)
self.custom_global_step = self.global_step
if self.example_count % 100 == 0:
print(f"GPU {self.global_rank}: Processed {self.example_count} examples, Global Step: {self.custom_global_step}")
return loss
def on_save_checkpoint(self, checkpoint):
checkpoint['example_count'] = self.example_count
checkpoint['custom_global_step'] = self.custom_global_step
def on_load_checkpoint(self, checkpoint):
self.example_count = checkpoint['example_count']
self.custom_global_step = checkpoint['custom_global_step']
def on_train_start(self):
print(f"GPU {self.global_rank}: Starting/Resuming training. Example count: {self.example_count}, Global Step: {self.custom_global_step}")
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.02)
def main():
parser = argparse.ArgumentParser(description="Train a dummy model with a unique run name.")
parser.add_argument("run_name", type=str, help="Unique name for this training run")
args = parser.parse_args()
print("="*30, args.run_name, "="*30)
log_dir = os.path.join("logs", args.run_name)
os.makedirs(log_dir, exist_ok=True)
model = DummyModel()
dm = DummyDataModule(batch_size=4)
checkpoint_callback = ModelCheckpoint(
dirpath=os.path.join(log_dir, 'checkpoints'),
filename='model-{epoch:02d}-{train_loss:.2f}',
save_top_k=1,
verbose=True,
monitor='train_loss',
mode='min',
every_n_epochs=1,
save_last=True
)
trainer = L.Trainer(
max_epochs=100,
devices=4,
accelerator='gpu',
strategy='ddp',
callbacks=[checkpoint_callback],
plugins=[SLURMEnvironment(auto_requeue=True, requeue_signal=signal.SIGUSR1)],
default_root_dir=log_dir,
use_distributed_sampler=False,
)
trainer.fit(model, datamodule=dm, ckpt_path="last")
if __name__ == '__main__':
main()
dummy_slurm.sh
#!/bin/bash
#SBATCH --job-name=auto_requeue_test
#SBATCH -C h100
#SBATCH -A ycy@h100
#SBATCH --nodes=1
#SBATCH --qos=qos_gpu_h100-dev
#SBATCH --ntasks-per-node=4
#SBATCH --gpus-per-node=4
#SBATCH --time=00:3:00
#SBATCH --signal=SIGUSR1@30 # Send signal 30 seconds before time limit
# Load any necessary modules or activate your environment here
# For example:
module purge
module load arch/h100
module load pytorch-gpu/py3/2.4.0
export PYTHONUSERBASE=$WORK/python_envs/worldmodel
echo "Starting job at $(date)"
# Generate a unique run name using the current date and time
RUN_NAME="run_${SLURM_JOB_ID}"
# Run the Python script with the unique run name
srun python dummy_model.py "$RUN_NAME"
echo "Job ended or requeued at $(date)"
Environment
Current environment
- PyTorch Lightning Version: 2.4.0
- PyTorch Version: 2.4.0
- Python version: 3.11.9
- How you installed Lightning(`conda`, `pip`, source): pip
cc @lantiga