Skip to content

Unexplained behaviour in accumulate gradients vs in a ddp setting - why are the gradients different? #20882

@avishek-mondal

Description

@avishek-mondal

Bug description

Can someone explain the discrepancy I’m noticing for this snippet of code on a 4-GPU machine (I can replicate this on a 2 GPU machine as well)

import pytorch_lightning as ptl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import time

# Configs
BATCH_SIZE = 1
NUM_LABELS = 3
INPUT_DIM = 4
HIDDEN_DIM = 4
SEQ_LEN = 4
VOCAB_SIZE = 6
SEED = 42
NUM_DEVICES = 4
NUM_SAMPLES = NUM_DEVICES


# reshape functions
def reshape_with_expand(
    last_hidden_state: torch.Tensor, num_labels: int
) -> torch.Tensor:
    return last_hidden_state.unsqueeze(1).expand(-1, num_labels, -1).flatten(0, 1)


# dummy dataset
class DummyDataset(Dataset):
    def __init__(self, num_samples=NUM_SAMPLES):
        self.num_samples = num_samples
        self.pixel_data = torch.randn(num_samples, INPUT_DIM)
        self.label_data = torch.randint(
            0, VOCAB_SIZE, (num_samples, NUM_LABELS, SEQ_LEN)
        )

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.pixel_data[idx], self.label_data[idx]


class PlToyModel(ptl.LightningModule):
    def __init__(self):
        super().__init__()
        self.final_gradients = {}

        self.encoder = nn.Linear(INPUT_DIM, HIDDEN_DIM)
        self.decoder = nn.Linear(HIDDEN_DIM, VOCAB_SIZE)

    def on_after_backward(self):
        if self.trainer.is_last_batch:
            print(f"[Rank {self.global_rank}] Capturing final gradients.", flush=True)
            for name, param in self.named_parameters():
                if param.grad is not None:
                    self.final_gradients[name] = param.grad.clone()

    def training_step(self, batch, batch_idx):
        print(f"GPU {self.global_rank}: batch_idx: {batch_idx}")
        pixel_values, labels = batch
        encoder_output = self.encoder(pixel_values)
        reshaped_encoder_output = reshape_with_expand(encoder_output, NUM_LABELS)
        input_to_decoder = reshaped_encoder_output.unsqueeze(1).repeat(1, SEQ_LEN, 1)
        logits = self.decoder(input_to_decoder)
        loss = F.cross_entropy(logits.view(-1, VOCAB_SIZE), labels.view(-1))
        self.log("train_loss", loss, sync_dist=True)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=1e-3)


if __name__ == "__main__":
    shared_dataset = DummyDataset()

    ptl.seed_everything(SEED, workers=True)

    model_accumulate = PlToyModel()
    loader_accumulate = DataLoader(
        shared_dataset,
        batch_size=BATCH_SIZE,
        num_workers=2,
    )

    trainer_accumulate = ptl.Trainer(
        accelerator="gpu",
        strategy="ddp",
        max_epochs=1,
        enable_checkpointing=False,
        logger=False,
        devices=1,
        accumulate_grad_batches=NUM_SAMPLES,
    )
    print("---Starting 'accumulate'---")
    trainer_accumulate.fit(model_accumulate, loader_accumulate)

    accumulated_grads = model_accumulate.final_gradients

    # second run
    ptl.seed_everything(SEED, workers=True)

    model_ddp = PlToyModel()

    loader_ddp = DataLoader(
        shared_dataset,
        batch_size=BATCH_SIZE,
        num_workers=2,
    )

    trainer_ddp = ptl.Trainer(
        accelerator="gpu",
        strategy="ddp",
        max_epochs=1,
        enable_checkpointing=False,
        logger=False,
        devices=NUM_DEVICES,
    )
    print("---Starting 'ddp'---")
    trainer_ddp.fit(model_ddp, loader_ddp)
    ddp_grads = model_ddp.final_gradients

    time.sleep(1)

    if trainer_ddp.is_global_zero:
        print("\n" + "=" * 50)
        print("---Final gradient norm comparison ---")
        print("=" * 50)
        print(f"{'Parameter':<20} | {'Accumulated Norm':<15} | {'DDP norms':<15}")
        print("-" * 50)
        for name in accumulated_grads:
            accumulated_norm = accumulated_grads[name].norm().item()
            ddp_norm = ddp_grads[name].norm().item()
            print(f"{name:<20} | {accumulated_norm:<15.4f} | {ddp_norm:<15.4f}")
        print("=" * 50)

Why aren’t the gradients the same? In both cases, won’t it be (g1 + g2 + g3 + g4) / 4 (where g_i is the gradient as a result of the ith sample).

Also, why don’t I see something like GPU 1: batch_idx: 1 being printed in the ddp case? I have saved this snippet in a file called snippet.py and I am running python snippet.py

Any help would be appreciated!

What version are you seeing the problem on?

master

Reproduced in studio

No response

How to reproduce the bug

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):

More info

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingneeds triageWaiting to be triaged by maintainersver: 2.5.x

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions