Skip to content

DDP model not synchronizing when static_graph=True #20704

@hyukkyukang

Description

@hyukkyukang

Bug description

When using PyTorch Lightning with DDP and static_graph=True, model parameters are not synchronized properly across processes. I tested this against vanilla PyTorch DDP and confirmed that the issue only appears in Lightning.

📄 Minimal Reproducible Example

I created a minimal script that compares model parameter changes across DDP processes after each optimizer step. It runs 2 training steps and logs the changed indices and delta of the weights from the first fully connected layer.

This script runs 4 experiments:
• Lightning with static_graph=True
• Lightning with static_graph=False
• Vanilla PyTorch DDP with static_graph=True
• Vanilla PyTorch DDP with static_graph=False

Only the Lightning + static_graph=True case shows inconsistent or missing synchronization.

🔍 Observed Behavior
• When using Lightning + DDP + static_graph=True, each GPU maintains a different version of the model after training steps.
• When using Vanilla PyTorch DDP + static_graph=True, synchronization works as expected.

✅ Expected Behavior

Model parameters should remain synchronized across DDP processes, even when static_graph=True.

What version are you seeing the problem on?

master

How to reproduce the bug

import os
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

# Lightning imports
import lightning as L
from lightning import Trainer
from lightning.pytorch.strategies import DDPStrategy

# PyTorch DDP imports
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

# ------------------------------------------------------------------------------
# Helper Function for Reporting Weight Changes
# ------------------------------------------------------------------------------
def report_weight_changes(rank, mode_name, step, prev_weights, prev_weights_sum, current_weights, suffix):
    """
    Computes and reports the change in parameters (for fc1 layer) between training steps.
    
    Args:
        rank (int): Process/GPU rank.
        mode_name (str): Mode name (e.g., "Lightning" or "PyTorch").
        step (int): The current training step (or batch index).
        prev_weights (Tensor or None): The weight vector from the previous step.
        prev_weights_sum (Tensor or None): Sum of the previous weights.
        current_weights (Tensor): The weight vector at the current step.
        suffix (str): Suffix to be appended to log file name (e.g., 'sgTrue' or 'sgFalse').
        
    Returns:
        current_weights_sum, current_weights: Updated weight sum and weight vector.
    """
    current_weights_sum = current_weights.sum()
    # Only report if we have previous weights to compare.
    if prev_weights_sum is not None:
        delta_weights = current_weights - prev_weights
        changed_indices = delta_weights.nonzero()[:10]
        file_name = f"{mode_name}_{suffix}_{rank}.txt"
        with open(file_name, "a") as f:
            f.write(f"[{mode_name} GPU {rank}] Step {step} Changed indices: {changed_indices.tolist()}\n")
            f.write(f"[{mode_name} GPU {rank}] Step {step} Weight delta: {delta_weights[changed_indices]}\n")
    return current_weights_sum, current_weights

# ------------------------------------------------------------------------------
# Shared Model Definition
# ------------------------------------------------------------------------------
class BaseModel(nn.Module):
    def __init__(self):
        super(BaseModel, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# ------------------------------------------------------------------------------
# PyTorch Lightning Module & Data Module
# ------------------------------------------------------------------------------
class LitClassifier(L.LightningModule):
    def __init__(self, graph_mode):
        """
        Args:
            graph_mode (bool): True if using static_graph=True, False otherwise.
        """
        super().__init__()
        self.model = BaseModel()
        # For tracking weight changes for fc1
        self.prev_weights_sum = None
        self.prev_weights = None
        self.automatic_optimization = False  # using manual optimization.
        # Save the graph mode suffix for logging purposes.
        self.graph_suffix = f"sg{'True' if graph_mode else 'False'}"

    def training_step(self, batch, batch_idx):
        opt = self.optimizers()  # manual optimizer access.
        x, y = batch
        logits = self.model(x)
        loss = F.cross_entropy(logits, y)
        opt.zero_grad()
        self.manual_backward(loss)
        opt.step()

        # Compute current weights for fc1 and compare with previous step.
        current_weights = self.model.fc1.weight.data.view(-1).clone().detach().cpu()
        self.prev_weights_sum, self.prev_weights = report_weight_changes(
            self.global_rank, "Lightning", batch_idx, self.prev_weights, self.prev_weights_sum, current_weights,
            self.graph_suffix
        )
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=1e-3)

class MNISTDataModule(L.LightningDataModule):
    def __init__(self, batch_size=64):
        super().__init__()
        self.batch_size = batch_size

    def prepare_data(self):
        datasets.MNIST("data", train=True, download=True)
        datasets.MNIST("data", train=False, download=True)

    def setup(self, stage=None):
        transform = transforms.ToTensor()
        full_dataset = datasets.MNIST("data", train=True, transform=transform)
        self.train_set, _ = random_split(full_dataset, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, num_workers=4)

# ------------------------------------------------------------------------------
# Vanilla PyTorch DDP Implementation
# ------------------------------------------------------------------------------
def setup_ddp(rank, world_size, port):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(port)
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup_ddp():
    dist.destroy_process_group()

def prepare_dataloader_ddp(rank, world_size, batch_size=64):
    transform = transforms.ToTensor()
    dataset = datasets.MNIST("data", train=True, download=True, transform=transform)
    train_set, _ = random_split(dataset, [55000, 5000])
    sampler = torch.utils.data.distributed.DistributedSampler(train_set, num_replicas=world_size, rank=rank)
    return DataLoader(train_set, batch_size=batch_size, sampler=sampler, num_workers=4)

def ddp_train(rank, world_size, steps, static_graph, port):
    setup_ddp(rank, world_size, port)
    device = torch.device(f"cuda:{rank}")
    model = BaseModel().to(device)
    # Pass the static_graph flag from the argument.
    ddp_model = DDP(model, device_ids=[rank], static_graph=static_graph)
    optimizer = torch.optim.Adam(ddp_model.parameters(), lr=1e-3)
    train_loader = prepare_dataloader_ddp(rank, world_size)

    # Determine the suffix for the log filename.
    graph_suffix = f"sg{'True' if static_graph else 'False'}"

    step = 0
    prev_weights_sum = None
    prev_weights = None
    ddp_model.train()

    for epoch in range(10):  # Loop over epochs if necessary.
        for batch_idx, (x, y) in enumerate(train_loader):
            if step >= steps:
                break
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            output = ddp_model(x)
            loss = F.cross_entropy(output, y)
            loss.backward()
            optimizer.step()

            with torch.no_grad():
                current_weights = ddp_model.module.fc1.weight.data.view(-1).clone().detach().cpu()
                prev_weights_sum, prev_weights = report_weight_changes(
                    rank, "PyTorch", step, prev_weights, prev_weights_sum, current_weights, graph_suffix
                )
            step += 1
        if step >= steps:
            break
    cleanup_ddp()

# ------------------------------------------------------------------------------
# Main: Running Both Versions for Lightning and PyTorch
# ------------------------------------------------------------------------------
def run_lightning(static_graph):
    print(f"Running Lightning mode with static_graph={static_graph} for 2 training steps")
    model = LitClassifier(graph_mode=static_graph)
    dm = MNISTDataModule(batch_size=64)
    trainer = Trainer(
        max_epochs=1,
        accelerator="gpu",
        devices=torch.cuda.device_count(),
        strategy=DDPStrategy(static_graph=static_graph),
        num_sanity_val_steps=0,
        deterministic=True,
        limit_train_batches=2,
    )
    trainer.fit(model, dm)

def run_pytorch(static_graph, port):
    print(f"Running vanilla PyTorch DDP mode with static_graph={static_graph} for 2 training steps (port={port})")
    world_size = torch.cuda.device_count()
    mp.spawn(ddp_train, args=(world_size, 2, static_graph, port), nprocs=world_size, join=True)

if __name__ == "__main__":
    # Run Lightning with static_graph True and False:
    run_lightning(static_graph=True)
    run_lightning(static_graph=False)
    # Run vanilla PyTorch DDP with static_graph True and False on different ports.
    run_pytorch(static_graph=True, port=12356)
    run_pytorch(static_graph=False, port=12357)

Error messages and logs

Running the reproducing script, we can check that Pytorch Lightning DDP with static_graph=True has different model parameters across different processes throughout the training step.

Lightning_SGTrue_0.txt:

[Lightning GPU 0] Step 1 Changed indices: [[67], [68], [69], [70], [71], [72], [73], [74], [95], [96]]
[Lightning GPU 0] Step 1 Weight delta: tensor([[-0.0007],
        [-0.0007],
        [-0.0007],
        [-0.0007],
        [-0.0007],
        [-0.0007],
        [-0.0007],
        [-0.0007],
        [-0.0007],
        [-0.0007]])

Lightning_SGTrue_1.txt:

[Lightning GPU 1] Step 1 Changed indices: [[39], [40], [66], [67], [68], [69], [70], [71], [72], [94]]
[Lightning GPU 1] Step 1 Weight delta: tensor([[-0.0007],
        [-0.0007],
        [-0.0007],
        [-0.0007],
        [-0.0008],
        [-0.0007],
        [-0.0010],
        [-0.0007],
        [-0.0007],
        [-0.0007]])

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA RTX A6000
    - NVIDIA RTX A6000
    - NVIDIA RTX A6000
    - NVIDIA RTX A6000
    - available: True
    - version: 12.6
  • Lightning:
    - lightning: 2.5.0.post0
    - lightning-sdk: 0.2.5
    - lightning-utilities: 0.14.0
    - lion-pytorch: 0.2.3
    - pytorch-lightning: 2.5.0.post0
    - pytorch-triton: 3.3.0+git96316ce5
    - torch: 2.8.0.dev20250407+cu126
    - torch-tb-profiler: 0.4.3
    - torchmetrics: 1.6.2
    - torchvision: 0.22.0.dev20250407+cu126
  • Packages:
    - absl-py: 2.1.0
    - accelerate: 1.4.0
    - aiohappyeyeballs: 2.5.0
    - aiohttp: 3.11.13
    - aiosignal: 1.3.2
    - annotated-types: 0.7.0
    - antlr4-python3-runtime: 4.9.3
    - anyio: 4.9.0
    - attn-gym: 0.0.4.dev12+g41a96b6
    - attrs: 25.1.0
    - autocommand: 2.2.2
    - backoff: 2.2.1
    - backports.tarfile: 1.2.0
    - beautifulsoup4: 4.13.3
    - blinker: 1.4
    - boto3: 1.37.10
    - botocore: 1.37.10
    - bs4: 0.0.2
    - certifi: 2025.1.31
    - charset-normalizer: 3.4.1
    - click: 8.1.8
    - cloudpickle: 3.1.1
    - contourpy: 1.3.1
    - cryptography: 3.4.8
    - cssselect: 1.3.0
    - cycler: 0.12.1
    - dacite: 1.9.2
    - datasets: 3.3.2
    - dbus-python: 1.2.18
    - dill: 0.3.8
    - distro: 1.7.0
    - distro-info: 1.1+ubuntu0.2
    - docker: 7.1.0
    - docker-pycreds: 0.4.0
    - einops: 0.8.1
    - faiss: 1.10.0
    - fastapi: 0.115.11
    - feedfinder2: 0.0.4
    - feedparser: 6.0.11
    - filelock: 3.16.1
    - flash-attn: 2.7.4.post1
    - fonttools: 4.56.0
    - frozenlist: 1.5.0
    - fsspec: 2024.10.0
    - ftfy: 6.3.1
    - gitdb: 4.0.12
    - gitpython: 3.1.44
    - grpcio: 1.71.0
    - h11: 0.14.0
    - h5py: 3.13.0
    - hkkang-utils: 0.2.57
    - htmlmin: 0.1.12
    - httplib2: 0.20.2
    - huggingface-hub: 0.29.3
    - hydra-core: 1.3.2
    - idna: 3.10
    - importlib-metadata: 8.0.0
    - inflect: 7.3.1
    - jaraco.collections: 5.1.0
    - jaraco.context: 5.3.0
    - jaraco.functools: 4.0.1
    - jaraco.text: 3.12.1
    - jeepney: 0.7.1
    - jieba3k: 0.35.1
    - jinja2: 3.1.4
    - jmespath: 1.0.1
    - joblib: 1.4.2
    - jsonargparse: 4.37.0
    - keyring: 23.5.0
    - kiwisolver: 1.4.8
    - langdetect: 1.0.9
    - launchpadlib: 1.10.16
    - lazr.restfulclient: 0.14.4
    - lazr.uri: 1.0.6
    - legacy-cgi: 2.6.2
    - lightning: 2.5.0.post0
    - lightning-sdk: 0.2.5
    - lightning-utilities: 0.14.0
    - lion-pytorch: 0.2.3
    - lxml: 5.3.1
    - lxml-html-clean: 0.4.1
    - markdown: 3.7
    - markdown-it-py: 3.0.0
    - markupsafe: 2.1.5
    - matplotlib: 3.10.1
    - mdurl: 0.1.2
    - more-itertools: 10.3.0
    - mpmath: 1.3.0
    - multidict: 6.1.0
    - multiprocess: 0.70.16
    - networkx: 3.4.2
    - newspaper3k: 0.2.8
    - nltk: 3.9.1
    - numpy: 2.2.3
    - nvidia-cublas-cu12: 12.6.4.1
    - nvidia-cuda-cupti-cu12: 12.6.80
    - nvidia-cuda-nvrtc-cu12: 12.6.77
    - nvidia-cuda-runtime-cu12: 12.6.77
    - nvidia-cudnn-cu12: 9.5.1.17
    - nvidia-cufft-cu12: 11.3.0.4
    - nvidia-cufile-cu12: 1.11.1.6
    - nvidia-curand-cu12: 10.3.7.77
    - nvidia-cusolver-cu12: 11.7.1.2
    - nvidia-cusparse-cu12: 12.5.4.2
    - nvidia-cusparselt-cu12: 0.6.3
    - nvidia-nccl-cu12: 2.26.2
    - nvidia-nvjitlink-cu12: 12.6.85
    - nvidia-nvtx-cu12: 12.6.77
    - oauthlib: 3.2.0
    - omegaconf: 2.3.0
    - orjson: 3.10.15
    - packaging: 24.2
    - pandas: 2.2.3
    - pglast: 7.3
    - pillow: 11.1.0
    - pip: 25.0.1
    - platformdirs: 4.2.2
    - propcache: 0.3.0
    - protobuf: 5.29.3
    - psutil: 7.0.0
    - psycopg: 3.2.5
    - psycopg-binary: 3.2.5
    - psycopg-pool: 3.2.6
    - pyarrow: 19.0.1
    - pydantic: 2.10.6
    - pydantic-core: 2.27.2
    - pygments: 2.19.1
    - pygobject: 3.42.1
    - pyjwt: 2.3.0
    - pyparsing: 2.4.7
    - python-apt: 2.4.0+ubuntu4
    - python-dateutil: 2.9.0.post0
    - python-dotenv: 1.0.1
    - pytorch-lightning: 2.5.0.post0
    - pytorch-triton: 3.3.0+git96316ce5
    - pytz: 2025.1
    - pyyaml: 6.0.2
    - regex: 2024.11.6
    - requests: 2.32.3
    - requests-file: 2.1.0
    - rich: 13.9.4
    - s3transfer: 0.11.4
    - safetensors: 0.5.3
    - secretstorage: 3.3.1
    - sentencepiece: 0.2.0
    - sentry-sdk: 2.22.0
    - setproctitle: 1.3.5
    - setuptools: 75.8.0
    - sgmllib3k: 1.0.0
    - simple-term-menu: 1.6.6
    - six: 1.16.0
    - slack-sdk: 3.34.0
    - smmap: 5.0.2
    - sniffio: 1.3.1
    - soupsieve: 2.6
    - standard-imghdr: 3.13.0
    - starlette: 0.46.1
    - sympy: 1.13.3
    - tensorboard: 2.19.0
    - tensorboard-data-server: 0.7.2
    - tensordict: 0.7.2
    - tinysegmenter: 0.3
    - tldextract: 5.1.3
    - tokenizers: 0.21.0
    - tomli: 2.0.1
    - torch: 2.8.0.dev20250407+cu126
    - torch-tb-profiler: 0.4.3
    - torchmetrics: 1.6.2
    - torchvision: 0.22.0.dev20250407+cu126
    - tqdm: 4.67.1
    - transformers: 4.49.0
    - triton: 3.2.0
    - typeguard: 4.3.0
    - typing-extensions: 4.12.2
    - tzdata: 2025.1
    - ujson: 5.10.0
    - unattended-upgrades: 0.1
    - urllib3: 2.3.0
    - uvicorn: 0.34.0
    - wadllib: 1.3.6
    - wandb: 0.19.8
    - wcwidth: 0.2.13
    - websocket-client: 1.8.0
    - werkzeug: 3.1.3
    - wget: 3.2
    - wheel: 0.43.0
    - xxhash: 3.5.0
    - yarl: 1.18.3
    - zipp: 3.19.2
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.13.2
    - release: 5.15.0-107-generic
    - version: Errata in the readme? #117-Ubuntu SMP Fri Apr 26 12:26:49 UTC 2024

More info

No response

cc @justusschock @lantiga

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions