Skip to content

DDP Strategy Does Not Automatically Shard Batch Sizes Despite Documentation Claims #21023

@erlebach

Description

@erlebach

Bug description

DDP Strategy Does Not Automatically Shard Batch Sizes Despite Documentation Claims

Issue Description

The Lightning documentation claims that DDP automatically shards both datasets and batch sizes, but experimental testing shows that only dataset sharding works automatically. Batch sizes are not automatically divided across GPUs, preventing users from achieving expected speedup from multi-GPU training.

Expected Behavior (According to Documentation)

According to the GPU training documentation:

"Each GPU gets visibility into a subset of the overall dataset. It will only ever see that subset."

This suggests that DDP should automatically:

  1. Shard the dataset (different data per GPU) ✅ Works
  2. Shard the batch size (smaller batches per GPU) ❌ Does NOT work

Actual Behavior

When using strategy="ddp" or DDPStrategy() with multiple GPUs:

  • Dataset sharding works: Each GPU gets different data (confirmed by different input/target values)
  • Gradient synchronization works: Weight matrices remain identical across GPUs (confirmed by weight norm checks)
  • Batch size sharding does NOT work: Each GPU processes the full batch size instead of batch_size / num_gpus

Reproduction Steps

import lightning.pytorch as pl
from lightning.pytorch.strategies import DDPStrategy
from pytorch_lightning import LightningModule, Trainer
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
import torch

class SimpleModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(10, 1)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.mse_loss(y_hat, y)
        
        if batch_idx == 0:
            print(f"[Rank {self.global_rank}] Batch shape: {x.shape}")
            print(f"[Rank {self.global_rank}] First element: {x[0, 0].item():.6f}")
        
        return loss

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.01)

# Create dataset and dataloader
X = torch.randn(1000, 10)
y = torch.randn(1000, 1)
ds = TensorDataset(X, y)
dl = DataLoader(ds, batch_size=64, shuffle=False, num_workers=0)

model = SimpleModel()
strategy = DDPStrategy(find_unused_parameters=False, static_graph=True)

trainer = Trainer(
    accelerator="gpu",
    devices=2,
    strategy=strategy,
    max_epochs=1,
    logger=False,
    enable_checkpointing=False,
)

trainer.fit(model, dl)

Expected vs Actual Output

Expected (if batch sharding worked):

[Rank 0] Batch shape: torch.Size([32, 10])  # 64/2 = 32
[Rank 1] Batch shape: torch.Size([32, 10])  # 64/2 = 32

Actual output:

[Rank 0] Batch shape: torch.Size([64, 10])  # Full batch size
[Rank 1] Batch shape: torch.Size([64, 10])  # Full batch size

Impact

  1. No speedup from multi-GPU training: Each GPU uses full memory instead of reduced memory
  2. Memory pressure: Users hit OOM errors that could be avoided with proper batch sharding
  3. Misleading documentation: Users spend hours debugging why they're not getting expected performance
  4. Manual workaround required: Users must manually divide batch sizes by number of GPUs

Workaround

Users must manually reduce batch sizes in DataLoaders:

# For 2 GPUs, use batch_size=32 instead of 64
dl = DataLoader(ds, batch_size=32, shuffle=False, num_workers=0)
# Effective batch size = 32 * 2 = 64

Request

Either:

  1. Fix automatic batch sharding to match documentation claims, OR
  2. Update documentation to clarify that only dataset sharding is automatic and batch sharding requires manual configuration

Environment

  • PyTorch Lightning version: [latest]
  • PyTorch version: [latest]
  • CUDA version: [latest]
  • Number of GPUs: 2

This issue cost several hours of debugging time and should be addressed to prevent other users from experiencing the same frustration.

What version are you seeing the problem on?

v2.5

Reproduced in studio

No response

How to reproduce the bug

import lightning.pytorch as pl
from lightning.pytorch.strategies import DDPStrategy
from pytorch_lightning import LightningModule, Trainer
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
import torch

class SimpleModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(10, 1)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.mse_loss(y_hat, y)
        
        if batch_idx == 0:
            print(f"[Rank {self.global_rank}] Batch shape: {x.shape}")
            print(f"[Rank {self.global_rank}] First element: {x[0, 0].item():.6f}")
        
        return loss

    def configure_optimizers(self):
        return torch.optim.SGD(self.parameters(), lr=0.01)

# Create dataset and dataloader
X = torch.randn(1000, 10)
y = torch.randn(1000, 1)
ds = TensorDataset(X, y)
dl = DataLoader(ds, batch_size=64, shuffle=False, num_workers=0)

model = SimpleModel()
strategy = DDPStrategy(find_unused_parameters=False, static_graph=True)

trainer = Trainer(
    accelerator="gpu",
    devices=2,
    strategy=strategy,
    max_epochs=1,
    logger=False,
    enable_checkpointing=False,
)

trainer.fit(model, dl)

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source): pip
Here is my slurm submission: ```bash #!/bin/bash #SBATCH --job-name=mwe_2gpu #SBATCH --output=mwe-%x-%j.out #SBATCH --error=mwe-%x-%j.err #SBATCH --nodes=1 #SBATCH --ntasks-per-node=2 #SBATCH --gres=gpu:2 #SBATCH --mem=60GB #SBATCH --time=01:00:00 #SBATCH -A pilotgpu

Ensure proper SLURM environment variables for PyTorch Lightning

export SLURM_NTASKS_PER_NODE=2

Print cluster info for debugging

echo "=== SLURM Job Info ==="
echo "Job ID: $SLURM_JOB_ID"
echo "Nodes: $SLURM_JOB_NUM_NODES"
echo "Tasks per node: $SLURM_NTASKS_PER_NODE"
echo "CPUs per task: $SLURM_CPUS_PER_TASK"
echo "GPUs requested: $SLURM_GPUS_PER_NODE"
echo "Available GPUs: $CUDA_VISIBLE_DEVICES"
echo "======================="

Load modules and setup environment

module load cuda/12.1 || echo "Warning: Could not load CUDA module"
module load webproxy
pip install torch lightning

nproc_per_node == number of GPUs (or else won't work)

torchrun
--nproc_per_node=2
--nnodes=1
--node_rank=0
--master_addr=localhost
--master_port=12355
mwe.py



### More info

I would like confirmation about the batch size under DDS. If the batch size is 64 on a single GPU, and I run in DDS mode, should the batch size  e 32 on each GPU, or 64 on each GPU?

cc @justusschock @lantiga

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions