-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
DDP Strategy Does Not Automatically Shard Batch Sizes Despite Documentation Claims
Issue Description
The Lightning documentation claims that DDP automatically shards both datasets and batch sizes, but experimental testing shows that only dataset sharding works automatically. Batch sizes are not automatically divided across GPUs, preventing users from achieving expected speedup from multi-GPU training.
Expected Behavior (According to Documentation)
According to the GPU training documentation:
"Each GPU gets visibility into a subset of the overall dataset. It will only ever see that subset."
This suggests that DDP should automatically:
- Shard the dataset (different data per GPU) ✅ Works
- Shard the batch size (smaller batches per GPU) ❌ Does NOT work
Actual Behavior
When using strategy="ddp"
or DDPStrategy()
with multiple GPUs:
- Dataset sharding works: Each GPU gets different data (confirmed by different input/target values)
- Gradient synchronization works: Weight matrices remain identical across GPUs (confirmed by weight norm checks)
- Batch size sharding does NOT work: Each GPU processes the full batch size instead of
batch_size / num_gpus
Reproduction Steps
import lightning.pytorch as pl
from lightning.pytorch.strategies import DDPStrategy
from pytorch_lightning import LightningModule, Trainer
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
import torch
class SimpleModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = nn.Linear(10, 1)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.functional.mse_loss(y_hat, y)
if batch_idx == 0:
print(f"[Rank {self.global_rank}] Batch shape: {x.shape}")
print(f"[Rank {self.global_rank}] First element: {x[0, 0].item():.6f}")
return loss
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.01)
# Create dataset and dataloader
X = torch.randn(1000, 10)
y = torch.randn(1000, 1)
ds = TensorDataset(X, y)
dl = DataLoader(ds, batch_size=64, shuffle=False, num_workers=0)
model = SimpleModel()
strategy = DDPStrategy(find_unused_parameters=False, static_graph=True)
trainer = Trainer(
accelerator="gpu",
devices=2,
strategy=strategy,
max_epochs=1,
logger=False,
enable_checkpointing=False,
)
trainer.fit(model, dl)
Expected vs Actual Output
Expected (if batch sharding worked):
[Rank 0] Batch shape: torch.Size([32, 10]) # 64/2 = 32
[Rank 1] Batch shape: torch.Size([32, 10]) # 64/2 = 32
Actual output:
[Rank 0] Batch shape: torch.Size([64, 10]) # Full batch size
[Rank 1] Batch shape: torch.Size([64, 10]) # Full batch size
Impact
- No speedup from multi-GPU training: Each GPU uses full memory instead of reduced memory
- Memory pressure: Users hit OOM errors that could be avoided with proper batch sharding
- Misleading documentation: Users spend hours debugging why they're not getting expected performance
- Manual workaround required: Users must manually divide batch sizes by number of GPUs
Workaround
Users must manually reduce batch sizes in DataLoaders:
# For 2 GPUs, use batch_size=32 instead of 64
dl = DataLoader(ds, batch_size=32, shuffle=False, num_workers=0)
# Effective batch size = 32 * 2 = 64
Request
Either:
- Fix automatic batch sharding to match documentation claims, OR
- Update documentation to clarify that only dataset sharding is automatic and batch sharding requires manual configuration
Environment
- PyTorch Lightning version: [latest]
- PyTorch version: [latest]
- CUDA version: [latest]
- Number of GPUs: 2
This issue cost several hours of debugging time and should be addressed to prevent other users from experiencing the same frustration.
What version are you seeing the problem on?
v2.5
Reproduced in studio
No response
How to reproduce the bug
import lightning.pytorch as pl
from lightning.pytorch.strategies import DDPStrategy
from pytorch_lightning import LightningModule, Trainer
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
import torch
class SimpleModel(LightningModule):
def __init__(self):
super().__init__()
self.layer = nn.Linear(10, 1)
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.functional.mse_loss(y_hat, y)
if batch_idx == 0:
print(f"[Rank {self.global_rank}] Batch shape: {x.shape}")
print(f"[Rank {self.global_rank}] First element: {x[0, 0].item():.6f}")
return loss
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.01)
# Create dataset and dataloader
X = torch.randn(1000, 10)
y = torch.randn(1000, 1)
ds = TensorDataset(X, y)
dl = DataLoader(ds, batch_size=64, shuffle=False, num_workers=0)
model = SimpleModel()
strategy = DDPStrategy(find_unused_parameters=False, static_graph=True)
trainer = Trainer(
accelerator="gpu",
devices=2,
strategy=strategy,
max_epochs=1,
logger=False,
enable_checkpointing=False,
)
trainer.fit(model, dl)
Error messages and logs
# Error messages and logs here please
Environment
Current environment
#- PyTorch Lightning Version (e.g., 2.5.0):
#- PyTorch Version (e.g., 2.5):
#- Python version (e.g., 3.12):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source): pip
Ensure proper SLURM environment variables for PyTorch Lightning
export SLURM_NTASKS_PER_NODE=2
Print cluster info for debugging
echo "=== SLURM Job Info ==="
echo "Job ID: $SLURM_JOB_ID"
echo "Nodes: $SLURM_JOB_NUM_NODES"
echo "Tasks per node: $SLURM_NTASKS_PER_NODE"
echo "CPUs per task: $SLURM_CPUS_PER_TASK"
echo "GPUs requested: $SLURM_GPUS_PER_NODE"
echo "Available GPUs: $CUDA_VISIBLE_DEVICES"
echo "======================="
Load modules and setup environment
module load cuda/12.1 || echo "Warning: Could not load CUDA module"
module load webproxy
pip install torch lightning
nproc_per_node == number of GPUs (or else won't work)
torchrun
--nproc_per_node=2
--nnodes=1
--node_rank=0
--master_addr=localhost
--master_port=12355
mwe.py
### More info
I would like confirmation about the batch size under DDS. If the batch size is 64 on a single GPU, and I run in DDS mode, should the batch size e 32 on each GPU, or 64 on each GPU?
cc @justusschock @lantiga