Skip to content

ModelCheckpoint broadcast fails on multiple GPUsΒ #20597

@navotoz

Description

@navotoz

Bug description

When training on 4 A100, when reaching on_train_epoch_end (no custom implementation), the ModelCheckpoint callback is invoked.
Inside the callback, the logger (?) tries to broadcast the weights (?) to all devices.

I was able to trace the issue to torch.distributed.distributed_c10d.py, lines 3382-3504.
Line 3490 tries to allocate a tensor with object_sizes_tensor, but the rank != src, so object_sizes_tensor is init to a big random number, and object_tensor is init with this hugh integer.

Versions:
pytorch 2.6.0
lightning 2.5.0

What version are you seeing the problem on?

v2.5

How to reproduce the bug

Error messages and logs

Error 1:

[rank1]: Traceback (most recent call last):
[rank1]:   File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
[rank1]:     return _run_code(code, main_globals, None,
[rank1]:   File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
[rank1]:     exec(code, run_globals)
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/menta/runner.py", line 459, in <module>
[rank1]:     main()
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/menta/runner.py", line 455, in main
[rank1]:     local_runner.run()
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/menta/runner.py", line 369, in run
[rank1]:     trainer.fit(model=module, datamodule=datamodule, ckpt_path=latest_checkpoint_path)
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 539, in fit
[rank1]:     call._call_and_handle_interrupt(
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 46, in _call_and_handle_interrupt
[rank1]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
[rank1]:     return function(*args, **kwargs)
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 575, in _fit_impl
[rank1]:     self._run(model, ckpt_path=ckpt_path)
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 982, in _run
[rank1]:     results = self._run_stage()
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1026, in _run_stage
[rank1]:     self.fit_loop.run()
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 217, in run
[rank1]:     self.on_advance_end()
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py", line 470, in on_advance_end
[rank1]:     call._call_callback_hooks(trainer, "on_train_epoch_end", monitoring_callbacks=True)
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 222, in _call_callback_hooks
[rank1]:     fn(trainer, trainer.lightning_module, *args, **kwargs)
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 325, in on_train_epoch_end
[rank1]:     self._save_topk_checkpoint(trainer, monitor_candidates)
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 387, in _save_topk_checkpoint
[rank1]:     self._save_none_monitor_checkpoint(trainer, monitor_candidates)
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 712, in _save_none_monitor_checkpoint
[rank1]:     filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer, self.best_model_path)
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 663, in _get_metric_interpolated_filepath_name
[rank1]:     while self.file_exists(filepath, trainer) and filepath != del_filepath:
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 776, in file_exists
[rank1]:     return trainer.strategy.broadcast(exists)
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/strategies/ddp.py", line 307, in broadcast
[rank1]:     torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 83, in wrapper
[rank1]:     return func(*args, **kwargs)
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 3140, in broadcast_object_list
[rank1]:     object_tensor = torch.empty(  # type: ignore[call-overload]
[rank1]: RuntimeError: /opt/conda/conda-bld/pytorch_1729647352509/work/build/aten/src/ATen/RegisterCUDA.cpp:7275: SymIntArrayRef expected to contain only concrete integers

Error 2:

[rank1]:   File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 325, in on_train_epoch_end
[rank1]:     self._save_topk_checkpoint(trainer, monitor_candidates)
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 387, in _save_topk_checkpoint
[rank1]:     self._save_none_monitor_checkpoint(trainer, monitor_candidates)
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 712, in _save_none_monitor_checkpoint
[rank1]:     filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer, self.best_model_path)
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 663, in _get_metric_interpolated_filepath_name
[rank1]:     while self.file_exists(filepath, trainer) and filepath != del_filepath:
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py", line 776, in file_exists
[rank1]:     return trainer.strategy.broadcast(exists)
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/strategies/ddp.py", line 307, in broadcast
[rank1]:     torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD)
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 83, in wrapper
[rank1]:     return func(*args, **kwargs)
[rank1]:   File "/opt/conda/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 3140, in broadcast_object_list
[rank1]:     object_tensor = torch.empty(  # type: ignore[call-overload]
[rank1]: torch.OutOfMemoryError: CUDA out of memory. Tried to allocate more than 1EB memory.

Environment

Current environment
  • CUDA:
    - GPU:
    - NVIDIA A100-SXM4-80GB
    - NVIDIA A100-SXM4-80GB
    - NVIDIA A100-SXM4-80GB
    - NVIDIA A100-SXM4-80GB
    - available: True
    - version: 12.4
  • Lightning:
    - lightning: 2.5.0.post0
    - lightning-utilities: 0.12.0
    - pytorch-lightning: 2.5.0.post0
    - sagemaker-pytorch-training: 2.8.1
    - torch: 2.6.0
    - torchaudio: 2.5.1
    - torchmetrics: 1.6.1
    - torchvision: 0.20.1
  • Packages:
    - absl-py: 2.1.0
    - aiobotocore: 2.19.0
    - aiohappyeyeballs: 2.4.6
    - aiohttp: 3.11.12
    - aioitertools: 0.12.0
    - aiosignal: 1.3.2
    - angie-data-reader: 0.10.0
    - angie-shuffle-service: 0.18.1
    - annotated-types: 0.7.0
    - antlr4-python3-runtime: 4.9.3
    - anyio: 4.8.0
    - argon2-cffi: 23.1.0
    - argon2-cffi-bindings: 21.2.0
    - arrow: 1.3.0
    - asttokens: 3.0.0
    - async-lru: 2.0.4
    - async-timeout: 5.0.1
    - attrs: 23.2.0
    - autocommand: 2.2.2
    - babel: 2.17.0
    - backports.tarfile: 1.2.0
    - bcrypt: 4.2.1
    - beartype: 0.19.0
    - beautifulsoup4: 4.13.3
    - bleach: 6.2.0
    - blinker: 1.9.0
    - boto3: 1.36.3
    - botocore: 1.36.3
    - brotli: 1.1.0
    - cached-property: 1.5.2
    - cachetools: 5.5.1
    - certifi: 2025.1.31
    - cffi: 1.17.1
    - charset-normalizer: 3.4.1
    - click: 8.1.8
    - cloud-logging: 0.5.53
    - cloud-storage-utils: 1.0.63
    - cloudpickle: 2.2.1
    - colorama: 0.4.6
    - comet-ml: 3.49.1
    - comm: 0.2.2
    - configobj: 5.0.9
    - contextlib2: 21.6.0
    - contourpy: 1.3.1
    - cryptography: 42.0.8
    - cycler: 0.12.1
    - cython: 3.0.12
    - debugpy: 1.8.12
    - decorator: 5.1.1
    - defusedxml: 0.7.1
    - dill: 0.3.9
    - dl-optimizer: 25.10218.0
    - dl-optimizer-common: 25.10218.0
    - docker: 7.1.0
    - dulwich: 0.22.7
    - durationpy: 0.9
    - einops: 0.8.1
    - everett: 3.1.0
    - exceptiongroup: 1.2.2
    - execnet: 2.1.1
    - executing: 2.1.0
    - faker: 36.1.1
    - fastjsonschema: 2.21.1
    - filelock: 3.17.0
    - fonttools: 4.56.0
    - fpdf: 1.7.2
    - fqdn: 1.5.1
    - frozenlist: 1.5.0
    - fsspec: 2025.2.0
    - future: 1.0.0
    - getdaft: 0.3.0.dev0
    - gevent: 24.11.1
    - gmpy2: 2.1.5
    - google-api-core: 2.24.1
    - google-api-python-client: 2.161.0
    - google-auth: 2.38.0
    - google-auth-httplib2: 0.2.0
    - google-cloud-container: 2.56.0
    - google-cloud-core: 2.4.1
    - google-cloud-secret-manager: 2.23.0
    - google-cloud-storage: 3.0.0
    - google-crc32c: 1.1.2
    - google-pasta: 0.2.0
    - google-resumable-media: 2.7.2
    - googleapis-common-protos: 1.67.0
    - greenlet: 3.1.1
    - grpc-google-iam-v1: 0.14.0
    - grpcio: 1.62.2
    - grpcio-status: 1.62.2
    - h11: 0.14.0
    - h2: 4.2.0
    - hera: 5.18.0
    - hpack: 4.1.0
    - httpcore: 1.0.7
    - httplib2: 0.22.0
    - httpx: 0.28.1
    - huggingface-hub: 0.29.0
    - hydra-core: 1.3.2
    - hyperframe: 6.1.0
    - hyperopt: 0.2.7
    - idna: 3.10
    - importlib-metadata: 6.10.0
    - importlib-resources: 6.5.2
    - inflect: 7.3.1
    - iniconfig: 2.0.0
    - inotify-simple: 1.2.1
    - ipykernel: 6.29.5
    - ipython: 8.32.0
    - ipywidgets: 8.1.5
    - isoduration: 20.11.0
    - jaraco.collections: 5.1.0
    - jaraco.context: 5.3.0
    - jaraco.functools: 4.0.1
    - jaraco.text: 3.12.1
    - jedi: 0.19.2
    - jinja2: 3.1.5
    - jmespath: 1.0.1
    - joblib: 1.4.2
    - json5: 0.10.0
    - jsonpointer: 3.0.0
    - jsonschema: 4.23.0
    - jsonschema-specifications: 2024.10.1
    - jupyter-client: 8.6.3
    - jupyter-core: 5.7.2
    - jupyter-events: 0.12.0
    - jupyter-lsp: 2.2.5
    - jupyter-server: 2.15.0
    - jupyter-server-terminals: 0.5.3
    - jupyterlab: 4.3.5
    - jupyterlab-pygments: 0.3.0
    - jupyterlab-server: 2.27.3
    - jupyterlab-widgets: 3.0.13
    - kiwisolver: 1.4.7
    - kubernetes: 32.0.1
    - lightning: 2.5.0.post0
    - lightning-utilities: 0.12.0
    - lxml: 5.3.1
    - lz4: 4.3.3
    - markdown: 3.6
    - markdown-it-py: 3.0.0
    - markupsafe: 3.0.2
    - matplotlib: 3.10.0
    - matplotlib-inline: 0.1.7
    - mdurl: 0.1.2
    - me-auth-client: 0.34.0
    - menta3: 0.32.9
    - mistune: 3.1.2
    - mobileye-pyq: 0.4.0
    - mobileye.cloud-users: 0.2.6
    - mobileye.metal: 4.2.6
    - mobileye.nebula-tools: 2.0.12
    - more-itertools: 10.3.0
    - mpmath: 1.3.0
    - msal: 1.31.1
    - multidict: 6.1.0
    - multiprocess: 0.70.17
    - munkres: 1.1.4
    - nbclient: 0.10.2
    - nbconvert: 7.16.6
    - nbformat: 5.10.4
    - nest-asyncio: 1.6.0
    - networkx: 3.4.2
    - notebook: 7.3.2
    - notebook-shim: 0.2.4
    - nssd-errors: 1.6.8
    - numpy: 1.26.4
    - nvidia-cublas-cu12: 12.4.5.8
    - nvidia-cuda-cupti-cu12: 12.4.127
    - nvidia-cuda-nvrtc-cu12: 12.4.127
    - nvidia-cuda-runtime-cu12: 12.4.127
    - nvidia-cudnn-cu12: 9.1.0.70
    - nvidia-cufft-cu12: 11.2.1.3
    - nvidia-curand-cu12: 10.3.5.147
    - nvidia-cusolver-cu12: 11.6.1.9
    - nvidia-cusparse-cu12: 12.3.1.170
    - nvidia-cusparselt-cu12: 0.6.2
    - nvidia-nccl-cu12: 2.21.5
    - nvidia-nvjitlink-cu12: 12.4.127
    - nvidia-nvtx-cu12: 12.4.127
    - oauthlib: 3.2.2
    - omegaconf: 2.3.0
    - opencv-python: 4.11.0.86
    - overrides: 7.7.0
    - packaging: 24.2
    - pandas: 2.2.3
    - pandocfilters: 1.5.0
    - paramiko: 3.5.1
    - parso: 0.8.4
    - pathos: 0.3.3
    - pexpect: 4.9.0
    - pickleshare: 0.7.5
    - pillow: 11.1.0
    - pip: 25.0.1
    - pkgutil-resolve-name: 1.3.10
    - platformdirs: 4.3.6
    - pluggy: 1.5.0
    - pox: 0.3.5
    - ppft: 1.7.6.9
    - prometheus-client: 0.21.1
    - prompt-toolkit: 3.0.50
    - propcache: 0.2.1
    - proto-plus: 1.26.0
    - protobuf: 3.20.3
    - psutil: 6.1.1
    - ptyprocess: 0.7.0
    - pure-eval: 0.2.3
    - py4j: 0.10.9.9
    - pyarrow: 16.1.0
    - pyasn1: 0.6.1
    - pyasn1-modules: 0.4.1
    - pycocotools: 2.0.8
    - pycparser: 2.22
    - pydantic: 2.10.6
    - pydantic-core: 2.27.2
    - pydot: 3.0.4
    - pygments: 2.19.1
    - pyjwt: 2.10.1
    - pynacl: 1.5.0
    - pynamodb: 5.5.1
    - pyopenssl: 25.0.0
    - pyparsing: 3.2.1
    - pyside6: 6.8.2
    - pysocks: 1.7.1
    - pytest: 7.4.4
    - pytest-mock: 3.14.0
    - pytest-xdist: 3.6.1
    - python-box: 6.1.0
    - python-dateutil: 2.9.0.post0
    - python-json-logger: 2.0.7
    - pytorch-lightning: 2.5.0.post0
    - pytz: 2024.1
    - pyu2f: 0.1.5
    - pyyaml: 6.0.2
    - pyzmq: 26.2.1
    - referencing: 0.36.2
    - requests: 2.32.3
    - requests-oauthlib: 2.0.0
    - requests-toolbelt: 1.0.0
    - retrying: 1.3.4
    - rfc3339-validator: 0.1.4
    - rfc3986-validator: 0.1.1
    - rich: 13.9.4
    - rpds-py: 0.22.3
    - rsa: 4.9
    - ruamel.yaml: 0.18.10
    - ruamel.yaml.clib: 0.2.8
    - s3fs: 2025.2.0
    - s3path: 0.6.0
    - s3transfer: 0.11.2
    - safetensors: 0.5.2
    - sagemaker: 2.210.0
    - sagemaker-pytorch-training: 2.8.1
    - sagemaker-training: 4.9.0
    - schema: 0.7.7
    - scikit-learn: 1.6.1
    - scipy: 1.15.2
    - semantic-version: 2.10.0
    - send2trash: 1.8.3
    - sentry-sdk: 2.22.0
    - setuptools: 75.8.0
    - shiboken6: 6.8.2
    - six: 1.17.0
    - smart-open: 7.1.0
    - smdebug-rulesconfig: 1.0.1
    - sniffio: 1.3.1
    - soupsieve: 2.5
    - stack-data: 0.6.3
    - sympy: 1.13.1
    - tabulate: 0.9.0
    - tblib: 2.0.0
    - tenacity: 9.0.0
    - tensorboard: 2.18.0
    - tensorboard-data-server: 0.7.0
    - termcolor: 2.5.0
    - terminado: 0.18.1
    - threadpoolctl: 3.5.0
    - timm: 1.0.14
    - tinycss2: 1.4.0
    - toml: 0.10.2
    - tomli: 2.2.1
    - torch: 2.6.0
    - torchaudio: 2.5.1
    - torchmetrics: 1.6.1
    - torchvision: 0.20.1
    - tornado: 6.4.2
    - tqdm: 4.67.1
    - traitlets: 5.14.3
    - triton: 3.2.0
    - typeguard: 4.3.0
    - types-python-dateutil: 2.9.0.20241206
    - typing-extensions: 4.12.2
    - typing-utils: 0.1.0
    - tzdata: 2025.1
    - unicodedata2: 16.0.0
    - uri-template: 1.3.0
    - uritemplate: 4.1.1
    - urllib3: 1.26.19
    - wcwidth: 0.2.13
    - webcolors: 24.11.1
    - webencodings: 0.5.1
    - websocket-client: 1.8.0
    - werkzeug: 3.1.3
    - wheel: 0.45.1
    - widgetsnbextension: 4.0.13
    - wrapt: 1.17.2
    - wurlitzer: 3.1.1
    - yarl: 1.18.3
    - zipp: 3.21.0
    - zope.event: 5.0
    - zope.interface: 7.2
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.10.16
    - release: 5.15.0-105-generic
    - version: added gan templateΒ #115-Ubuntu SMP Mon Apr 15 09:52:04 UTC 2024

More info

No response

cc @justusschock

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions