Skip to content

lightning.Fabric meets deadlock when loading nn.Module #20536

@forestbat

Description

@forestbat

Bug description

When I try to use lightning.Fabric.setup() to load torch.nn.Module under multi-process, the program will meet deadlock and stuck in lightning/fabric/strategies/launchers/subprocess_script.py.

I doubt this problem comes from popen start method of process, but I have not more evidence.

What version are you seeing the problem on?

v2.4, v2.5

How to reproduce the bug

I try to reproduce this bug with smaller demo with my project, but I failed, this bug seems only in my actual project:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
import lightning
fabric = lightning.Fabric(devices=[1, 2], num_nodes=1, strategy='ddp')


class MyEvaluator:
    def __init__(self):
        fabric.launch()

    def eval_model(self, dataset, crit):
        model = LinearModel()
        model = fabric.setup_module(module=model)
        model.eval()
        # I have changed type of model to custom class, but still can't reproduce this problem
        # model = Seq2Seq_Min_LSTM_GNN(en_input_size=30, de_input_size=18, output_size=2, hidden_size=256, forecast_history=168, forecast_length=56, graph=dgl.graph((torch.tensor([0, 1]), torch.tensor([1, 2]))))
        test_loader = fabric.setup_dataloaders(DataLoader(dataset, batch_size=10, shuffle=True,
                                                           num_workers=1, multiprocessing_context='spawn'))
        for x, y in test_loader:
            output = model(x)
            loss = crit(output, y)
            yield loss

class LinearModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(10, 2)

    def forward(self, x):
        return self.linear(x)


if __name__ == '__main__':
    x = torch.randn(100, 10)
    y = torch.rand(100, 2)
    dataset = TensorDataset(x, y)
    crit = nn.MSELoss()
    evaluator = MyEvaluator()
    for loss in evaluator.eval_model(dataset, crit):
        print(loss)

Error messages and logs

This is no error messages and logs when deadlock occurs. What should I do to know what happened in my program and give you enough messages?

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA RTX 5000 Ada Generation
    - NVIDIA A40
    - NVIDIA A40
    - available: True
    - version: 12.1
  • Lightning:
    - lightning: 2.4.0
    - lightning-utilities: 0.11.9
    - pytorch-lightning: 2.4.0
    - torch: 2.2.2
    - torchaudio: 2.2.2
    - torchdata: 0.7.1
    - torchmetrics: 1.6.0
    - torchvision: 0.17.2
  • Packages:
    - absl-py: 2.1.0
    - affine: 2.4.0
    - aiobotocore: 2.13.2
    - aiodns: 3.2.0
    - aiohappyeyeballs: 2.3.7
    - aiohttp: 3.10.4
    - aiohttp-client-cache: 0.11.1
    - aioitertools: 0.11.0
    - aiosignal: 1.3.1
    - aiosqlite: 0.20.0
    - annotated-types: 0.7.0
    - appdirs: 1.4.4
    - argon2-cffi: 23.1.0
    - argon2-cffi-bindings: 21.2.0
    - asciitree: 0.3.3
    - async-retriever: 0.17.0
    - attrs: 24.2.0
    - autocommand: 2.2.2
    - backports.tarfile: 1.2.0
    - black: 24.8.0
    - bleach: 6.1.0
    - bokeh: 3.5.1
    - boto3: 1.34.131
    - botocore: 1.34.131
    - branca: 0.7.2
    - brotli: 1.1.0
    - bump2version: 1.0.1
    - cachetools: 5.5.0
    - cartopy: 0.23.0
    - cattrs: 23.2.3
    - certifi: 2024.8.30
    - cffi: 1.17.0
    - cfgrib: 0.9.14.0
    - cftime: 1.6.4
    - chardet: 5.2.0
    - charset-normalizer: 3.3.2
    - click: 8.1.7
    - click-plugins: 1.1.1
    - cligj: 0.7.2
    - cloudpickle: 3.0.0
    - codetiming: 1.4.0
    - colorama: 0.4.6
    - contourpy: 1.2.1
    - cryptography: 43.0.0
    - cupy: 13.3.0
    - cycler: 0.12.1
    - cytoolz: 0.12.3
    - dask: 2024.8.1
    - dask-expr: 1.1.11
    - dataretrieval: 1.0.10
    - deepspeed: 0.16.1
    - defusedxml: 0.7.1
    - dgl: 2.2.1+cu121
    - distributed: 2024.8.1
    - docutils: 0.21.2
    - eccodes: 1.7.1
    - einops: 0.8.0
    - et-xmlfile: 1.1.0
    - exceptiongroup: 1.2.2
    - fasteners: 0.19
    - fastrlock: 0.8.2
    - filelock: 3.15.4
    - findlibs: 0.0.5
    - flake8: 7.1.1
    - flexcache: 0.3
    - flexparser: 0.3.1
    - folium: 0.17.0
    - fonttools: 4.53.1
    - frozenlist: 1.4.1
    - fsspec: 2024.6.1
    - geopandas: 1.0.1
    - gmpy2: 2.1.5
    - greenlet: 3.0.3
    - grpcio: 1.62.2
    - h2: 4.1.0
    - h5netcdf: 1.3.0
    - h5py: 3.11.0
    - hjson: 3.1.0
    - hpack: 4.0.0
    - hydrodataset: 0.1.13
    - hydrodatasource: 0.0.8
    - hydroerr: 1.24
    - hydrosignatures: 0.17.0
    - hydrotopo: 0.0.6
    - hydroutils: 0.0.12
    - hyperframe: 6.0.1
    - idna: 3.7
    - igraph: 0.11.6
    - importlib-metadata: 8.2.0
    - importlib-resources: 6.4.0
    - inflect: 7.3.1
    - iniconfig: 2.0.0
    - intake: 2.0.6
    - itsdangerous: 2.2.0
    - jaraco.classes: 3.4.0
    - jaraco.context: 5.3.0
    - jaraco.functools: 4.0.2
    - jaraco.text: 3.12.1
    - jeepney: 0.8.0
    - jinja2: 3.1.4
    - jmespath: 1.0.1
    - joblib: 1.4.2
    - kaggle: 1.6.17
    - kerchunk: 0.2.6
    - keyring: 25.3.0
    - kiwisolver: 1.4.5
    - lightning: 2.4.0
    - lightning-utilities: 0.11.9
    - llvmlite: 0.43.0
    - locket: 1.0.0
    - loguru: 0.7.2
    - lxml: 5.3.0
    - lz4: 4.3.3
    - markdown: 3.6
    - markdown-it-py: 3.0.0
    - markupsafe: 2.1.5
    - matplotlib: 3.9.2
    - mccabe: 0.7.0
    - mdurl: 0.1.2
    - minio: 7.2.8
    - more-itertools: 10.4.0
    - mpmath: 1.3.0
    - msgpack: 1.0.8
    - multidict: 6.0.5
    - mypy-extensions: 1.0.0
    - netcdf4: 1.7.1.post2
    - networkx: 3.3
    - nh3: 0.2.18
    - ninja: 1.11.1.3
    - nuitka: 2.4.7
    - numba: 0.60.0
    - numcodecs: 0.13.0
    - numpy: 1.26.4
    - nvidia-ml-py: 12.535.161
    - nvitop: 1.3.2
    - openpyxl: 3.1.5
    - ordered-set: 4.1.0
    - owslib: 0.31.0
    - packaging: 24.1
    - pandas: 2.2.2
    - partd: 1.4.2
    - pathspec: 0.12.1
    - pillow: 10.4.0
    - pint: 0.24.3
    - pint-pandas: 0.6.2
    - pint-xarray: 0.4
    - pip: 24.2
    - pkginfo: 1.10.0
    - platformdirs: 4.2.2
    - pluggy: 1.5.0
    - polars: 1.17.1
    - protobuf: 4.25.3
    - psutil: 6.0.0
    - psycopg2-binary: 2.9.9
    - py-cpuinfo: 9.0.0
    - pyarrow: 17.0.0
    - pyarrow-hotfix: 0.6
    - pycairo: 1.27.0
    - pycares: 4.4.0
    - pycodestyle: 2.12.1
    - pycparser: 2.22
    - pycryptodome: 3.20.0
    - pydantic: 2.8.2
    - pydantic-core: 2.20.1
    - pyflakes: 3.2.0
    - pygeohydro: 0.17.0
    - pygeoogc: 0.17.0
    - pygeoutils: 0.17.0
    - pygments: 2.18.0
    - pykalman: 0.9.7
    - pynhd: 0.17.0
    - pyogrio: 0.9.0
    - pyparsing: 3.1.2
    - pyproj: 3.6.1
    - pyshp: 2.3.1
    - pysocks: 1.7.1
    - pytest: 8.3.2
    - python-dateutil: 2.9.0
    - python-slugify: 8.0.4
    - pytorch-lightning: 2.4.0
    - pytz: 2024.1
    - pyyaml: 6.0.2
    - rasterio: 1.3.10
    - readme-renderer: 44.0
    - requests: 2.32.3
    - requests-cache: 1.2.1
    - requests-toolbelt: 1.0.0
    - rfc3986: 2.0.0
    - rich: 13.7.1
    - rioxarray: 0.17.0
    - s3fs: 2024.6.1
    - s3transfer: 0.10.2
    - scikit-learn: 1.5.1
    - scipy: 1.14.0
    - seaborn: 0.13.2
    - secretstorage: 3.3.3
    - setuptools: 72.2.0
    - shap: 0.45.1
    - shapely: 2.0.1
    - six: 1.16.0
    - slicer: 0.0.8
    - snuggs: 1.4.7
    - sortedcontainers: 2.4.0
    - sqlalchemy: 2.0.32
    - sympy: 1.13.2
    - tblib: 3.0.0
    - tbparse: 0.0.9
    - tensorboard: 2.17.1
    - tensorboard-data-server: 0.7.0
    - termcolor: 2.5.0
    - text-unidecode: 1.3
    - texttable: 1.7.0
    - threadpoolctl: 3.5.0
    - tomli: 2.0.1
    - toolz: 0.12.1
    - torch: 2.2.2
    - torchaudio: 2.2.2
    - torchdata: 0.7.1
    - torchmetrics: 1.6.0
    - torchvision: 0.17.2
    - tornado: 6.4.1
    - tqdm: 4.66.5
    - triton: 2.2.0
    - twine: 5.1.1
    - typeguard: 4.3.0
    - typing-extensions: 4.12.2
    - tzdata: 2024.1
    - tzfpy: 0.15.5
    - ujson: 5.10.0
    - url-normalize: 1.4.3
    - urllib3: 2.2.2
    - webencodings: 0.5.1
    - werkzeug: 3.0.3
    - wget: 3.2
    - wheel: 0.44.0
    - wrapt: 1.16.0
    - xarray: 2024.7.0
    - xlrd: 2.0.1
    - xyzservices: 2024.6.0
    - yarl: 1.9.4
    - zarr: 2.18.2
    - zict: 3.0.0
    - zipp: 3.20.0
    - zstandard: 0.23.0
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.11.9
    - release: 5.4.0-195-generic
    - version: Demos #215-Ubuntu SMP Fri Aug 2 18:28:05 UTC 2024

More info

In my project, this error occurs in training and evaluating.

Hope the file can give you more help and tell me how to reproduce or solve it correctly.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions