Skip to content

DDP and BackboneFinetuning: model weights get out of sync when unfreezing layers for training #20340

@ksikka

Description

@ksikka

Bug description

When model training using DDP and pl.callbacks.BackboneFinetuning, it seems that model weights start to get out of sync across the processes after the backbone is unfrozen. Prior to unfreezing, model weights stay in sync across processes as expected.

I discovered this issue when trying to adopt DDP. I saw that on rank 0 process, validation loss trended downward while training, while on rank > 1 processes validation loss increased steadily. This led to the suspicion that model weights were different across nodes, which was confirmed by printing out the hash of model weights on the different processes on each epoch.

What version are you seeing the problem on?

v2.4

How to reproduce the bug

The below example is programmed to check that model weights are in sync after every epoch. It fails the assertion after epoch 3 (unfreeze_backbone_at_epoch).

import hashlib
import pytorch_lightning as pl
import torch
from torch import nn
import torch.distributed as dist
from torch.utils.data import DataLoader, Dataset


# 1. Define a simple dataset
class RandomDataset(Dataset):
    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


# 2. Define a LightningModule
class SimpleModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.backbone = nn.Linear(32, 16)
        self.layer = nn.Linear(16, 2)

    def forward(self, x):
        x = torch.relu(self.backbone(x))
        x = self.layer(x)
        return x

    def training_step(self, batch, batch_idx):
        x = batch
        y_hat = self(x)
        loss = torch.nn.functional.mse_loss(y_hat, torch.ones_like(y_hat))
        return loss

    def configure_optimizers(self):
        return torch.optim.SGD(
            filter(lambda p: p.requires_grad, self.parameters()), lr=0.1
        )

    def on_train_epoch_end(self):
        # Compute hash of model weights and checks if they're equal across processes.
        hasher = hashlib.sha256()
        for param in self.parameters():
            hasher.update(param.data.cpu().numpy().tobytes())
        param_hash = hasher.hexdigest()
        all_param_hashes = [None] * dist.get_world_size()
        dist.all_gather_object(all_param_hashes, param_hash)
        if self.trainer.is_global_zero:
            assert len(set(all_param_hashes)) == 1, "Model weights not in sync :("
            print("Model weights in sync!")


# 3. Create data loaders
pl.seed_everything(0)
train_loader = DataLoader(RandomDataset(32, 64), batch_size=2)

# 4. Initialize the model and trainer
model = SimpleModel()
trainer = pl.Trainer(
    accelerator="cpu",
    strategy="ddp",
    devices=2,
    callbacks=[
        pl.callbacks.BackboneFinetuning(unfreeze_backbone_at_epoch=3, verbose=True)
    ],
)

# 5. Train the model
trainer.fit(model, train_loader)

Output:

Epoch 0: 100%|████████████████████████████| 16/16 [00:00<00:00, 163.83it/s, loss=0.295, v_num=22]
Model weights in sync!
Epoch 1: 100%|███████████████████████████| 16/16 [00:00<00:00, 269.84it/s, loss=0.0667, v_num=22]
Model weights in sync!
Epoch 2: 100%|███████████████████████████| 16/16 [00:00<00:00, 257.68it/s, loss=0.0361, v_num=22]
Model weights in sync!
Current lr: 0.1, Backbone lr: 0.01
Current lr: 0.1, Backbone lr: 0.01
Epoch 3: 100%|███████████████████████████| 16/16 [00:00<00:00, 244.17it/s, loss=0.0243, v_num=22]Current lr: 0.1, Backbone lr: 0.02
[rank0]: Traceback (most recent call last):
...
[rank0]:   File "/home/ksikka/lightning-pose/example2.py", line 58, in _assert_model_weights_in_sync
[rank0]:     assert len(set(all_param_hashes)) == 1, "Model weights not in sync :("
[rank0]: AssertionError: Model weights not in sync :(

Error messages and logs

No warning or error. Validation loss with sync_dist=True increases after unfreezing, while with sync_dist=False, it decreases although at a lower rate than single process.

Environment

I originally noticed the issue in a multi-GPU linux environment in lightning studio, but I reproduced with the example code above on the following environment.

Current environment
  • CUDA:
    - GPU:
    - NVIDIA GeForce GTX 1080 Ti
    - available: True
    - version: 12.1
  • Lightning:
    - lightning: 2.4.0
    - lightning-bolts: 0.7.0
    - lightning-pose: 1.5.1
    - lightning-utilities: 0.11.7
    - pytorch-lightning: 1.9.5
    - torch: 2.4.1
    - torchmetrics: 1.4.2
    - torchtyping: 0.1.5
    - torchvision: 0.19.1
  • Packages:
    - absl-py: 2.1.0
    - aiofiles: 24.1.0
    - aiohappyeyeballs: 2.4.3
    - aiohttp: 3.10.8
    - aiosignal: 1.3.1
    - alabaster: 0.7.16
    - altair: 5.4.1
    - antlr4-python3-runtime: 4.9.3
    - anyio: 4.6.0
    - argcomplete: 3.5.0
    - astunparse: 1.6.3
    - async-timeout: 4.0.3
    - attrs: 24.2.0
    - autocommand: 2.2.2
    - babel: 2.16.0
    - backports.tarfile: 1.2.0
    - beautifulsoup4: 4.12.3
    - black: 24.8.0
    - blinker: 1.8.2
    - boto3: 1.35.32
    - botocore: 1.35.32
    - brotli: 1.1.0
    - cachetools: 5.5.0
    - certifi: 2024.8.30
    - charset-normalizer: 3.3.2
    - click: 8.1.7
    - contourpy: 1.3.0
    - cycler: 0.12.1
    - dacite: 1.7.0
    - decorator: 4.4.2
    - deprecated: 1.2.14
    - dill: 0.3.9
    - dm-tree: 0.1.8
    - dnspython: 2.6.1
    - docutils: 0.20.1
    - exceptiongroup: 1.2.2
    - execnet: 2.1.1
    - fiftyone: 1.0.0
    - fiftyone-brain: 0.17.0
    - fiftyone-db: 1.1.6
    - filelock: 3.16.1
    - flake8: 7.1.1
    - fonttools: 4.54.1
    - frozenlist: 1.4.1
    - fsspec: 2024.9.0
    - ftfy: 6.2.3
    - future: 1.0.0
    - gast: 0.6.0
    - gitdb: 4.0.11
    - gitpython: 3.1.43
    - glob2: 0.7
    - graphql-core: 3.2.4
    - grpcio: 1.66.2
    - h11: 0.14.0
    - h2: 4.1.0
    - h5py: 3.12.1
    - hpack: 4.0.0
    - httpcore: 1.0.6
    - httpx: 0.27.2
    - humanize: 4.10.0
    - hydra-core: 1.3.2
    - hypercorn: 0.17.3
    - hyperframe: 6.0.1
    - idna: 3.10
    - imageio: 2.35.1
    - imageio-ffmpeg: 0.5.1
    - imagesize: 1.4.1
    - imgaug: 0.4.0
    - importlib-metadata: 8.0.0
    - importlib-resources: 6.4.0
    - inflate64: 1.0.0
    - inflect: 7.3.1
    - iniconfig: 2.0.0
    - isort: 5.13.2
    - jaraco.collections: 5.1.0
    - jaraco.context: 5.3.0
    - jaraco.functools: 4.0.1
    - jaraco.text: 3.12.1
    - jinja2: 3.1.4
    - jmespath: 1.0.1
    - joblib: 1.4.2
    - jsonlines: 4.0.0
    - jsonschema: 4.23.0
    - jsonschema-specifications: 2023.12.1
    - kaleido: 0.2.1
    - kiwisolver: 1.4.7
    - kornia: 0.7.3
    - kornia-rs: 0.1.5
    - lazy-loader: 0.4
    - lightning: 2.4.0
    - lightning-bolts: 0.7.0
    - lightning-pose: 1.5.1
    - lightning-utilities: 0.11.7
    - markdown: 3.7
    - markdown-it-py: 3.0.0
    - markupsafe: 2.1.5
    - matplotlib: 3.9.2
    - mccabe: 0.7.0
    - mdurl: 0.1.2
    - mongoengine: 0.24.2
    - more-itertools: 10.3.0
    - motor: 3.5.3
    - moviepy: 1.0.3
    - mpmath: 1.3.0
    - multidict: 6.1.0
    - multivolumefile: 0.2.3
    - mypy-extensions: 1.0.0
    - narwhals: 1.9.0
    - networkx: 3.3
    - numpy: 1.26.4
    - nvidia-cublas-cu12: 12.1.3.1
    - nvidia-cuda-cupti-cu12: 12.1.105
    - nvidia-cuda-nvrtc-cu12: 12.1.105
    - nvidia-cuda-runtime-cu12: 12.1.105
    - nvidia-cudnn-cu12: 9.1.0.70
    - nvidia-cufft-cu12: 11.0.2.54
    - nvidia-curand-cu12: 10.3.2.106
    - nvidia-cusolver-cu12: 11.4.5.107
    - nvidia-cusparse-cu12: 12.1.0.106
    - nvidia-dali-cuda110: 1.42.0
    - nvidia-nccl-cu12: 2.20.5
    - nvidia-nvimgcodec-cu11: 0.3.0.5
    - nvidia-nvjitlink-cu12: 12.6.77
    - nvidia-nvtx-cu12: 12.1.105
    - omegaconf: 2.3.0
    - opencv-python: 4.10.0.84
    - opencv-python-headless: 4.10.0.84
    - packaging: 24.1
    - pandas: 2.2.3
    - pathspec: 0.12.1
    - pillow: 10.4.0
    - pip: 24.2
    - platformdirs: 4.3.6
    - plotly: 5.24.1
    - pluggy: 1.5.0
    - pprintpp: 0.4.0
    - priority: 2.0.0
    - proglog: 0.1.10
    - protobuf: 5.28.2
    - psutil: 6.0.0
    - py7zr: 0.22.0
    - pyarrow: 17.0.0
    - pybcj: 1.0.2
    - pycodestyle: 2.12.1
    - pycryptodomex: 3.21.0
    - pydash: 8.0.3
    - pydeck: 0.9.1
    - pyflakes: 3.2.0
    - pygments: 2.18.0
    - pymongo: 4.8.0
    - pyparsing: 3.1.4
    - pyppmd: 1.1.0
    - pytest: 8.3.3
    - pytest-xdist: 3.6.1
    - python-dateutil: 2.9.0.post0
    - pytorch-lightning: 1.9.5
    - pytz: 2024.2
    - pyyaml: 6.0.2
    - pyzstd: 0.16.1
    - rarfile: 4.2
    - referencing: 0.35.1
    - regex: 2024.9.11
    - requests: 2.32.3
    - retrying: 1.3.4
    - rich: 13.9.1
    - rpds-py: 0.20.0
    - s3transfer: 0.10.2
    - scikit-image: 0.24.0
    - scikit-learn: 1.5.2
    - scipy: 1.14.1
    - seaborn: 0.13.2
    - segment-anything: 1.0
    - setuptools: 75.1.0
    - shapely: 2.0.6
    - six: 1.16.0
    - smmap: 5.0.1
    - sniffio: 1.3.1
    - snowballstemmer: 2.2.0
    - sortedcontainers: 2.4.0
    - soupsieve: 2.6
    - sphinx: 7.4.7
    - sphinx-automodapi: 0.18.0
    - sphinx-copybutton: 0.5.2
    - sphinx-design: 0.6.1
    - sphinx-rtd-dark-mode: 1.3.0
    - sphinx-rtd-theme: 2.0.0
    - sphinxcontrib-applehelp: 2.0.0
    - sphinxcontrib-devhelp: 2.0.0
    - sphinxcontrib-htmlhelp: 2.1.0
    - sphinxcontrib-jquery: 4.1
    - sphinxcontrib-jsmath: 1.0.1
    - sphinxcontrib-qthelp: 2.0.0
    - sphinxcontrib-serializinghtml: 2.0.0
    - sse-starlette: 0.10.3
    - sseclient-py: 1.8.0
    - starlette: 0.39.2
    - strawberry-graphql: 0.243.1
    - streamlit: 1.39.0
    - sympy: 1.13.3
    - tabulate: 0.9.0
    - taskgroup: 0.0.0a4
    - tenacity: 9.0.0
    - tensorboard: 2.18.0
    - tensorboard-data-server: 0.7.2
    - texttable: 1.7.0
    - threadpoolctl: 3.5.0
    - tifffile: 2024.9.20
    - toml: 0.10.2
    - tomli: 2.0.2
    - torch: 2.4.1
    - torchmetrics: 1.4.2
    - torchtyping: 0.1.5
    - torchvision: 0.19.1
    - tornado: 6.4.1
    - tqdm: 4.66.5
    - triton: 3.0.0
    - typeguard: 2.13.3
    - typing: 3.7.4.3
    - typing-extensions: 4.12.2
    - tzdata: 2024.2
    - tzlocal: 5.2
    - universal-analytics-python3: 1.1.1
    - urllib3: 2.2.3
    - voxel51-eta: 0.13.0
    - watchdog: 5.0.3
    - wcwidth: 0.2.13
    - werkzeug: 3.0.4
    - wheel: 0.44.0
    - wrapt: 1.16.0
    - wsproto: 1.2.0
    - xmltodict: 0.13.0
    - yarl: 1.13.1
    - zipp: 3.19.2
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.10.0
    - release: 5.15.153.1-microsoft-standard-WSL2
    - version: Proposal for help #1 SMP Fri Mar 29 23:14:13 UTC 2024

More info

No response

cc @justusschock @lantiga

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions