-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.5.x
Description
Bug description
Can someone explain the discrepancy I’m noticing for this snippet of code on a 4-GPU machine (I can replicate this on a 2 GPU machine as well)
import pytorch_lightning as ptl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import time
# Configs
BATCH_SIZE = 1
NUM_LABELS = 3
INPUT_DIM = 4
HIDDEN_DIM = 4
SEQ_LEN = 4
VOCAB_SIZE = 6
SEED = 42
NUM_DEVICES = 4
NUM_SAMPLES = NUM_DEVICES
# reshape functions
def reshape_with_expand(
last_hidden_state: torch.Tensor, num_labels: int
) -> torch.Tensor:
return last_hidden_state.unsqueeze(1).expand(-1, num_labels, -1).flatten(0, 1)
# dummy dataset
class DummyDataset(Dataset):
def __init__(self, num_samples=NUM_SAMPLES):
self.num_samples = num_samples
self.pixel_data = torch.randn(num_samples, INPUT_DIM)
self.label_data = torch.randint(
0, VOCAB_SIZE, (num_samples, NUM_LABELS, SEQ_LEN)
)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
return self.pixel_data[idx], self.label_data[idx]
class PlToyModel(ptl.LightningModule):
def __init__(self):
super().__init__()
self.final_gradients = {}
self.encoder = nn.Linear(INPUT_DIM, HIDDEN_DIM)
self.decoder = nn.Linear(HIDDEN_DIM, VOCAB_SIZE)
def on_after_backward(self):
if self.trainer.is_last_batch:
print(f"[Rank {self.global_rank}] Capturing final gradients.", flush=True)
for name, param in self.named_parameters():
if param.grad is not None:
self.final_gradients[name] = param.grad.clone()
def training_step(self, batch, batch_idx):
print(f"GPU {self.global_rank}: batch_idx: {batch_idx}")
pixel_values, labels = batch
encoder_output = self.encoder(pixel_values)
reshaped_encoder_output = reshape_with_expand(encoder_output, NUM_LABELS)
input_to_decoder = reshaped_encoder_output.unsqueeze(1).repeat(1, SEQ_LEN, 1)
logits = self.decoder(input_to_decoder)
loss = F.cross_entropy(logits.view(-1, VOCAB_SIZE), labels.view(-1))
self.log("train_loss", loss, sync_dist=True)
return loss
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=1e-3)
if __name__ == "__main__":
shared_dataset = DummyDataset()
ptl.seed_everything(SEED, workers=True)
model_accumulate = PlToyModel()
loader_accumulate = DataLoader(
shared_dataset,
batch_size=BATCH_SIZE,
num_workers=2,
)
trainer_accumulate = ptl.Trainer(
accelerator="gpu",
strategy="ddp",
max_epochs=1,
enable_checkpointing=False,
logger=False,
devices=1,
accumulate_grad_batches=NUM_SAMPLES,
)
print("---Starting 'accumulate'---")
trainer_accumulate.fit(model_accumulate, loader_accumulate)
accumulated_grads = model_accumulate.final_gradients
# second run
ptl.seed_everything(SEED, workers=True)
model_ddp = PlToyModel()
loader_ddp = DataLoader(
shared_dataset,
batch_size=BATCH_SIZE,
num_workers=2,
)
trainer_ddp = ptl.Trainer(
accelerator="gpu",
strategy="ddp",
max_epochs=1,
enable_checkpointing=False,
logger=False,
devices=NUM_DEVICES,
)
print("---Starting 'ddp'---")
trainer_ddp.fit(model_ddp, loader_ddp)
ddp_grads = model_ddp.final_gradients
time.sleep(1)
if trainer_ddp.is_global_zero:
print("\n" + "=" * 50)
print("---Final gradient norm comparison ---")
print("=" * 50)
print(f"{'Parameter':<20} | {'Accumulated Norm':<15} | {'DDP norms':<15}")
print("-" * 50)
for name in accumulated_grads:
accumulated_norm = accumulated_grads[name].norm().item()
ddp_norm = ddp_grads[name].norm().item()
print(f"{name:<20} | {accumulated_norm:<15.4f} | {ddp_norm:<15.4f}")
print("=" * 50)
Why aren’t the gradients the same? In both cases, won’t it be (g1 + g2 + g3 + g4) / 4 (where g_i is the gradient as a result of the ith sample).
Also, why don’t I see something like GPU 1: batch_idx: 1 being printed in the ddp case? I have saved this snippet in a file called snippet.py
and I am running python snippet.py
Any help would be appreciated!
What version are you seeing the problem on?
master
Reproduced in studio
No response
How to reproduce the bug
Error messages and logs
# Error messages and logs here please
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
More info
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingneeds triageWaiting to be triaged by maintainersWaiting to be triaged by maintainersver: 2.5.x