-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Description
Bug description
When using PyTorch Lightning with DDP and static_graph=True, model parameters are not synchronized properly across processes. I tested this against vanilla PyTorch DDP and confirmed that the issue only appears in Lightning.
📄 Minimal Reproducible Example
I created a minimal script that compares model parameter changes across DDP processes after each optimizer step. It runs 2 training steps and logs the changed indices and delta of the weights from the first fully connected layer.
This script runs 4 experiments:
• Lightning with static_graph=True
• Lightning with static_graph=False
• Vanilla PyTorch DDP with static_graph=True
• Vanilla PyTorch DDP with static_graph=False
Only the Lightning + static_graph=True case shows inconsistent or missing synchronization.
🔍 Observed Behavior
• When using Lightning + DDP + static_graph=True, each GPU maintains a different version of the model after training steps.
• When using Vanilla PyTorch DDP + static_graph=True, synchronization works as expected.
✅ Expected Behavior
Model parameters should remain synchronized across DDP processes, even when static_graph=True.
What version are you seeing the problem on?
master
How to reproduce the bug
import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
# Lightning imports
import lightning as L
from lightning import Trainer
from lightning.pytorch.strategies import DDPStrategy
# PyTorch DDP imports
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
# ------------------------------------------------------------------------------
# Helper Function for Reporting Weight Changes
# ------------------------------------------------------------------------------
def report_weight_changes(rank, mode_name, step, prev_weights, prev_weights_sum, current_weights, suffix):
"""
Computes and reports the change in parameters (for fc1 layer) between training steps.
Args:
rank (int): Process/GPU rank.
mode_name (str): Mode name (e.g., "Lightning" or "PyTorch").
step (int): The current training step (or batch index).
prev_weights (Tensor or None): The weight vector from the previous step.
prev_weights_sum (Tensor or None): Sum of the previous weights.
current_weights (Tensor): The weight vector at the current step.
suffix (str): Suffix to be appended to log file name (e.g., 'sgTrue' or 'sgFalse').
Returns:
current_weights_sum, current_weights: Updated weight sum and weight vector.
"""
current_weights_sum = current_weights.sum()
# Only report if we have previous weights to compare.
if prev_weights_sum is not None:
delta_weights = current_weights - prev_weights
changed_indices = delta_weights.nonzero()[:10]
file_name = f"{mode_name}_{suffix}_{rank}.txt"
with open(file_name, "a") as f:
f.write(f"[{mode_name} GPU {rank}] Step {step} Changed indices: {changed_indices.tolist()}\n")
f.write(f"[{mode_name} GPU {rank}] Step {step} Weight delta: {delta_weights[changed_indices]}\n")
return current_weights_sum, current_weights
# ------------------------------------------------------------------------------
# Shared Model Definition
# ------------------------------------------------------------------------------
class BaseModel(nn.Module):
def __init__(self):
super(BaseModel, self).__init__()
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(28 * 28, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
# ------------------------------------------------------------------------------
# PyTorch Lightning Module & Data Module
# ------------------------------------------------------------------------------
class LitClassifier(L.LightningModule):
def __init__(self, graph_mode):
"""
Args:
graph_mode (bool): True if using static_graph=True, False otherwise.
"""
super().__init__()
self.model = BaseModel()
# For tracking weight changes for fc1
self.prev_weights_sum = None
self.prev_weights = None
self.automatic_optimization = False # using manual optimization.
# Save the graph mode suffix for logging purposes.
self.graph_suffix = f"sg{'True' if graph_mode else 'False'}"
def training_step(self, batch, batch_idx):
opt = self.optimizers() # manual optimizer access.
x, y = batch
logits = self.model(x)
loss = F.cross_entropy(logits, y)
opt.zero_grad()
self.manual_backward(loss)
opt.step()
# Compute current weights for fc1 and compare with previous step.
current_weights = self.model.fc1.weight.data.view(-1).clone().detach().cpu()
self.prev_weights_sum, self.prev_weights = report_weight_changes(
self.global_rank, "Lightning", batch_idx, self.prev_weights, self.prev_weights_sum, current_weights,
self.graph_suffix
)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=1e-3)
class MNISTDataModule(L.LightningDataModule):
def __init__(self, batch_size=64):
super().__init__()
self.batch_size = batch_size
def prepare_data(self):
datasets.MNIST("data", train=True, download=True)
datasets.MNIST("data", train=False, download=True)
def setup(self, stage=None):
transform = transforms.ToTensor()
full_dataset = datasets.MNIST("data", train=True, transform=transform)
self.train_set, _ = random_split(full_dataset, [55000, 5000])
def train_dataloader(self):
return DataLoader(self.train_set, batch_size=self.batch_size, num_workers=4)
# ------------------------------------------------------------------------------
# Vanilla PyTorch DDP Implementation
# ------------------------------------------------------------------------------
def setup_ddp(rank, world_size, port):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(port)
dist.init_process_group("nccl", rank=rank, world_size=world_size)
def cleanup_ddp():
dist.destroy_process_group()
def prepare_dataloader_ddp(rank, world_size, batch_size=64):
transform = transforms.ToTensor()
dataset = datasets.MNIST("data", train=True, download=True, transform=transform)
train_set, _ = random_split(dataset, [55000, 5000])
sampler = torch.utils.data.distributed.DistributedSampler(train_set, num_replicas=world_size, rank=rank)
return DataLoader(train_set, batch_size=batch_size, sampler=sampler, num_workers=4)
def ddp_train(rank, world_size, steps, static_graph, port):
setup_ddp(rank, world_size, port)
device = torch.device(f"cuda:{rank}")
model = BaseModel().to(device)
# Pass the static_graph flag from the argument.
ddp_model = DDP(model, device_ids=[rank], static_graph=static_graph)
optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-3)
train_loader = prepare_dataloader_ddp(rank, world_size)
# Determine the suffix for the log filename.
graph_suffix = f"sg{'True' if static_graph else 'False'}"
step = 0
prev_weights_sum = None
prev_weights = None
ddp_model.train()
for epoch in range(10): # Loop over epochs if necessary.
for batch_idx, (x, y) in enumerate(train_loader):
if step >= steps:
break
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
output = ddp_model(x)
loss = F.cross_entropy(output, y)
loss.backward()
optimizer.step()
with torch.no_grad():
current_weights = ddp_model.module.fc1.weight.data.view(-1).clone().detach().cpu()
prev_weights_sum, prev_weights = report_weight_changes(
rank, "PyTorch", step, prev_weights, prev_weights_sum, current_weights, graph_suffix
)
step += 1
if step >= steps:
break
cleanup_ddp()
# ------------------------------------------------------------------------------
# Main: Running Both Versions for Lightning and PyTorch
# ------------------------------------------------------------------------------
def run_lightning(static_graph):
print(f"Running Lightning mode with static_graph={static_graph} for 2 training steps")
model = LitClassifier(graph_mode=static_graph)
dm = MNISTDataModule(batch_size=64)
trainer = Trainer(
max_epochs=1,
accelerator="gpu",
devices=torch.cuda.device_count(),
strategy=DDPStrategy(static_graph=static_graph),
num_sanity_val_steps=0,
deterministic=True,
limit_train_batches=2,
)
trainer.fit(model, dm)
def run_pytorch(static_graph, port):
print(f"Running vanilla PyTorch DDP mode with static_graph={static_graph} for 2 training steps (port={port})")
world_size = torch.cuda.device_count()
mp.spawn(ddp_train, args=(world_size, 2, static_graph, port), nprocs=world_size, join=True)
if __name__ == "__main__":
# Run Lightning with static_graph True and False:
run_lightning(static_graph=True)
run_lightning(static_graph=False)
# Run vanilla PyTorch DDP with static_graph True and False on different ports.
run_pytorch(static_graph=True, port=12356)
run_pytorch(static_graph=False, port=12357)
Error messages and logs
Running the reproducing script, we can check that Pytorch Lightning DDP with static_graph=True has different model parameters across different processes throughout the training step.
Lightning_SGTrue_0.txt:
[Lightning GPU 0] Step 1 Changed indices: [[67], [68], [69], [70], [71], [72], [73], [74], [95], [96]]
[Lightning GPU 0] Step 1 Weight delta: tensor([[-0.0007],
[-0.0007],
[-0.0007],
[-0.0007],
[-0.0007],
[-0.0007],
[-0.0007],
[-0.0007],
[-0.0007],
[-0.0007]])
Lightning_SGTrue_1.txt:
[Lightning GPU 1] Step 1 Changed indices: [[39], [40], [66], [67], [68], [69], [70], [71], [72], [94]]
[Lightning GPU 1] Step 1 Weight delta: tensor([[-0.0007],
[-0.0007],
[-0.0007],
[-0.0007],
[-0.0008],
[-0.0007],
[-0.0010],
[-0.0007],
[-0.0007],
[-0.0007]])
Environment
Current environment
- CUDA:
- GPU:
- NVIDIA RTX A6000
- NVIDIA RTX A6000
- NVIDIA RTX A6000
- NVIDIA RTX A6000
- available: True
- version: 12.6 - Lightning:
- lightning: 2.5.0.post0
- lightning-sdk: 0.2.5
- lightning-utilities: 0.14.0
- lion-pytorch: 0.2.3
- pytorch-lightning: 2.5.0.post0
- pytorch-triton: 3.3.0+git96316ce5
- torch: 2.8.0.dev20250407+cu126
- torch-tb-profiler: 0.4.3
- torchmetrics: 1.6.2
- torchvision: 0.22.0.dev20250407+cu126 - Packages:
- absl-py: 2.1.0
- accelerate: 1.4.0
- aiohappyeyeballs: 2.5.0
- aiohttp: 3.11.13
- aiosignal: 1.3.2
- annotated-types: 0.7.0
- antlr4-python3-runtime: 4.9.3
- anyio: 4.9.0
- attn-gym: 0.0.4.dev12+g41a96b6
- attrs: 25.1.0
- autocommand: 2.2.2
- backoff: 2.2.1
- backports.tarfile: 1.2.0
- beautifulsoup4: 4.13.3
- blinker: 1.4
- boto3: 1.37.10
- botocore: 1.37.10
- bs4: 0.0.2
- certifi: 2025.1.31
- charset-normalizer: 3.4.1
- click: 8.1.8
- cloudpickle: 3.1.1
- contourpy: 1.3.1
- cryptography: 3.4.8
- cssselect: 1.3.0
- cycler: 0.12.1
- dacite: 1.9.2
- datasets: 3.3.2
- dbus-python: 1.2.18
- dill: 0.3.8
- distro: 1.7.0
- distro-info: 1.1+ubuntu0.2
- docker: 7.1.0
- docker-pycreds: 0.4.0
- einops: 0.8.1
- faiss: 1.10.0
- fastapi: 0.115.11
- feedfinder2: 0.0.4
- feedparser: 6.0.11
- filelock: 3.16.1
- flash-attn: 2.7.4.post1
- fonttools: 4.56.0
- frozenlist: 1.5.0
- fsspec: 2024.10.0
- ftfy: 6.3.1
- gitdb: 4.0.12
- gitpython: 3.1.44
- grpcio: 1.71.0
- h11: 0.14.0
- h5py: 3.13.0
- hkkang-utils: 0.2.57
- htmlmin: 0.1.12
- httplib2: 0.20.2
- huggingface-hub: 0.29.3
- hydra-core: 1.3.2
- idna: 3.10
- importlib-metadata: 8.0.0
- inflect: 7.3.1
- jaraco.collections: 5.1.0
- jaraco.context: 5.3.0
- jaraco.functools: 4.0.1
- jaraco.text: 3.12.1
- jeepney: 0.7.1
- jieba3k: 0.35.1
- jinja2: 3.1.4
- jmespath: 1.0.1
- joblib: 1.4.2
- jsonargparse: 4.37.0
- keyring: 23.5.0
- kiwisolver: 1.4.8
- langdetect: 1.0.9
- launchpadlib: 1.10.16
- lazr.restfulclient: 0.14.4
- lazr.uri: 1.0.6
- legacy-cgi: 2.6.2
- lightning: 2.5.0.post0
- lightning-sdk: 0.2.5
- lightning-utilities: 0.14.0
- lion-pytorch: 0.2.3
- lxml: 5.3.1
- lxml-html-clean: 0.4.1
- markdown: 3.7
- markdown-it-py: 3.0.0
- markupsafe: 2.1.5
- matplotlib: 3.10.1
- mdurl: 0.1.2
- more-itertools: 10.3.0
- mpmath: 1.3.0
- multidict: 6.1.0
- multiprocess: 0.70.16
- networkx: 3.4.2
- newspaper3k: 0.2.8
- nltk: 3.9.1
- numpy: 2.2.3
- nvidia-cublas-cu12: 12.6.4.1
- nvidia-cuda-cupti-cu12: 12.6.80
- nvidia-cuda-nvrtc-cu12: 12.6.77
- nvidia-cuda-runtime-cu12: 12.6.77
- nvidia-cudnn-cu12: 9.5.1.17
- nvidia-cufft-cu12: 11.3.0.4
- nvidia-cufile-cu12: 1.11.1.6
- nvidia-curand-cu12: 10.3.7.77
- nvidia-cusolver-cu12: 11.7.1.2
- nvidia-cusparse-cu12: 12.5.4.2
- nvidia-cusparselt-cu12: 0.6.3
- nvidia-nccl-cu12: 2.26.2
- nvidia-nvjitlink-cu12: 12.6.85
- nvidia-nvtx-cu12: 12.6.77
- oauthlib: 3.2.0
- omegaconf: 2.3.0
- orjson: 3.10.15
- packaging: 24.2
- pandas: 2.2.3
- pglast: 7.3
- pillow: 11.1.0
- pip: 25.0.1
- platformdirs: 4.2.2
- propcache: 0.3.0
- protobuf: 5.29.3
- psutil: 7.0.0
- psycopg: 3.2.5
- psycopg-binary: 3.2.5
- psycopg-pool: 3.2.6
- pyarrow: 19.0.1
- pydantic: 2.10.6
- pydantic-core: 2.27.2
- pygments: 2.19.1
- pygobject: 3.42.1
- pyjwt: 2.3.0
- pyparsing: 2.4.7
- python-apt: 2.4.0+ubuntu4
- python-dateutil: 2.9.0.post0
- python-dotenv: 1.0.1
- pytorch-lightning: 2.5.0.post0
- pytorch-triton: 3.3.0+git96316ce5
- pytz: 2025.1
- pyyaml: 6.0.2
- regex: 2024.11.6
- requests: 2.32.3
- requests-file: 2.1.0
- rich: 13.9.4
- s3transfer: 0.11.4
- safetensors: 0.5.3
- secretstorage: 3.3.1
- sentencepiece: 0.2.0
- sentry-sdk: 2.22.0
- setproctitle: 1.3.5
- setuptools: 75.8.0
- sgmllib3k: 1.0.0
- simple-term-menu: 1.6.6
- six: 1.16.0
- slack-sdk: 3.34.0
- smmap: 5.0.2
- sniffio: 1.3.1
- soupsieve: 2.6
- standard-imghdr: 3.13.0
- starlette: 0.46.1
- sympy: 1.13.3
- tensorboard: 2.19.0
- tensorboard-data-server: 0.7.2
- tensordict: 0.7.2
- tinysegmenter: 0.3
- tldextract: 5.1.3
- tokenizers: 0.21.0
- tomli: 2.0.1
- torch: 2.8.0.dev20250407+cu126
- torch-tb-profiler: 0.4.3
- torchmetrics: 1.6.2
- torchvision: 0.22.0.dev20250407+cu126
- tqdm: 4.67.1
- transformers: 4.49.0
- triton: 3.2.0
- typeguard: 4.3.0
- typing-extensions: 4.12.2
- tzdata: 2025.1
- ujson: 5.10.0
- unattended-upgrades: 0.1
- urllib3: 2.3.0
- uvicorn: 0.34.0
- wadllib: 1.3.6
- wandb: 0.19.8
- wcwidth: 0.2.13
- websocket-client: 1.8.0
- werkzeug: 3.1.3
- wget: 3.2
- wheel: 0.43.0
- xxhash: 3.5.0
- yarl: 1.18.3
- zipp: 3.19.2 - System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.13.2
- release: 5.15.0-107-generic
- version: Errata in the readme? #117-Ubuntu SMP Fri Apr 26 12:26:49 UTC 2024
More info
No response