diff --git a/.github/workflows/_legacy-checkpoints.yml b/.github/workflows/_legacy-checkpoints.yml
index 8cceb5870af8d..6c37281fc04ec 100644
--- a/.github/workflows/_legacy-checkpoints.yml
+++ b/.github/workflows/_legacy-checkpoints.yml
@@ -57,28 +57,32 @@ jobs:
steps:
- uses: actions/checkout@v5
- - uses: actions/setup-python@v5
+ - name: Install uv and set Python version
+ uses: astral-sh/setup-uv@v6
with:
- # Python version here needs to be supported by all PL versions listed in back-compatible-versions.txt.
python-version: "3.9"
+ # TODO: Avoid activating environment like this
+ # see: https://github.com/astral-sh/setup-uv/tree/v6/?tab=readme-ov-file#activate-environment
+ activate-environment: true
+ enable-cache: true
- name: Install PL from source
env:
PACKAGE_NAME: pytorch
FREEZE_REQUIREMENTS: 1
timeout-minutes: 20
- run: pip install . --extra-index-url="${TORCH_URL}"
+ run: uv pip install . --extra-index-url="${TORCH_URL}"
if: inputs.pl_version == ''
- name: Install PL version
timeout-minutes: 20
- run: pip install "pytorch-lightning==${{ inputs.pl_version }}" --extra-index-url="${TORCH_URL}"
+ run: uv pip install "pytorch-lightning==${{ inputs.pl_version }}" --extra-index-url="${TORCH_URL}"
if: inputs.pl_version != ''
- name: Adjust tests -> PL
if: ${{ matrix.pkg-name != 'lightning' }}
run: |
- pip install -q -r .actions/requirements.txt
+ uv pip install -q -r .actions/requirements.txt
python .actions/assistant.py copy_replace_imports --source_dir="./tests" \
--source_import="lightning.fabric,lightning.pytorch" \
--target_import="lightning_fabric,pytorch_lightning"
@@ -115,7 +119,7 @@ jobs:
# export to env bool if secrets.AWS_REGION is not empty
run: echo "WITH_SECRETS=$([ -n '${{ secrets.AWS_REGION }}' ] && echo 1 || echo 0)" >> $GITHUB_ENV
- - run: pip install -r requirements/ci.txt
+ - run: uv pip install -r requirements/ci.txt
- name: Upload checkpoints to S3
if: ${{ env.WITH_SECRETS == '1' }}
working-directory: ${{ env.LEGACY_FOLDER }}
diff --git a/.github/workflows/ci-tests-fabric.yml b/.github/workflows/ci-tests-fabric.yml
index f4c66f425cc71..615633c5311bf 100644
--- a/.github/workflows/ci-tests-fabric.yml
+++ b/.github/workflows/ci-tests-fabric.yml
@@ -38,44 +38,30 @@ jobs:
strategy:
fail-fast: false
matrix:
- include:
+ os: [macOS-14, ubuntu-22.04, windows-2022]
+ config:
# only run PyTorch latest
- - { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- - { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- - { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
- - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
- - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
- - { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- - { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
- - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
- - { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
- - { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
- - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
- - { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
+ - { pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
+ - { pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
+ - { pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
+ - { pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
+ - { pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
+
# only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues
- - { os: "macOS-14", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.6" }
- - { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.6" }
- - { os: "windows-2022", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.6" }
- # "oldest" versions tests, only on minimum Python
- - { os: "macOS-14", pkg-name: "fabric", pytorch-version: "2.1", requires: "oldest" }
- - { os: "ubuntu-22.04", pkg-name: "fabric", pytorch-version: "2.1", requires: "oldest" }
- - { os: "windows-2022", pkg-name: "fabric", pytorch-version: "2.1", requires: "oldest" }
+ - { pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.6" }
+
# "fabric" installs the standalone package
- - { os: "macOS-14", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.7" }
- - { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.7" }
- - { os: "windows-2022", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.7" }
+ - { pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.7" }
+
# adding recently cut Torch 2.7 - FUTURE
- - { os: "macOS-14", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.8" }
- - { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.8" }
- - { os: "windows-2022", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.8" }
+ - { pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.8" }
+
+ # "oldest" versions tests, only on minimum Python
+ - { pkg-name: "fabric", pytorch-version: "2.1", requires: "oldest" }
timeout-minutes: 25 # because of building grpcio on Mac
env:
- PACKAGE_NAME: ${{ matrix.pkg-name }}
+ PACKAGE_NAME: ${{ matrix.config.pkg-name }}
FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }}
- PYPI_CACHE_DIR: "_pip-wheels"
TORCH_URL_STABLE: "https://download.pytorch.org/whl/cpu/"
TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/"
# TODO: Remove this - Enable running MPS tests on this platform
@@ -83,68 +69,67 @@ jobs:
steps:
- uses: actions/checkout@v5
- - name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v5
+ - name: Install uv and set Python version
+ uses: astral-sh/setup-uv@v6
with:
- python-version: ${{ matrix.python-version || '3.9' }}
+ python-version: ${{ matrix.config.python-version || '3.9' }}
+ # TODO: Avoid activating environment like this
+ # see: https://github.com/astral-sh/setup-uv/tree/v6/?tab=readme-ov-file#activate-environment
+ activate-environment: true
+ enable-cache: true
+
+ - name: Basic setup
+ run: uv pip install -q -r .actions/requirements.txt
- - name: basic setup
- run: pip install -q -r .actions/requirements.txt
+ - name: Append Env. vars for Linux
+ if: ${{ runner.os == 'Linux' }}
+ run: echo "GLOO_SOCKET_IFNAME=eth0" >> $GITHUB_ENV
+
+ - name: Append Env. vars for MacOS
+ if: ${{ runner.os == 'macOS' }}
+ run: echo "GLOO_SOCKET_IFNAME=lo0" >> $GITHUB_ENV
+
+ - name: Append Env. vars for Windows
+ if: ${{ runner.os == 'windows' }}
+ run: |
+ # Avoid issue on Windows with PyTorch 2.4: "RuntimeError: use_libuv was requested but PyTorch was build without libuv support"
+ echo "USE_LIBUV=0" >> $GITHUB_ENV
- name: Set min. dependencies
- if: ${{ matrix.requires == 'oldest' }}
+ if: ${{ matrix.config.requires == 'oldest' }}
run: |
cd requirements/fabric
- pip install -U "lightning-utilities[cli]"
+ uv pip install -U "lightning-utilities[cli]"
python -m lightning_utilities.cli requirements set-oldest --req_files "['base.txt', 'strategies.txt', 'test.txt']"
- pip install "cython<3.0" wheel
- pip install "pyyaml==5.4" --no-build-isolation
+ uv pip install "cython<3.0" wheel
+ uv pip install "pyyaml==5.4" --no-build-isolation
- name: Adjust PyTorch versions in requirements files
- if: ${{ matrix.requires != 'oldest' }}
+ if: ${{ matrix.config.requires != 'oldest' }}
run: |
- pip install -q -r requirements/ci.txt
+ uv pip install -q -r requirements/ci.txt
python -m wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/adjust-torch-versions.py
for fpath in `ls requirements/**/*.txt`; do \
- python ./adjust-torch-versions.py $fpath ${{ matrix.pytorch-version }}; \
+ python ./adjust-torch-versions.py $fpath ${{ matrix.config.pytorch-version }}; \
done
- - name: pip wheels cache
- uses: actions/cache/restore@v4
- with:
- path: ${{ env.PYPI_CACHE_DIR }}
- key: pypi_wheels
- - run: |
- mkdir -p $PYPI_CACHE_DIR
- ls -lh $PYPI_CACHE_DIR
-
- name: Expand Env. variables
run: |
# Switch PyTorch URL between stable and test/future
- python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.7' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV
+ python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.config.pytorch-version }}' == '2.7' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV
# Switch coverage scope
- python -c "print('COVERAGE_SCOPE=' + str('lightning' if '${{matrix.pkg-name}}' == 'lightning' else 'lightning_fabric'))" >> $GITHUB_ENV
+ python -c "print('COVERAGE_SCOPE=' + str('lightning' if '${{matrix.config.pkg-name}}' == 'lightning' else 'lightning_fabric'))" >> $GITHUB_ENV
# if you install mono-package set dependency only for this subpackage
- python -c "print('EXTRA_PREFIX=' + str('' if '${{matrix.pkg-name}}' != 'lightning' else 'fabric-'))" >> $GITHUB_ENV
- - name: Append Env. vars for MacOS
- if: ${{ runner.os == 'macOS' }}
- run: |
- # trying to avoid "gloo" issue with SIGABRT
- echo "GLOO_SOCKET_IFNAME=lo0" >> $GITHUB_ENV
- - name: Append Env. vars for Windows
- if: ${{ runner.os == 'windows' }}
- run: |
- # Avoid issue on Windows with PyTorch 2.4: "RuntimeError: use_libuv was requested but PyTorch was build without libuv support"
- echo "USE_LIBUV=0" >> $GITHUB_ENV
+ python -c "print('EXTRA_PREFIX=' + str('' if '${{matrix.config.pkg-name}}' != 'lightning' else 'fabric-'))" >> $GITHUB_ENV
- name: Install package & dependencies
timeout-minutes: 20
run: |
- pip install -e ".[${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies]" \
- -U --upgrade-strategy=eager --prefer-binary \
- --extra-index-url="${TORCH_URL}" \
- --find-links="${PYPI_CACHE_DIR}"
- pip list
+ uv pip install ".[${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies]" \
+ --upgrade \
+ --find-links="${TORCH_URL}"
+ uv pip list
+
- name: Dump handy wheels
if: github.event_name == 'push' && github.ref == 'refs/heads/master'
continue-on-error: true
@@ -155,7 +140,7 @@ jobs:
cache-key: "pypi_wheels"
- name: Adjust tests
- if: ${{ matrix.pkg-name != 'lightning' }}
+ if: ${{ matrix.config.pkg-name != 'lightning' }}
run: |
python .actions/assistant.py copy_replace_imports --source_dir="./tests" \
--source_import="lightning.fabric" --target_import="lightning_fabric"
@@ -188,10 +173,13 @@ jobs:
with:
token: ${{ secrets.CODECOV_TOKEN }}
file: tests/tests_fabric/coverage.xml
- flags: ${{ env.COVERAGE_SCOPE }},cpu,pytest,python${{ matrix.python-version }}
+ flags: ${{ env.COVERAGE_SCOPE }},cpu,pytest,python${{ matrix.config.python-version }}
name: CPU-coverage
fail_ci_if_error: false
+ - name: Minimize uv cache
+ run: uv cache prune --ci
+
fabric-cpu-guardian:
runs-on: ubuntu-latest
needs: fabric-cpu
diff --git a/.github/workflows/ci-tests-pytorch.yml b/.github/workflows/ci-tests-pytorch.yml
index e5e9ddd23c06e..41d895f14ae86 100644
--- a/.github/workflows/ci-tests-pytorch.yml
+++ b/.github/workflows/ci-tests-pytorch.yml
@@ -42,115 +42,106 @@ jobs:
strategy:
fail-fast: false
matrix:
- include:
+ os: [macOS-14, ubuntu-22.04, windows-2022]
+ config:
# only run PyTorch latest
- - { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- - { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- - { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
- - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
- - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
- - { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
- - { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
- - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
- - { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
- - { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
- - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
- - { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
+ - { pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
+ - { pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" }
+ - { pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" }
+ - { pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" }
+ - { pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" }
+
# only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues
- - { os: "macOS-14", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.6" }
- - { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.6" }
- - { os: "windows-2022", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.6" }
- # "oldest" versions tests, only on minimum Python
- - { os: "macOS-14", pkg-name: "pytorch", pytorch-version: "2.1", requires: "oldest" }
- - { os: "ubuntu-22.04", pkg-name: "pytorch", pytorch-version: "2.1", requires: "oldest" }
- - { os: "windows-2022", pkg-name: "pytorch", pytorch-version: "2.1", requires: "oldest" }
+ - { pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.6" }
+
# "pytorch" installs the standalone package
- - { os: "macOS-14", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.7" }
- - { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.7" }
- - { os: "windows-2022", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.7" }
+ - { pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.7" }
+
# adding recently cut Torch 2.7 - FUTURE
- - { os: "macOS-14", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.8" }
- - { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.8" }
- - { os: "windows-2022", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.8" }
+ - { pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.8" }
+
+ # "oldest" versions tests, only on minimum Python
+ - { pkg-name: "pytorch", pytorch-version: "2.1", requires: "oldest" }
timeout-minutes: 50
env:
- PACKAGE_NAME: ${{ matrix.pkg-name }}
+ PACKAGE_NAME: ${{ matrix.config.pkg-name }}
TORCH_URL: "https://download.pytorch.org/whl/cpu/"
TORCH_URL_STABLE: "https://download.pytorch.org/whl/cpu/"
TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/"
FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }}
- PYPI_CACHE_DIR: "_pip-wheels"
# TODO: Remove this - Enable running MPS tests on this platform
DISABLE_MPS: ${{ matrix.os == 'macOS-14' && '1' || '0' }}
steps:
- uses: actions/checkout@v5
- - name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v5
+ - name: Install uv and set Python version
+ uses: astral-sh/setup-uv@v6
with:
- python-version: ${{ matrix.python-version || '3.9' }}
+ python-version: ${{ matrix.config.python-version || '3.9' }}
+ # TODO: Avoid activating environment like this
+ # see: https://github.com/astral-sh/setup-uv/tree/v6/?tab=readme-ov-file#activate-environment
+ activate-environment: true
+ enable-cache: true
+
+ - name: Basic setup
+ run: uv pip install -q -r .actions/requirements.txt
- - name: basic setup
- run: pip install -q -r .actions/requirements.txt
+ - name: Append Env. vars for Linux
+ if: ${{ runner.os == 'Linux' }}
+ run: echo "GLOO_SOCKET_IFNAME=eth0" >> $GITHUB_ENV
+
+ - name: Append Env. vars for MacOS
+ if: ${{ runner.os == 'macOS' }}
+ run: echo "GLOO_SOCKET_IFNAME=lo0" >> $GITHUB_ENV
- name: Set min. dependencies
- if: ${{ matrix.requires == 'oldest' }}
+ if: ${{ matrix.config.requires == 'oldest' }}
run: |
cd requirements/pytorch
- pip install -U "lightning-utilities[cli]"
+ uv pip install -U "lightning-utilities[cli]"
python -m lightning_utilities.cli requirements set-oldest --req_files "['base.txt', 'extra.txt', 'strategies.txt', 'examples.txt', 'test.txt']"
- pip install "cython<3.0" wheel
- pip install "pyyaml==5.4" --no-build-isolation
+ uv pip install "cython<3.0" wheel
+ uv pip install "pyyaml==5.4" --no-build-isolation
- name: Adjust PyTorch versions in requirements files
- if: ${{ matrix.requires != 'oldest' }}
+ if: ${{ matrix.config.requires != 'oldest' }}
run: |
- pip install -q -r requirements/ci.txt
+ uv pip install -q -r requirements/ci.txt
python -m wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/adjust-torch-versions.py
for fpath in `ls requirements/**/*.txt`; do \
- python ./adjust-torch-versions.py $fpath ${{ matrix.pytorch-version }}; \
+ python ./adjust-torch-versions.py $fpath ${{ matrix.config.pytorch-version }}; \
done
cat requirements/pytorch/base.txt
- - name: pip wheels cache
- uses: actions/cache/restore@v4
- with:
- path: ${{ env.PYPI_CACHE_DIR }}
- key: pypi_wheels
- - run: |
- mkdir -p $PYPI_CACHE_DIR
- ls -lh $PYPI_CACHE_DIR
-
- name: Env. variables
run: |
# Switch PyTorch URL between stable and test/future
- python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.7' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV
+ python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.config.pytorch-version }}' == '2.7' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV
# Switch coverage scope
- python -c "print('COVERAGE_SCOPE=' + str('lightning' if '${{matrix.pkg-name}}' == 'lightning' else 'pytorch_lightning'))" >> $GITHUB_ENV
+ python -c "print('COVERAGE_SCOPE=' + str('lightning' if '${{matrix.config.pkg-name}}' == 'lightning' else 'pytorch_lightning'))" >> $GITHUB_ENV
# if you install mono-package set dependency only for this subpackage
- python -c "print('EXTRA_PREFIX=' + str('' if '${{matrix.pkg-name}}' != 'lightning' else 'pytorch-'))" >> $GITHUB_ENV
+ python -c "print('EXTRA_PREFIX=' + str('' if '${{matrix.config.pkg-name}}' != 'lightning' else 'pytorch-'))" >> $GITHUB_ENV
# Avoid issue on Windows with PyTorch 2.4: "RuntimeError: use_libuv was requested but PyTorch was build without libuv support"
- python -c "print('USE_LIBUV=0' if '${{matrix.os}}' == 'windows-2022' and '${{matrix.pytorch-version}}' == '2.4' else '')" >> $GITHUB_ENV
+ python -c "print('USE_LIBUV=0' if '${{matrix.os}}' == 'windows-2022' and '${{matrix.config.pytorch-version}}' == '2.4' else '')" >> $GITHUB_ENV
- name: Install package & dependencies
timeout-minutes: 20
run: |
- pip install ".[${EXTRA_PREFIX}extra,${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies]" \
- -U --upgrade-strategy=eager --prefer-binary \
+ uv pip install ".[${EXTRA_PREFIX}extra,${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies]" \
+ --upgrade \
-r requirements/_integrations/accelerators.txt \
- --extra-index-url="${TORCH_URL}" \
- --find-links="${PYPI_CACHE_DIR}"
- pip list
+ --find-links="${TORCH_URL}"
+ uv pip list
+
- name: Drop LAI from extensions
- if: ${{ matrix.pkg-name != 'lightning' }}
+ if: ${{ matrix.config.pkg-name != 'lightning' }}
# Lightning is dependency of Habana or other accelerators/integrations so in case we test PL we need to remove it
- run: pip uninstall -y lightning
+ run: uv pip uninstall lightning
+
- name: Drop PL for LAI
- if: ${{ matrix.pkg-name == 'lightning' }}
- run: pip uninstall -y pytorch-lightning
+ if: ${{ matrix.config.pkg-name == 'lightning' }}
+ run: uv pip uninstall pytorch-lightning
+
- name: Dump handy wheels
if: github.event_name == 'push' && github.ref == 'refs/heads/master'
continue-on-error: true
@@ -170,10 +161,10 @@ jobs:
run: |
set -e
python requirements/pytorch/check-avail-extras.py
- python -c "from torch import __version__ as ver; assert ver.startswith('${{ matrix.pytorch-version }}'), ver"
+ python -c "from torch import __version__ as ver; assert ver.startswith('${{ matrix.config.pytorch-version }}'), ver"
- name: Adjust tests / env. -> PL
- if: ${{ matrix.pkg-name != 'lightning' }}
+ if: ${{ matrix.config.pkg-name != 'lightning' }}
run: |
python .actions/assistant.py copy_replace_imports --source_dir="./tests" \
--source_import="lightning.fabric,lightning.pytorch" \
@@ -223,10 +214,13 @@ jobs:
with:
token: ${{ secrets.CODECOV_TOKEN }}
file: tests/tests_pytorch/coverage.xml
- flags: ${{ env.COVERAGE_SCOPE }},cpu,pytest-full,python${{ matrix.python-version }},pytorch${{ matrix.pytorch-version }}
+ flags: ${{ env.COVERAGE_SCOPE }},cpu,pytest-full,python${{ matrix.config.python-version }},pytorch${{ matrix.config.pytorch-version }}
name: CPU-coverage
fail_ci_if_error: false
+ - name: Minimize uv cache
+ run: uv cache prune --ci
+
pl-cpu-guardian:
runs-on: ubuntu-latest
needs: pl-cpu
diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml
index 69d35e605db5b..5fb961798ad22 100644
--- a/.github/workflows/docs-build.yml
+++ b/.github/workflows/docs-build.yml
@@ -165,12 +165,12 @@ jobs:
path: docs/build/html/
- name: Authenticate to Google Cloud
- uses: google-github-actions/auth@v2
+ uses: google-github-actions/auth@v3
with:
credentials_json: ${{ secrets.GCS_SA_KEY }}
- name: Setup gcloud
- uses: google-github-actions/setup-gcloud@v2
+ uses: google-github-actions/setup-gcloud@v3
with:
project_id: ${{ secrets.GCS_PROJECT }}
diff --git a/.lightning/workflows/fabric.yml b/.lightning/workflows/fabric.yml
index 438f56ef7fe94..e6d97ca1b1887 100644
--- a/.lightning/workflows/fabric.yml
+++ b/.lightning/workflows/fabric.yml
@@ -4,20 +4,22 @@ trigger:
pull_request:
branches: ["master"]
-timeout: "75" # minutes
-machine: "L4_X_2"
+timeout: "55" # minutes
parametrize:
matrix: {}
include:
- # note that this is setting also all oldest requirements which is linked to Torch == 2.0
+ # note that this is setting also all oldest requirements which is linked to Torch == 2.1
- image: "pytorchlightning/pytorch_lightning:base-cuda12.1.1-py3.10-torch2.1"
PACKAGE_NAME: "fabric"
- - image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.7"
+ machine: "A100_X_2"
+ - image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.8"
PACKAGE_NAME: "fabric"
+ machine: "L4_X_2"
# - image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.7"
# PACKAGE_NAME: "fabric"
- - image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.7"
+ - image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.8"
PACKAGE_NAME: "lightning"
+ machine: "L4_X_2"
exclude: []
env:
@@ -30,6 +32,7 @@ run: |
python --version
pip --version
pip install -q fire wget packaging
+ pip list
set -ex
CUDA_VERSION="${image##*cuda}" # Remove everything up to and including "cuda"
@@ -40,12 +43,15 @@ run: |
echo "Torch URL: ${TORCH_URL}"
COVERAGE_SOURCE=$(python -c 'n = "$(PACKAGE_NAME)" ; print(dict(fabric="lightning_fabric").get(n, n))')
echo "collecting coverage for: ${COVERAGE_SOURCE}"
+ TORCH_VER=$(python -c "import torch; print(torch.__version__.rsplit('.', 1)[0])")
if [ "${TORCH_VER}" == "2.1" ]; then
echo "Set oldest versions"
- cd requirements/fabric
+ pip uninstall -y deepspeed
pip install -U "lightning-utilities[cli]"
+ cd requirements/fabric
python -m lightning_utilities.cli requirements set-oldest --req_files "['base.txt', 'strategies.txt']"
+ python -m lightning_utilities.cli requirements prune-pkgs --packages deepspeed --req_files strategies.txt
cd ../..
pip install "cython<3.0" wheel # for compatibility
fi
@@ -92,6 +98,7 @@ run: |
export PL_RUN_STANDALONE_TESTS=1
wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/run_standalone_tests.sh
bash ./run_standalone_tests.sh "tests_fabric"
+ export PL_RUN_STANDALONE_TESTS=0
# echo "Reporting coverage" # todo
# python -m coverage report
diff --git a/.lightning/workflows/pytorch.yml b/.lightning/workflows/pytorch.yml
index 5c92bf881d969..21551565b1fed 100644
--- a/.lightning/workflows/pytorch.yml
+++ b/.lightning/workflows/pytorch.yml
@@ -4,20 +4,22 @@ trigger:
pull_request:
branches: ["master"]
-timeout: "75" # minutes
-machine: "L4_X_2"
+timeout: "55" # minutes
parametrize:
matrix: {}
include:
- # note that this is setting also all oldest requirements which is linked to Torch == 2.0
+ # note that this is setting also all oldest requirements which is linked to Torch == 2.1
- image: "pytorchlightning/pytorch_lightning:base-cuda12.1.1-py3.10-torch2.1"
PACKAGE_NAME: "pytorch"
- - image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.7"
+ machine: "A100_X_2"
+ - image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.8"
PACKAGE_NAME: "pytorch"
+ machine: "L4_X_2"
# - image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.7"
# PACKAGE_NAME: "pytorch"
- - image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.7"
+ - image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.8"
PACKAGE_NAME: "lightning"
+ machine: "L4_X_2"
exclude: []
env:
@@ -30,6 +32,7 @@ run: |
python --version
pip --version
pip install -q fire wget packaging
+ pip list
set -ex
CUDA_VERSION="${image##*cuda}" # Remove everything up to and including "cuda"
@@ -40,12 +43,15 @@ run: |
echo "Torch URL: ${TORCH_URL}"
COVERAGE_SOURCE=$(python -c 'n = "$(PACKAGE_NAME)" ; print(dict(fabric="pytorch_lightning").get(n, n))')
echo "collecting coverage for: ${COVERAGE_SOURCE}"
+ TORCH_VER=$(python -c "import torch; print(torch.__version__.rsplit('.', 1)[0])")
if [ "${TORCH_VER}" == "2.1" ]; then
- recho "Set oldest versions"
- cd requirements/pytorch
+ echo "Set oldest versions"
+ pip uninstall -y deepspeed
pip install -U "lightning-utilities[cli]"
+ cd requirements/pytorch
python -m lightning_utilities.cli requirements set-oldest --req_files "['base.txt', 'extra.txt', 'strategies.txt', 'examples.txt']"
+ python -m lightning_utilities.cli requirements prune-pkgs --packages deepspeed --req_files strategies.txt
cd ../..
pip install "cython<3.0" wheel # for compatibility
fi
@@ -108,6 +114,7 @@ run: |
export PL_RUN_STANDALONE_TESTS=1
wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/run_standalone_tests.sh
bash ./run_standalone_tests.sh "tests_pytorch"
+ export PL_RUN_STANDALONE_TESTS=0
echo "Testing: PyTorch standalone tasks"
cd tests_pytorch/
diff --git a/dockers/base-cuda/Dockerfile b/dockers/base-cuda/Dockerfile
index 2b6f48771c7f7..41faf0ca55113 100644
--- a/dockers/base-cuda/Dockerfile
+++ b/dockers/base-cuda/Dockerfile
@@ -13,7 +13,7 @@
# limitations under the License.
ARG UBUNTU_VERSION=22.04
-ARG CUDA_VERSION=11.7.1
+ARG CUDA_VERSION=12.1.1
FROM nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION}
diff --git a/docs/source-fabric/guide/index.rst b/docs/source-fabric/guide/index.rst
index 669501ec09864..6444a6bbdd17a 100644
--- a/docs/source-fabric/guide/index.rst
+++ b/docs/source-fabric/guide/index.rst
@@ -78,7 +78,7 @@ Build your own Trainer
.. displayitem::
- :header: Organize your model code with with LightningModule
+ :header: Organize your model code with LightningModule
:description: Organize your code in a LightningModule and use it with Fabric
:button_link: lightning_module.html
:col_css: col-md-4
diff --git a/docs/source-fabric/levels/intermediate.rst b/docs/source-fabric/levels/intermediate.rst
index f21e7c96608ab..30f27bdd1d544 100644
--- a/docs/source-fabric/levels/intermediate.rst
+++ b/docs/source-fabric/levels/intermediate.rst
@@ -19,7 +19,7 @@ Intermediate skills
.. displayitem::
- :header: Organize your model code with with LightningModule
+ :header: Organize your model code with LightningModule
:description: Organize your code in a LightningModule and use it with Fabric
:button_link: ../guide/lightning_module.html
:col_css: col-md-4
diff --git a/docs/source-pytorch/accelerators/gpu_faq.rst b/docs/source-pytorch/accelerators/gpu_faq.rst
index 4cc05555bb559..346843e316bce 100644
--- a/docs/source-pytorch/accelerators/gpu_faq.rst
+++ b/docs/source-pytorch/accelerators/gpu_faq.rst
@@ -5,31 +5,71 @@
GPU training (FAQ)
==================
-******************************************************************
-How should I adjust the learning rate when using multiple devices?
-******************************************************************
+***************************************************************
+How should I adjust the batch size when using multiple devices?
+***************************************************************
-When using distributed training make sure to modify your learning rate according to your effective
-batch size.
+Lightning automatically shards your data across multiple GPUs, meaning that each device only sees a unique subset of your
+data, but the `batch_size` in your DataLoader remains the same. This means that the effective batch size e.g. the
+total number of samples processed in one forward/backward pass is
-Let's say you have a batch size of 7 in your dataloader.
+.. math::
-.. testcode::
+ \text{Effective Batch Size} = \text{DataLoader Batch Size} \times \text{Number of Devices} \times \text{Number of Nodes}
- class LitModel(LightningModule):
- def train_dataloader(self):
- return Dataset(..., batch_size=7)
-
-Whenever you use multiple devices and/or nodes, your effective batch size will be 7 * devices * num_nodes.
+A couple of examples to illustrate this:
.. code-block:: python
- # effective batch size = 7 * 8
+ dataloader = DataLoader(..., batch_size=7)
+
+ # Single GPU: effective batch size = 7
+ Trainer(accelerator="gpu", devices=1)
+
+ # Multi-GPU: effective batch size = 7 * 8 = 56
Trainer(accelerator="gpu", devices=8, strategy=...)
- # effective batch size = 7 * 8 * 10
+ # Multi-node: effective batch size = 7 * 8 * 10 = 560
Trainer(accelerator="gpu", devices=8, num_nodes=10, strategy=...)
+In general you should be able to use the same `batch_size` in your DataLoader regardless of the number of devices you are
+using.
+
+.. note::
+
+ If you want distributed training to work exactly the same as single GPU training, you need to set the `batch_size`
+ in your DataLoader to `original_batch_size / num_devices` to maintain the same effective batch size. However, this
+ can lead to poor GPU utilization.
+
+----
+
+******************************************************************
+How should I adjust the learning rate when using multiple devices?
+******************************************************************
+
+Because the effective batch size is larger when using multiple devices, you need to adjust your learning rate
+accordingly. Because the learning rate is a hyperparameter that controls how much to change the model in response to
+the estimated error each time the model weights are updated, it is important to scale it with the effective batch size.
+
+In general, there are two common scaling rules:
+
+1. **Linear scaling**: Increase the learning rate linearly with the number of devices.
+
+ .. code-block:: python
+
+ # Example: Linear scaling
+ base_lr = 1e-3
+ num_devices = 8
+ scaled_lr = base_lr * num_devices # 8e-3
+
+2. **Square root scaling**: Increase the learning rate by the square root of the number of devices.
+
+ .. code-block:: python
+
+ # Example: Square root scaling
+ base_lr = 1e-3
+ num_devices = 8
+ scaled_lr = base_lr * (num_devices ** 0.5) # 2.83e-3
.. note:: Huge batch sizes are actually really bad for convergence. Check out:
`Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour `_
diff --git a/docs/source-pytorch/common/hooks.rst b/docs/source-pytorch/common/hooks.rst
new file mode 100644
index 0000000000000..94531ab818d98
--- /dev/null
+++ b/docs/source-pytorch/common/hooks.rst
@@ -0,0 +1,325 @@
+##########################
+Hooks in PyTorch Lightning
+##########################
+
+Hooks in Pytorch Lightning allow you to customize the training, validation, and testing logic of your models. They
+provide a way to insert custom behavior at specific points during the training process without modifying the core
+training loop. There are several categories of hooks available in PyTorch Lightning:
+
+1. **Setup/Teardown Hooks**: Called at the beginning and end of training phases
+2. **Training Hooks**: Called during the training loop
+3. **Validation Hooks**: Called during validation
+4. **Test Hooks**: Called during testing
+5. **Prediction Hooks**: Called during prediction
+6. **Optimizer Hooks**: Called around optimizer operations
+7. **Checkpoint Hooks**: Called during checkpoint save/load operations
+8. **Exception Hooks**: Called when exceptions occur
+
+Nearly all hooks can be implemented in three places within your code:
+
+- **LightningModule**: The main module where you define your model and training logic.
+- **Callbacks**: Custom classes that can be passed to the Trainer to handle specific events.
+- **Strategy**: Custom strategies for distributed training.
+
+Importantly, because logic can be place in the same hook but in different places the call order of hooks is in
+important to understand. The following order is always used:
+
+1. Callbacks, called in the order they are passed to the Trainer.
+2. ``LightningModule``
+3. Strategy
+
+.. testcode::
+
+ from lightning.pytorch import Trainer
+ from lightning.pytorch.callbacks import Callback
+ from lightning.pytorch.demos import BoringModel
+
+ class MyModel(BoringModel):
+ def on_train_start(self):
+ print("Model: Training is starting!")
+
+ class MyCallback(Callback):
+ def on_train_start(self, trainer, pl_module):
+ print("Callback: Training is starting!")
+
+ model = MyModel()
+ callback = MyCallback()
+ trainer = Trainer(callbacks=[callback], logger=False, max_epochs=1, enable_progress_bar=False)
+ trainer.fit(model)
+
+.. testoutput::
+ :hide:
+ :options: +ELLIPSIS, +NORMALIZE_WHITESPACE
+
+ Callback: Training is starting!
+ Model: Training is starting!
+
+
+.. note::
+ There are a few exceptions to this pattern:
+
+ - **on_train_epoch_end**: Non-monitoring callbacks are called first, then ``LightningModule``, then monitoring callbacks
+ - **Optimizer hooks** (on_before_backward, on_after_backward, on_before_optimizer_step): Only callbacks and ``LightningModule`` are called
+ - Some internal hooks may only call ``LightningModule`` or Strategy
+
+************************
+Training Loop Hook Order
+************************
+
+The following diagram shows the execution order of hooks during a typical training loop e.g. calling `trainer.fit()`,
+with the source of each hook indicated:
+
+.. code-block:: text
+
+ Training Process Flow:
+
+ trainer.fit()
+ │
+ ├── setup(stage="fit")
+ │ ├── [LightningDataModule]
+ │ ├── [Callbacks]
+ │ ├── [LightningModule]
+ │ ├── [LightningModule.configure_shared_model()]
+ │ ├── [LightningModule.configure_model()]
+ │ ├── Strategy.restore_checkpoint_before_setup
+ │ │ ├── [LightningModule.on_load_checkpoint()]
+ │ │ ├── [LightningModule.load_state_dict()]
+ │ │ ├── [LightningDataModule.load_state_dict()]
+ │ │ ├── [Callbacks.on_load_checkpoint()]
+ │ │ └── [Callbacks.load_state_dict()]
+ │ └── [Strategy]
+ │
+ ├── on_fit_start()
+ │ ├── [Callbacks]
+ │ └── [LightningModule]
+ │
+ ├── Strategy.restore_checkpoint_after_setup
+ │ ├── [LightningModule.on_load_checkpoint()]
+ │ ├── [LightningModule.load_state_dict()]
+ │ ├── [LightningDataModule.load_state_dict()]
+ │ ├── [Callbacks.on_load_checkpoint()]
+ │ └── [Callbacks.load_state_dict()]
+ │
+ ├── on_sanity_check_start()
+ │ ├── [Callbacks]
+ │ ├── [LightningModule]
+ │ └── [Strategy]
+ │ ├── on_validation_start()
+ │ │ ├── [Callbacks]
+ │ │ ├── [LightningModule]
+ │ │ └── [Strategy]
+ │ ├── on_validation_epoch_start()
+ │ │ ├── [Callbacks]
+ │ │ ├── [LightningModule]
+ │ │ └── [Strategy]
+ │ │ ├── [for each validation batch]
+ │ │ │ ├── on_validation_batch_start()
+ │ │ │ │ ├── [Callbacks]
+ │ │ │ │ ├── [LightningModule]
+ │ │ │ │ └── [Strategy]
+ │ │ │ └── on_validation_batch_end()
+ │ │ │ ├── [Callbacks]
+ │ │ │ ├── [LightningModule]
+ │ │ │ └── [Strategy]
+ │ │ └── [end validation batches]
+ │ ├── on_validation_epoch_end()
+ │ │ ├── [Callbacks]
+ │ │ ├── [LightningModule]
+ │ │ └── [Strategy]
+ │ └── on_validation_end()
+ │ ├── [Callbacks]
+ │ ├── [LightningModule]
+ │ └── [Strategy]
+ ├── on_sanity_check_end()
+ │ ├── [Callbacks]
+ │ ├── [LightningModule]
+ │ └── [Strategy]
+ │
+ ├── on_train_start()
+ │ ├── [Callbacks]
+ │ ├── [LightningModule]
+ │ └── [Strategy]
+ │
+ ├── [Training Epochs Loop]
+ │ │
+ │ ├── on_train_epoch_start()
+ │ │ ├── [Callbacks]
+ │ │ └── [LightningModule]
+ │ │
+ │ ├── [Training Batches Loop]
+ │ │ │
+ │ │ ├── on_train_batch_start()
+ │ │ │ ├── [Callbacks]
+ │ │ │ ├── [LightningModule]
+ │ │ │ └── [Strategy]
+ │ │ │
+ │ │ ├── [Forward Pass - training_step()]
+ │ │ │ └── [Strategy only]
+ │ │ │
+ │ │ ├── on_before_zero_grad()
+ │ │ │ ├── [Callbacks]
+ │ │ │ └── [LightningModule]
+ │ │ │
+ │ │ ├── optimizer_zero_grad()
+ │ │ │ └── [LightningModule only - optimizer_zero_grad()]
+ │ │ │
+ │ │ ├── [Backward Pass - Strategy.backward()]
+ │ │ │ ├── on_before_backward()
+ │ │ │ │ ├── [Callbacks]
+ │ │ │ │ └── [LightningModule]
+ │ │ │ ├── LightningModule.backward()
+ │ │ │ └── on_after_backward()
+ │ │ │ ├── [Callbacks]
+ │ │ │ └── [LightningModule]
+ │ │ │
+ │ │ ├── on_before_optimizer_step()
+ │ │ │ ├── [Callbacks]
+ │ │ │ └── [LightningModule]
+ │ │ │
+ │ │ ├── [Optimizer Step]
+ │ │ │ └── [LightningModule only - optimizer_step()]
+ │ │ │
+ │ │ └── on_train_batch_end()
+ │ │ ├── [Callbacks]
+ │ │ └── [LightningModule]
+ │ │
+ │ │ [Optional: Validation during training]
+ │ │ ├── on_validation_start()
+ │ │ │ ├── [Callbacks]
+ │ │ │ ├── [LightningModule]
+ │ │ │ └── [Strategy]
+ │ │ ├── on_validation_epoch_start()
+ │ │ │ ├── [Callbacks]
+ │ │ │ ├── [LightningModule]
+ │ │ │ └── [Strategy]
+ │ │ │ ├── [for each validation batch]
+ │ │ │ │ ├── on_validation_batch_start()
+ │ │ │ │ │ ├── [Callbacks]
+ │ │ │ │ │ ├── [LightningModule]
+ │ │ │ │ │ └── [Strategy]
+ │ │ │ │ └── on_validation_batch_end()
+ │ │ │ │ ├── [Callbacks]
+ │ │ │ │ ├── [LightningModule]
+ │ │ │ │ └── [Strategy]
+ │ │ │ └── [end validation batches]
+ │ │ ├── on_validation_epoch_end()
+ │ │ │ ├── [Callbacks]
+ │ │ │ ├── [LightningModule]
+ │ │ │ └── [Strategy]
+ │ │ └── on_validation_end()
+ │ │ ├── [Callbacks]
+ │ │ ├── [LightningModule]
+ │ │ └── [Strategy]
+ │ │
+ │ └── on_train_epoch_end() **SPECIAL CASE**
+ │ ├── [Callbacks - Non-monitoring only]
+ │ ├── [LightningModule]
+ │ └── [Callbacks - Monitoring only]
+ │
+ ├── [End Training Epochs]
+ │
+ ├── on_train_end()
+ │ ├── [Callbacks]
+ │ ├── [LightningModule]
+ │ └── [Strategy]
+ │
+ └── teardown(stage="fit")
+ ├── [Strategy]
+ ├── on_fit_end()
+ │ ├── [Callbacks]
+ │ └── [LightningModule]
+ ├── [LightningDataModule]
+ ├── [Callbacks]
+ └── [LightningModule]
+
+***********************
+Testing Loop Hook Order
+***********************
+
+When running tests with ``trainer.test()``:
+
+.. code-block:: text
+
+ trainer.test()
+ │
+ ├── setup(stage="test")
+ │ └── [Callbacks only]
+ ├── on_test_start()
+ │ ├── [Callbacks]
+ │ ├── [LightningModule]
+ │ └── [Strategy]
+ │
+ ├── [Test Epochs Loop]
+ │ │
+ │ ├── on_test_epoch_start()
+ │ │ ├── [Callbacks]
+ │ │ ├── [LightningModule]
+ │ │ └── [Strategy]
+ │ │
+ │ ├── [Test Batches Loop]
+ │ │ │
+ │ │ ├── on_test_batch_start()
+ │ │ │ ├── [Callbacks]
+ │ │ │ ├── [LightningModule]
+ │ │ │ └── [Strategy]
+ │ │ │
+ │ │ └── on_test_batch_end()
+ │ │ ├── [Callbacks]
+ │ │ ├── [LightningModule]
+ │ │ └── [Strategy]
+ │ │
+ │ └── on_test_epoch_end()
+ │ ├── [Callbacks]
+ │ ├── [LightningModule]
+ │ └── [Strategy]
+ │
+ ├── on_test_end()
+ │ ├── [Callbacks]
+ │ ├── [LightningModule]
+ │ └── [Strategy]
+ └── teardown(stage="test")
+ └── [Callbacks only]
+
+**************************
+Prediction Loop Hook Order
+**************************
+
+When running predictions with ``trainer.predict()``:
+
+.. code-block:: text
+
+ trainer.predict()
+ │
+ ├── setup(stage="predict")
+ │ └── [Callbacks only]
+ ├── on_predict_start()
+ │ ├── [Callbacks]
+ │ ├── [LightningModule]
+ │ └── [Strategy]
+ │
+ ├── [Prediction Epochs Loop]
+ │ │
+ │ ├── on_predict_epoch_start()
+ │ │ ├── [Callbacks]
+ │ │ └── [LightningModule]
+ │ │
+ │ ├── [Prediction Batches Loop]
+ │ │ │
+ │ │ ├── on_predict_batch_start()
+ │ │ │ ├── [Callbacks]
+ │ │ │ └── [LightningModule]
+ │ │ │
+ │ │ └── on_predict_batch_end()
+ │ │ ├── [Callbacks]
+ │ │ └── [LightningModule]
+ │ │
+ │ └── on_predict_epoch_end()
+ │ ├── [Callbacks]
+ │ └── [LightningModule]
+ │
+ ├── on_predict_end()
+ │ ├── [Callbacks]
+ │ ├── [LightningModule]
+ │ └── [Strategy]
+ └── teardown(stage="predict")
+ └── [Callbacks only]
diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst
index 86ee52f41f0c9..a81934c104b81 100644
--- a/docs/source-pytorch/common/trainer.rst
+++ b/docs/source-pytorch/common/trainer.rst
@@ -510,6 +510,7 @@ limit_train_batches
How much of training dataset to check.
Useful when debugging or testing something that happens at the end of an epoch.
+Value is per device.
.. testcode::
@@ -535,7 +536,7 @@ limit_test_batches
:width: 400
:muted:
-How much of test dataset to check.
+How much of test dataset to check. Value is per device.
.. testcode::
@@ -560,6 +561,7 @@ limit_val_batches
How much of validation dataset to check.
Useful when debugging or testing something that happens at the end of an epoch.
+Value is per device.
.. testcode::
diff --git a/docs/source-pytorch/expertise_levels.rst b/docs/source-pytorch/expertise_levels.rst
index 7c34123f5b3f7..672ccf48ab7eb 100644
--- a/docs/source-pytorch/expertise_levels.rst
+++ b/docs/source-pytorch/expertise_levels.rst
@@ -84,7 +84,7 @@ Learn to scale up your models and enable collaborative model development at acad
.. Add callout items below this line
.. displayitem::
- :header: Level 7: Interactive cloud development
+ :header: Level 7: Hardware acceleration
:description: Learn how to access GPUs and TPUs on the cloud.
:button_link: levels/intermediate_level_7.html
:col_css: col-md-6
diff --git a/docs/source-pytorch/glossary/index.rst b/docs/source-pytorch/glossary/index.rst
index 45683c67c1708..c904932737b02 100644
--- a/docs/source-pytorch/glossary/index.rst
+++ b/docs/source-pytorch/glossary/index.rst
@@ -20,6 +20,7 @@
FSDP <../advanced/model_parallel/fsdp>
GPU <../accelerators/gpu>
Half precision <../common/precision>
+ Hooks <../common/hooks>
HPU <../integrations/hpu/index>
Inference <../deploy/production_intermediate>
Lightning CLI <../cli/lightning_cli>
@@ -179,6 +180,13 @@ Glossary
:button_link: ../common/precision.html
:height: 100
+.. displayitem::
+ :header: Hooks
+ :description: How to customize the training, validation, and testing logic
+ :col_css: col-md-12
+ :button_link: ../common/hooks.html
+ :height: 100
+
.. displayitem::
:header: HPU
:description: Habana Gaudi AI Processor Unit for faster training
diff --git a/docs/source-pytorch/levels/intermediate.rst b/docs/source-pytorch/levels/intermediate.rst
index 282ac7bc98c90..797cfbdf0c2e2 100644
--- a/docs/source-pytorch/levels/intermediate.rst
+++ b/docs/source-pytorch/levels/intermediate.rst
@@ -16,7 +16,7 @@ Learn to scale up your models and enable collaborative model development at acad
.. Add callout items below this line
.. displayitem::
- :header: Level 7: Interactive cloud development
+ :header: Level 7: Hardware acceleration
:description: Learn how to access GPUs and TPUs on the cloud.
:button_link: intermediate_level_7.html
:col_css: col-md-6
diff --git a/docs/source-pytorch/levels/intermediate_level_7.rst b/docs/source-pytorch/levels/intermediate_level_7.rst
index ef4122d0d3150..660fd9962c8d2 100644
--- a/docs/source-pytorch/levels/intermediate_level_7.rst
+++ b/docs/source-pytorch/levels/intermediate_level_7.rst
@@ -1,8 +1,8 @@
:orphan:
-######################################
-Level 7: Interactive cloud development
-######################################
+##############################
+Level 7: Hardware acceleration
+##############################
Learn to develop models on cloud GPUs and TPUs.
diff --git a/docs/source-pytorch/tuning/profiler_basic.rst b/docs/source-pytorch/tuning/profiler_basic.rst
index 880381268fa78..01f4d8a51daaf 100644
--- a/docs/source-pytorch/tuning/profiler_basic.rst
+++ b/docs/source-pytorch/tuning/profiler_basic.rst
@@ -121,3 +121,22 @@ This can be measured with the :class:`~lightning.pytorch.callbacks.device_stats_
CPU metrics will be tracked by default on the CPU accelerator. To enable it for other accelerators set ``DeviceStatsMonitor(cpu_stats=True)``. To disable logging
CPU metrics, you can specify ``DeviceStatsMonitor(cpu_stats=False)``.
+
+.. warning::
+
+ **Do not wrap** ``Trainer.fit()``, ``Trainer.validate()``, or other Trainer methods inside a manual
+ ``torch.profiler.profile`` context manager. This will cause unexpected crashes and cryptic errors due to
+ incompatibility between PyTorch Profiler's context management and Lightning's internal training loop.
+ Instead, always use the ``profiler`` argument in the ``Trainer`` constructor or the
+ :class:`~lightning.pytorch.profilers.pytorch.PyTorchProfiler` profiler class if you want to customize the profiling.
+
+ Example:
+
+ .. code-block:: python
+
+ from lightning.pytorch import Trainer
+ from lightning.pytorch.profilers import PytorchProfiler
+
+ trainer = Trainer(profiler="pytorch")
+ # or
+ trainer = Trainer(profiler=PytorchProfiler(dirpath=".", filename="perf_logs"))
diff --git a/docs/source-pytorch/versioning.rst b/docs/source-pytorch/versioning.rst
index 4a04bd1534de9..296f8052e863d 100644
--- a/docs/source-pytorch/versioning.rst
+++ b/docs/source-pytorch/versioning.rst
@@ -67,6 +67,13 @@ PyTorch Lightning follows `NEP 29 =2.1.0, <2.9.0
fsspec[http] >=2022.5.0, <2025.8.0
packaging >=20.0, <=25.0
-typing-extensions >4.5.0, <4.15.0
+typing-extensions >4.5.0, <4.16.0
lightning-utilities >=0.10.0, <0.16.0
diff --git a/requirements/fabric/test.txt b/requirements/fabric/test.txt
index a9b4271cac2c3..222e9cf412b3f 100644
--- a/requirements/fabric/test.txt
+++ b/requirements/fabric/test.txt
@@ -1,9 +1,9 @@
-coverage ==7.10.5
+coverage ==7.10.6
numpy >=1.21.0, <1.27.0
pytest ==8.4.1
pytest-cov ==6.2.1
pytest-timeout ==2.4.0
-pytest-rerunfailures ==15.1
+pytest-rerunfailures ==16.0
pytest-random-order ==1.2.0
click ==8.1.8; python_version < "3.11"
click ==8.2.1; python_version > "3.10"
diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt
index ef798883c12ef..6e684d995c1b7 100644
--- a/requirements/pytorch/base.txt
+++ b/requirements/pytorch/base.txt
@@ -7,5 +7,5 @@ PyYAML >5.4, <6.1.0
fsspec[http] >=2022.5.0, <2025.8.0
torchmetrics >0.7.0, <1.9.0
packaging >=20.0, <=25.0
-typing-extensions >4.5.0, <4.15.0
+typing-extensions >4.5.0, <4.16.0
lightning-utilities >=0.10.0, <0.16.0
diff --git a/requirements/pytorch/docs.txt b/requirements/pytorch/docs.txt
index a3e2e88967f75..7ee6c8bb309cb 100644
--- a/requirements/pytorch/docs.txt
+++ b/requirements/pytorch/docs.txt
@@ -1,7 +1,7 @@
-r ../docs.txt
nbformat # used for generate empty notebook
-ipython[notebook] <9.5.0
+ipython[notebook] <9.6.0
setuptools<81.0 # workaround for `error in ipython setup command: use_2to3 is invalid.`
onnxscript >= 0.2.2, < 0.5.0
diff --git a/requirements/pytorch/examples.txt b/requirements/pytorch/examples.txt
index 84ea80df6ff0c..c34309a0234f2 100644
--- a/requirements/pytorch/examples.txt
+++ b/requirements/pytorch/examples.txt
@@ -3,5 +3,5 @@
requests <2.33.0
torchvision >=0.16.0, <0.24.0
-ipython[all] <8.19.0
+ipython[all] >=8.0.0, <9.0.0
torchmetrics >=0.10.0, <1.9.0
diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt
index 1fd3ec790055f..86b765c37237f 100644
--- a/requirements/pytorch/test.txt
+++ b/requirements/pytorch/test.txt
@@ -1,15 +1,15 @@
-coverage ==7.10.5
+coverage ==7.10.6
pytest ==8.4.1
pytest-cov ==6.2.1
pytest-timeout ==2.4.0
-pytest-rerunfailures ==15.1
+pytest-rerunfailures ==16.0
pytest-random-order ==1.2.0
# needed in tests
cloudpickle >=1.3, <3.2.0
scikit-learn >0.22.1, <1.8.0
numpy >1.20.0, <1.27.0
-onnx >1.12.0, <1.19.0
+onnx >1.12.0, <1.20.0
onnxruntime >=1.12.0, <1.23.0
onnxscript >= 0.1.0, < 0.5.0
psutil <7.0.1 # for `DeviceStatsMonitor`
diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md
index b1102cdce06b7..ae9ab0fae131f 100644
--- a/src/lightning/fabric/CHANGELOG.md
+++ b/src/lightning/fabric/CHANGELOG.md
@@ -6,11 +6,25 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
---
+## [2.5.5] - 2025-09-05
+
+### Changed
+
+- Include `exclude_frozen_parameters` to `DeepSpeedStrategy` ([#21060](https://github.com/Lightning-AI/pytorch-lightning/pull/21060))
+- Let `_get_default_process_group_backend_for_device` support more hardware platforms (
+ [#21057](https://github.com/Lightning-AI/pytorch-lightning/pull/21057), [#21093](https://github.com/Lightning-AI/pytorch-lightning/pull/21093))
+
+### Fixed
+
+- Fixed with adding a missing device id for pytorch 2.8 ([#21105](https://github.com/Lightning-AI/pytorch-lightning/pull/21105))
+- Respecting `verbose=False` in `seed_everything` when no seed is provided ([#21161](https://github.com/Lightning-AI/pytorch-lightning/pull/21161))
+
+
## [2.5.4] - 2025-08-29
### Changed
-- Added support for NVIDIA H200 GPUs in `get_available_flops` ([#20913](https://github.com/Lightning-AI/pytorch-lightning/pull/21119))
+- Added support for NVIDIA H200 GPUs in `get_available_flops` ([#21119](https://github.com/Lightning-AI/pytorch-lightning/pull/21119))
## [2.5.3] - 2025-08-13
@@ -45,7 +59,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Removed
-- Removed legacy support for `lightning run model`. Use `fabric run` instead ([#20588](https://github.com/Lightning-AI/pytorch-lightning/pull/20588))
+- Removed legacy support for `lightning run model`; use `fabric run` instead ([#20588](https://github.com/Lightning-AI/pytorch-lightning/pull/20588))
## [2.5.0] - 2024-12-19
diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py
index ce47e4e403c34..af182ad7f422f 100644
--- a/src/lightning/fabric/strategies/ddp.py
+++ b/src/lightning/fabric/strategies/ddp.py
@@ -41,6 +41,7 @@
_sync_ddp_if_available,
)
from lightning.fabric.utilities.distributed import group as _group
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
from lightning.fabric.utilities.rank_zero import rank_zero_only
_DDP_FORK_ALIASES = (
@@ -159,7 +160,17 @@ def barrier(self, *args: Any, **kwargs: Any) -> None:
if torch.distributed.get_backend() == "nccl":
torch.distributed.barrier(device_ids=self._determine_ddp_device_ids())
else:
- torch.distributed.barrier()
+ # Handle PyTorch bug where barrier() fails on CPU with "PrivateUse1HooksInterface" error
+ try:
+ torch.distributed.barrier()
+ except RuntimeError as e:
+ if "PrivateUse1HooksInterface" in str(e):
+ # Fallback: Use all_reduce as barrier - all processes must participate
+ # This achieves the same synchronization effect as barrier()
+ dummy_tensor = torch.tensor(0.0, device=self.root_device)
+ torch.distributed.all_reduce(dummy_tensor)
+ else:
+ raise
@override
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
@@ -212,7 +223,10 @@ def _setup_distributed(self) -> None:
self._set_world_ranks()
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
- _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
+ kwargs: dict[str, Any] = {"timeout": self._timeout}
+ if _TORCH_GREATER_EQUAL_2_3:
+ kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
+ _init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs)
def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py
index 322fa1899b0ee..d21182b525f66 100644
--- a/src/lightning/fabric/strategies/deepspeed.py
+++ b/src/lightning/fabric/strategies/deepspeed.py
@@ -99,6 +99,7 @@ def __init__(
precision: Optional[Precision] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = default_pg_timeout,
+ exclude_frozen_parameters: bool = False,
) -> None:
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
billion parameter models. `For more information: https://pytorch-
@@ -228,6 +229,8 @@ def __init__(
when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards
per worker.
+ exclude_frozen_parameters: Exclude frozen parameters when saving checkpoints.
+
"""
if not _DEEPSPEED_AVAILABLE:
raise ImportError(
@@ -288,6 +291,7 @@ def __init__(
self.remote_device = remote_device
self.load_full_weights = load_full_weights
+ self.exclude_frozen_parameters = exclude_frozen_parameters
# default FP16 parameters.
self.loss_scale = loss_scale
@@ -444,7 +448,9 @@ def save_checkpoint(
# there might be other stateful objects unrelated to the deepspeed engine - convert them to a state_dict
state = self._convert_stateful_objects_in_state(state, filter={})
# use deepspeed's internal checkpointing function to handle partitioned weights across processes
- engine.save_checkpoint(path, client_state=state, tag="checkpoint")
+ engine.save_checkpoint(
+ path, client_state=state, tag="checkpoint", exclude_frozen_parameters=self.exclude_frozen_parameters
+ )
@override
def load_checkpoint(
diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py
index 74af1dd0e8f43..25b50d6a67332 100644
--- a/src/lightning/fabric/strategies/fsdp.py
+++ b/src/lightning/fabric/strategies/fsdp.py
@@ -662,7 +662,10 @@ def _setup_distributed(self) -> None:
self._set_world_ranks()
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
- _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
+ kwargs: dict[str, Any] = {"timeout": self._timeout}
+ if _TORCH_GREATER_EQUAL_2_3:
+ kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
+ _init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs)
def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py
index ace23a9c7a2c5..0d49ddf91a0bc 100644
--- a/src/lightning/fabric/strategies/model_parallel.py
+++ b/src/lightning/fabric/strategies/model_parallel.py
@@ -302,7 +302,10 @@ def _setup_distributed(self) -> None:
self._set_world_ranks()
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
- _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
+ kwargs: dict[str, Any] = {"timeout": self._timeout}
+ if _TORCH_GREATER_EQUAL_2_3:
+ kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
+ _init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs)
def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py
index ec4eb261f2d3e..500f3a3e2aa92 100644
--- a/src/lightning/fabric/utilities/distributed.py
+++ b/src/lightning/fabric/utilities/distributed.py
@@ -319,7 +319,11 @@ def _destroy_dist_connection() -> None:
def _get_default_process_group_backend_for_device(device: torch.device) -> str:
- return "nccl" if device.type == "cuda" else "gloo"
+ """Return corresponding distributed backend for a given device."""
+ device_backend_map = torch.distributed.Backend.default_device_backend_map
+ if device.type in device_backend_map:
+ return device_backend_map[device.type]
+ return "gloo"
class _DatasetSamplerWrapper(Dataset):
diff --git a/src/lightning/fabric/utilities/seed.py b/src/lightning/fabric/utilities/seed.py
index 534e5e3db653e..841fa195696a2 100644
--- a/src/lightning/fabric/utilities/seed.py
+++ b/src/lightning/fabric/utilities/seed.py
@@ -40,7 +40,8 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False, verbose:
env_seed = os.environ.get("PL_GLOBAL_SEED")
if env_seed is None:
seed = 0
- rank_zero_warn(f"No seed found, seed set to {seed}")
+ if verbose:
+ rank_zero_warn(f"No seed found, seed set to {seed}")
else:
try:
seed = int(env_seed)
diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md
index 01b64c38051b3..55519273e0b27 100644
--- a/src/lightning/pytorch/CHANGELOG.md
+++ b/src/lightning/pytorch/CHANGELOG.md
@@ -6,6 +6,22 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
---
+## [2.5.5] - 2025-09-05
+
+### Changed
+
+- Include `exclude_frozen_parameters` to `DeepSpeedStrategy` ([#21060](https://github.com/Lightning-AI/pytorch-lightning/pull/21060))
+- Include `PossibleUserWarning` that is raised if modules are in eval mode when training starts ([#21146](https://github.com/Lightning-AI/pytorch-lightning/pull/21146))
+
+### Fixed
+
+- Fixed `LightningCLI` not using `ckpt_path` hyperparameters to instantiate classes ([#21116](https://github.com/Lightning-AI/pytorch-lightning/pull/21116))
+- Fixed callbacks by defer step/time-triggered `ModelCheckpoint` saves until validation metrics are available ([#21106](https://github.com/Lightning-AI/pytorch-lightning/pull/21106))
+- Fixed with adding a missing device id for pytorch 2.8 ([#21105](https://github.com/Lightning-AI/pytorch-lightning/pull/21105))
+- Fixed `TQDMProgressBar` not resetting correctly when using both a finite and iterable dataloader ([#21147](https://github.com/Lightning-AI/pytorch-lightning/pull/21147))
+- Fixed cleanup of temporary files from `Tuner` on crashes ([#21162](https://github.com/Lightning-AI/pytorch-lightning/pull/21162))
+
+
## [2.5.4] - 2025-08-29
### Fixed
diff --git a/src/lightning/pytorch/callbacks/device_stats_monitor.py b/src/lightning/pytorch/callbacks/device_stats_monitor.py
index 6279dd13be4af..873c4c05f5aed 100644
--- a/src/lightning/pytorch/callbacks/device_stats_monitor.py
+++ b/src/lightning/pytorch/callbacks/device_stats_monitor.py
@@ -34,6 +34,67 @@ class DeviceStatsMonitor(Callback):
r"""Automatically monitors and logs device stats during training, validation and testing stage.
``DeviceStatsMonitor`` is a special callback as it requires a ``logger`` to passed as argument to the ``Trainer``.
+ **Logged Metrics**
+
+ Logs device statistics with keys prefixed as ``DeviceStatsMonitor.{hook_name}/{base_metric_name}``.
+ The actual metrics depend on the active accelerator and the ``cpu_stats`` flag. Below are an overview of the
+ possible available metrics and their meaning.
+
+ - CPU (via ``psutil``)
+
+ - ``cpu_percent`` — System-wide CPU utilization (%)
+ - ``cpu_vm_percent`` — System-wide virtual memory (RAM) utilization (%)
+ - ``cpu_swap_percent`` — System-wide swap memory utilization (%)
+
+ - CUDA GPU (via ``torch.cuda.memory_stats``)
+
+ Logs memory statistics from PyTorch caching allocator (all in bytes).
+ GPU compute utilization is not logged by default.
+
+ - General Memory Usage:
+
+ - ``allocated_bytes.all.current`` — Current allocated GPU memory
+ - ``allocated_bytes.all.peak`` — Peak allocated GPU memory
+ - ``reserved_bytes.all.current`` — Current reserved GPU memory (allocated + cached)
+ - ``reserved_bytes.all.peak`` — Peak reserved GPU memory
+ - ``active_bytes.all.current`` — Current GPU memory in active use
+ - ``active_bytes.all.peak`` — Peak GPU memory in active use
+ - ``inactive_split_bytes.all.current`` — Memory in inactive, splittable blocks
+
+ - Allocator Pool Statistics* (for ``small_pool`` and ``large_pool``):
+
+ - ``allocated_bytes.{pool_type}.current`` / ``allocated_bytes.{pool_type}.peak``
+ - ``reserved_bytes.{pool_type}.current`` / ``reserved_bytes.{pool_type}.peak``
+ - ``active_bytes.{pool_type}.current`` / ``active_bytes.{pool_type}.peak``
+
+ - Allocator Events:
+
+ - ``num_ooms`` — Cumulative out-of-memory errors
+ - ``num_alloc_retries`` — Number of allocation retries
+ - ``num_device_alloc`` — Number of device allocations
+ - ``num_device_free`` — Number of device deallocations
+
+ For a full list of CUDA memory stats, see the
+ `PyTorch documentation `_.
+
+ - TPU (via ``torch_xla``)
+
+ - *Memory Metrics* (per device, e.g., ``xla:0``):
+
+ - ``memory.free.xla:0`` — Free HBM memory (MB)
+ - ``memory.used.xla:0`` — Used HBM memory (MB)
+ - ``memory.percent.xla:0`` — Percentage of HBM memory used (%)
+
+ - *XLA Operation Counters*:
+
+ - ``CachedCompile.xla``
+ - ``CreateXlaTensor.xla``
+ - ``DeviceDataCacheMiss.xla``
+ - ``UncachedCompile.xla``
+ - ``xla::add.xla``, ``xla::addmm.xla``, etc.
+
+ These counters can be retrieved using: ``torch_xla.debug.metrics.counter_names()``
+
Args:
cpu_stats: if ``None``, it will log CPU stats only if the accelerator is CPU.
If ``True``, it will log CPU stats regardless of the accelerator.
diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py
index 68fed2ff82d31..452e8bdecbba3 100644
--- a/src/lightning/pytorch/callbacks/model_checkpoint.py
+++ b/src/lightning/pytorch/callbacks/model_checkpoint.py
@@ -260,6 +260,9 @@ def __init__(
self.best_model_path = ""
self.last_model_path = ""
self._last_checkpoint_saved = ""
+ # When using step/time-based checkpointing with a validation-only monitored metric,
+ # defer the save until validation has produced the metric
+ self._defer_save_until_validation: bool = False
self.kth_value: Tensor
self.dirpath: Optional[_PATH]
@@ -306,14 +309,17 @@ def on_train_batch_end(
batch_idx: int,
) -> None:
"""Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""
- if self._should_skip_saving_checkpoint(trainer):
- return
+ # Do not return early here because we may need to set deferral flags even
+ # if a save already happened at this global step. We'll enforce the skip
+ # just before actually saving below.
+ skip_due_to_state = self._should_skip_saving_checkpoint(trainer)
skip_batch = self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0)
train_time_interval = self._train_time_interval
skip_time = True
now = time.monotonic()
- if train_time_interval:
+ # Important: allow zero timedelta as a valid interval
+ if train_time_interval is not None:
prev_time_check = self._last_time_checked
skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds()
# in case we have time differences across ranks
@@ -326,6 +332,42 @@ def on_train_batch_end(
self._last_time_checked = now
monitor_candidates = self._monitor_candidates(trainer)
+ # If monitoring a metric that is not yet available (e.g., validation-only),
+ # defer saving until validation end so the metric is present.
+ if self.monitor is not None and self.monitor not in monitor_candidates:
+ # Defer both top-k and last to avoid blocking with `_last_global_step_saved`
+ self._defer_save_until_validation = True
+ return
+
+ # Even if the monitored key exists, it could be stale from a previous validation.
+ # If validation is scheduled to run right after this batch (e.g., last batch of epoch)
+ # and we are not saving at train epoch end, defer to `on_validation_end` to use fresh metrics.
+ if (
+ self.monitor is not None
+ and not self._should_save_on_train_epoch_end(trainer)
+ and getattr(trainer.fit_loop.epoch_loop.batch_progress, "is_last_batch", False)
+ ):
+ # Only defer if a validation loop is expected to run after this batch.
+ will_run_val = False
+ if getattr(trainer, "enable_validation", False):
+ num_val_batches = (
+ sum(trainer.num_val_batches)
+ if isinstance(trainer.num_val_batches, list)
+ else trainer.num_val_batches
+ )
+ if num_val_batches and num_val_batches > 0:
+ cve = trainer.check_val_every_n_epoch
+ if cve is None or ((trainer.current_epoch + 1) % cve == 0):
+ will_run_val = True
+
+ if will_run_val:
+ self._defer_save_until_validation = True
+ return
+
+ # Only proceed to save if not skipping due to trainer/callback state
+ if skip_due_to_state:
+ return
+
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)
@@ -343,6 +385,14 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul
"""Save a checkpoint at the end of the validation stage."""
if not self._should_skip_saving_checkpoint(trainer) and not self._should_save_on_train_epoch_end(trainer):
monitor_candidates = self._monitor_candidates(trainer)
+ # If a step/time-triggered save was deferred due to a missing monitored metric,
+ # perform the save now that validation metrics are available.
+ if self._defer_save_until_validation:
+ self._save_topk_checkpoint(trainer, monitor_candidates)
+ self._save_last_checkpoint(trainer, monitor_candidates)
+ self._defer_save_until_validation = False
+ return
+
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)
diff --git a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py
index 4ef260f00006d..4a0b0d67041ba 100644
--- a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py
+++ b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py
@@ -265,7 +265,9 @@ def on_train_start(self, *_: Any) -> None:
def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
if self._leave:
self.train_progress_bar = self.init_train_tqdm()
- self.train_progress_bar.reset(convert_inf(self.total_train_batches))
+ total = convert_inf(self.total_train_batches)
+ self.train_progress_bar.reset()
+ self.train_progress_bar.total = total
self.train_progress_bar.initial = 0
self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}")
@@ -306,7 +308,9 @@ def on_validation_batch_start(
if not self.has_dataloader_changed(dataloader_idx):
return
- self.val_progress_bar.reset(convert_inf(self.total_val_batches_current_dataloader))
+ total = convert_inf(self.total_val_batches_current_dataloader)
+ self.val_progress_bar.reset()
+ self.val_progress_bar.total = total
self.val_progress_bar.initial = 0
desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description
self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}")
@@ -348,7 +352,9 @@ def on_test_batch_start(
if not self.has_dataloader_changed(dataloader_idx):
return
- self.test_progress_bar.reset(convert_inf(self.total_test_batches_current_dataloader))
+ total = convert_inf(self.total_test_batches_current_dataloader)
+ self.test_progress_bar.reset()
+ self.test_progress_bar.total = total
self.test_progress_bar.initial = 0
self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}")
@@ -387,7 +393,9 @@ def on_predict_batch_start(
if not self.has_dataloader_changed(dataloader_idx):
return
- self.predict_progress_bar.reset(convert_inf(self.total_predict_batches_current_dataloader))
+ total = convert_inf(self.total_predict_batches_current_dataloader)
+ self.predict_progress_bar.reset()
+ self.predict_progress_bar.total = total
self.predict_progress_bar.initial = 0
self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}")
diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py
index 225296240674a..91247127f6c87 100644
--- a/src/lightning/pytorch/cli.py
+++ b/src/lightning/pytorch/cli.py
@@ -16,6 +16,7 @@
import sys
from collections.abc import Iterable
from functools import partial, update_wrapper
+from pathlib import Path
from types import MethodType
from typing import Any, Callable, Optional, TypeVar, Union
@@ -397,6 +398,7 @@ def __init__(
main_kwargs, subparser_kwargs = self._setup_parser_kwargs(self.parser_kwargs)
self.setup_parser(run, main_kwargs, subparser_kwargs)
self.parse_arguments(self.parser, args)
+ self._parse_ckpt_path()
self.subcommand = self.config["subcommand"] if run else None
@@ -551,6 +553,24 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
else:
self.config = parser.parse_args(args)
+ def _parse_ckpt_path(self) -> None:
+ """If a checkpoint path is given, parse the hyperparameters from the checkpoint and update the config."""
+ if not self.config.get("subcommand"):
+ return
+ ckpt_path = self.config[self.config.subcommand].get("ckpt_path")
+ if ckpt_path and Path(ckpt_path).is_file():
+ ckpt = torch.load(ckpt_path, weights_only=True, map_location="cpu")
+ hparams = ckpt.get("hyper_parameters", {})
+ hparams.pop("_instantiator", None)
+ if not hparams:
+ return
+ hparams = {self.config.subcommand: {"model": hparams}}
+ try:
+ self.config = self.parser.parse_object(hparams, self.config)
+ except SystemExit:
+ sys.stderr.write("Parsing of ckpt_path hyperparameters failed!\n")
+ raise
+
def _dump_config(self) -> None:
if hasattr(self, "config_dump"):
return
diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py
index 31d6724a043a3..f25c33359a78a 100644
--- a/src/lightning/pytorch/loops/fit_loop.py
+++ b/src/lightning/pytorch/loops/fit_loop.py
@@ -414,6 +414,9 @@ def on_run_start(self) -> None:
self.epoch_loop.val_loop.setup_data()
trainer.training = True
+ # Check for modules in eval mode at training start
+ self._warn_if_modules_in_eval_mode()
+
call._call_callback_hooks(trainer, "on_train_start")
call._call_lightning_module_hook(trainer, "on_train_start")
call._call_strategy_hook(trainer, "on_train_start")
@@ -515,6 +518,19 @@ def on_load_checkpoint(self, state_dict: dict) -> None:
self._combined_loader_states_to_load = state_dict.get("combined_loader", [])
super().on_load_checkpoint(state_dict)
+ def _warn_if_modules_in_eval_mode(self) -> None:
+ """Warn if any modules are in eval mode at the start of training."""
+ model = self.trainer.lightning_module
+ eval_modules = [name for name, module in model.named_modules() if not module.training]
+
+ if eval_modules:
+ rank_zero_warn(
+ f"Found {len(eval_modules)} module(s) in eval mode at the start of training."
+ " This may lead to unexpected behavior during training. If this is intentional,"
+ " you can ignore this warning.",
+ category=PossibleUserWarning,
+ )
+
def _should_accumulate(self) -> bool:
"""Whether the gradients should be accumulated."""
return self.epoch_loop._should_accumulate()
diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py
index 75e792af46b90..5ea62233e1f69 100644
--- a/src/lightning/pytorch/plugins/precision/amp.py
+++ b/src/lightning/pytorch/plugins/precision/amp.py
@@ -112,7 +112,8 @@ def clip_gradients(
super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm)
def autocast_context_manager(self) -> torch.autocast:
- return torch.autocast(self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half))
+ dtype = torch.bfloat16 if self.precision == "bf16-mixed" else torch.half
+ return torch.autocast(self.device, dtype=dtype, cache_enabled=False)
@override
@contextmanager
diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py
index fd3f66ef42471..92206e1accc31 100644
--- a/src/lightning/pytorch/strategies/ddp.py
+++ b/src/lightning/pytorch/strategies/ddp.py
@@ -36,7 +36,7 @@
_sync_ddp_if_available,
)
from lightning.fabric.utilities.distributed import group as _group
-from lightning.fabric.utilities.imports import _IS_WINDOWS
+from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_3
from lightning.fabric.utilities.optimizer import _optimizers_to_device
from lightning.fabric.utilities.seed import reset_seed
from lightning.fabric.utilities.types import ReduceOp
@@ -200,7 +200,10 @@ def setup_distributed(self) -> None:
self.set_world_ranks()
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
- _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
+ kwargs: dict[str, Any] = {"timeout": self._timeout}
+ if _TORCH_GREATER_EQUAL_2_3:
+ kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
+ _init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs)
def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py
index dabfde70242b9..c5253f77cdedb 100644
--- a/src/lightning/pytorch/strategies/deepspeed.py
+++ b/src/lightning/pytorch/strategies/deepspeed.py
@@ -122,16 +122,16 @@ def __init__(
precision_plugin: Optional[Precision] = None,
process_group_backend: Optional[str] = None,
timeout: Optional[timedelta] = default_pg_timeout,
+ exclude_frozen_parameters: bool = False,
) -> None:
"""Provides capabilities to run training using the DeepSpeed library, with training optimizations for large
- billion parameter models. `For more information: https://pytorch-
- lightning.readthedocs.io/en/stable/advanced/model_parallel.html#deepspeed`.
+ billion parameter models. *For more information:* :ref:`deepspeed_advanced`.
.. warning:: This is an :ref:`experimental ` feature.
Defaults have been set to enable ZeRO-Offload and some have been taken from the link below.
These defaults have been set generally, but may require tuning for optimum performance based on your model size.
- `For more information: https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training`.
+ *For more information:* https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training.
Arguments:
@@ -253,6 +253,8 @@ def __init__(
when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards
per worker.
+ exclude_frozen_parameters: Exclude frozen parameters when saving checkpoints.
+
"""
if not _DEEPSPEED_AVAILABLE:
raise MisconfigurationException(
@@ -311,6 +313,7 @@ def __init__(
self.remote_device = remote_device
self.load_full_weights = load_full_weights
+ self.exclude_frozen_parameters = exclude_frozen_parameters
# default FP16 parameters.
self.loss_scale = loss_scale
@@ -648,7 +651,12 @@ def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Op
# dump states as a checkpoint dictionary object
_exclude_keys = ["state_dict", "optimizer_states"]
checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys}
- self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint, tag="checkpoint")
+ self.deepspeed_engine.save_checkpoint(
+ filepath,
+ client_state=checkpoint,
+ tag="checkpoint",
+ exclude_frozen_parameters=self.exclude_frozen_parameters,
+ )
@override
def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]:
diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py
index 55ea354a5cb60..3fbd0f9cd5f0a 100644
--- a/src/lightning/pytorch/strategies/fsdp.py
+++ b/src/lightning/pytorch/strategies/fsdp.py
@@ -61,7 +61,7 @@
_sync_ddp_if_available,
)
from lightning.fabric.utilities.distributed import group as _group
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_3
from lightning.fabric.utilities.init import _has_meta_device_parameters_or_buffers
from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors
from lightning.fabric.utilities.optimizer import _optimizers_to_device
@@ -260,7 +260,10 @@ def setup_environment(self) -> None:
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
- _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
+ kwargs: dict[str, Any] = {"timeout": self._timeout}
+ if _TORCH_GREATER_EQUAL_2_3:
+ kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
+ _init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs)
# if 'device_mesh' in the `kwargs` is provided as a tuple, update it into the `DeviceMesh` object here
if isinstance(self.kwargs.get("device_mesh"), tuple):
diff --git a/src/lightning/pytorch/strategies/model_parallel.py b/src/lightning/pytorch/strategies/model_parallel.py
index 82fec205af731..e0286dbe2e0e6 100644
--- a/src/lightning/pytorch/strategies/model_parallel.py
+++ b/src/lightning/pytorch/strategies/model_parallel.py
@@ -39,7 +39,7 @@
_sync_ddp_if_available,
)
from lightning.fabric.utilities.distributed import group as _group
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3, _TORCH_GREATER_EQUAL_2_4
from lightning.fabric.utilities.init import _materialize_distributed_module
from lightning.fabric.utilities.load import _METADATA_FILENAME
from lightning.fabric.utilities.optimizer import _optimizers_to_device
@@ -350,7 +350,10 @@ def _setup_distributed(self) -> None:
self.set_world_ranks()
self._process_group_backend = self._get_process_group_backend()
assert self.cluster_environment is not None
- _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)
+ kwargs: dict[str, Any] = {"timeout": self._timeout}
+ if _TORCH_GREATER_EQUAL_2_3:
+ kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None
+ _init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs)
def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)
diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py
index b78843990af30..f3d45a17d4a3b 100644
--- a/src/lightning/pytorch/trainer/trainer.py
+++ b/src/lightning/pytorch/trainer/trainer.py
@@ -184,16 +184,16 @@ def __init__(
:class:`datetime.timedelta`.
limit_train_batches: How much of training dataset to check (float = fraction, int = num_batches).
- Default: ``1.0``.
+ Value is per device. Default: ``1.0``.
limit_val_batches: How much of validation dataset to check (float = fraction, int = num_batches).
- Default: ``1.0``.
+ Value is per device. Default: ``1.0``.
limit_test_batches: How much of test dataset to check (float = fraction, int = num_batches).
- Default: ``1.0``.
+ Value is per device. Default: ``1.0``.
limit_predict_batches: How much of prediction dataset to check (float = fraction, int = num_batches).
- Default: ``1.0``.
+ Value is per device. Default: ``1.0``.
overfit_batches: Overfit a fraction of training/validation data (float) or a set number of batches (int).
Default: ``0.0``.
diff --git a/src/lightning/pytorch/tuner/batch_size_scaling.py b/src/lightning/pytorch/tuner/batch_size_scaling.py
index 99badd84bb8ad..78d2aa52f5725 100644
--- a/src/lightning/pytorch/tuner/batch_size_scaling.py
+++ b/src/lightning/pytorch/tuner/batch_size_scaling.py
@@ -76,24 +76,27 @@ def _scale_batch_size(
if trainer.progress_bar_callback:
trainer.progress_bar_callback.disable()
- new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val)
-
- if mode == "power":
- new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params)
- elif mode == "binsearch":
- new_size = _run_binary_scaling(trainer, new_size, batch_arg_name, max_trials, params)
+ try:
+ new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val)
- garbage_collection_cuda()
+ if mode == "power":
+ new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params)
+ elif mode == "binsearch":
+ new_size = _run_binary_scaling(trainer, new_size, batch_arg_name, max_trials, params)
- log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}")
+ garbage_collection_cuda()
- __scale_batch_restore_params(trainer, params)
+ log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}")
+ except Exception as ex:
+ raise ex
+ finally:
+ __scale_batch_restore_params(trainer, params)
- if trainer.progress_bar_callback:
- trainer.progress_bar_callback.enable()
+ if trainer.progress_bar_callback:
+ trainer.progress_bar_callback.enable()
- trainer._checkpoint_connector.restore(ckpt_path)
- trainer.strategy.remove_checkpoint(ckpt_path)
+ trainer._checkpoint_connector.restore(ckpt_path)
+ trainer.strategy.remove_checkpoint(ckpt_path)
return new_size
diff --git a/src/lightning/pytorch/tuner/lr_finder.py b/src/lightning/pytorch/tuner/lr_finder.py
index b4b61d5cf0f93..5ef35dcd6d992 100644
--- a/src/lightning/pytorch/tuner/lr_finder.py
+++ b/src/lightning/pytorch/tuner/lr_finder.py
@@ -257,40 +257,45 @@ def _lr_find(
# Initialize lr finder object (stores results)
lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)
- # Configure optimizer and scheduler
- lr_finder._exchange_scheduler(trainer)
-
- # Fit, lr & loss logged in callback
- _try_loop_run(trainer, params)
-
- # Prompt if we stopped early
- if trainer.global_step != num_training + start_steps:
- log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.")
-
- # Transfer results from callback to lr finder object
- lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses})
- lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose
-
- __lr_finder_restore_params(trainer, params)
-
- if trainer.progress_bar_callback:
- trainer.progress_bar_callback.enable()
-
- # Update results across ranks
- lr_finder.results = trainer.strategy.broadcast(lr_finder.results)
-
- # Restore initial state of model (this will also restore the original optimizer state)
- trainer._checkpoint_connector.restore(ckpt_path)
- trainer.strategy.remove_checkpoint(ckpt_path)
- trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
- trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
- trainer.fit_loop.epoch_loop.val_loop._combined_loader = None
- trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit
- trainer.fit_loop.setup_data()
+ lr_finder_finished = False
+ try:
+ # Configure optimizer and scheduler
+ lr_finder._exchange_scheduler(trainer)
+
+ # Fit, lr & loss logged in callback
+ _try_loop_run(trainer, params)
+
+ # Prompt if we stopped early
+ if trainer.global_step != num_training + start_steps:
+ log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.")
+
+ # Transfer results from callback to lr finder object
+ lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses})
+ lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose
+
+ __lr_finder_restore_params(trainer, params)
+
+ if trainer.progress_bar_callback:
+ trainer.progress_bar_callback.enable()
+
+ # Update results across ranks
+ lr_finder.results = trainer.strategy.broadcast(lr_finder.results)
+ lr_finder_finished = True
+ except Exception as ex:
+ raise ex
+ finally:
+ # Restore initial state of model (this will also restore the original optimizer state)
+ trainer._checkpoint_connector.restore(ckpt_path)
+ trainer.strategy.remove_checkpoint(ckpt_path)
+ trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
+ trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
+ trainer.fit_loop.epoch_loop.val_loop._combined_loader = None
+ trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit
+ trainer.fit_loop.setup_data()
# Apply LR suggestion after restoring so it persists for the real training run
# When used as a callback, the suggestion would otherwise be lost due to checkpoint restore
- if update_attr:
+ if update_attr and lr_finder_finished:
lr = lr_finder.suggestion()
if lr is not None:
# update the attribute on the LightningModule (e.g., lr or learning_rate)
diff --git a/src/version.info b/src/version.info
index fe16b348d97f7..0cadbc1e33b27 100644
--- a/src/version.info
+++ b/src/version.info
@@ -1 +1 @@
-2.5.4
+2.5.5
diff --git a/tests/legacy/back-compatible-versions.txt b/tests/legacy/back-compatible-versions.txt
index 9032e5c13cc54..885dd6f85b36a 100644
--- a/tests/legacy/back-compatible-versions.txt
+++ b/tests/legacy/back-compatible-versions.txt
@@ -107,3 +107,4 @@
2.5.1
2.5.2
2.5.3
+2.5.4
diff --git a/tests/legacy/generate_checkpoints.sh b/tests/legacy/generate_checkpoints.sh
index 1d083a2a8e052..d3cfa693c2906 100644
--- a/tests/legacy/generate_checkpoints.sh
+++ b/tests/legacy/generate_checkpoints.sh
@@ -7,18 +7,17 @@
set -e
LEGACY_FOLDER=$(cd $(dirname $0); pwd -P)
-printf "LEGACY_FOLDER: $LEGACY_FOLDER"
+printf "LEGACY_FOLDER: $LEGACY_FOLDER\n"
TESTS_FOLDER=$(dirname $LEGACY_FOLDER)
-ENV_PATH=$LEGACY_FOLDER/vEnv
-printf "ENV_PATH: $ENV_PATH"
+ENV_PATH=$LEGACY_FOLDER/.venv
+printf "ENV_PATH: $ENV_PATH\n"
export PYTHONPATH=$TESTS_FOLDER # for `import tests_pytorch`
-printf "PYTHONPATH: $PYTHONPATH"
+printf "PYTHONPATH: $PYTHONPATH\n"
rm -rf $ENV_PATH
function create_and_save_checkpoint {
- python --version
- python -m pip --version
- python -m pip list
+ uv --version
+ uv pip list
python $LEGACY_FOLDER/simple_classif_training.py $pl_ver
@@ -33,10 +32,10 @@ do
printf "\n\n processing version: $pl_ver\n"
# Don't install/update anything before activating venv to avoid breaking any existing environment.
- python -m venv $ENV_PATH
+ uv venv $ENV_PATH
source $ENV_PATH/bin/activate
- python -m pip install "pytorch_lightning==$pl_ver" \
+ uv pip install "pytorch_lightning==$pl_ver" \
-r $LEGACY_FOLDER/requirements.txt \
-r "$(dirname $TESTS_FOLDER)/requirements/pytorch/test.txt" \
-f https://download.pytorch.org/whl/cpu/torch_stable.html
@@ -52,7 +51,7 @@ done
if [[ -z "$@" ]]; then
printf "\n\n processing local version\n"
- python -m pip install \
+ uv pip install \
-r $LEGACY_FOLDER/requirements.txt \
-r "$(dirname $TESTS_FOLDER)/requirements/pytorch/test.txt" \
-f https://download.pytorch.org/whl/cpu/torch_stable.html
diff --git a/tests/tests_fabric/strategies/test_ddp.py b/tests/tests_fabric/strategies/test_ddp.py
index fa5c975228a5e..f302da5d1bc4f 100644
--- a/tests/tests_fabric/strategies/test_ddp.py
+++ b/tests/tests_fabric/strategies/test_ddp.py
@@ -25,6 +25,7 @@
from lightning.fabric.plugins.environments import LightningEnvironment
from lightning.fabric.strategies import DDPStrategy
from lightning.fabric.strategies.ddp import _DDPBackwardSyncControl
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
from tests_fabric.helpers.runif import RunIf
@@ -168,6 +169,52 @@ def test_set_timeout(init_process_group_mock):
process_group_backend = strategy._get_process_group_backend()
global_rank = strategy.cluster_environment.global_rank()
world_size = strategy.cluster_environment.world_size()
+ kwargs = {}
+ if _TORCH_GREATER_EQUAL_2_3:
+ kwargs["device_id"] = strategy.root_device if strategy.root_device.type != "cpu" else None
init_process_group_mock.assert_called_with(
- process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
+ process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs
+ )
+
+
+@mock.patch("torch.distributed.init_process_group")
+def test_device_id_passed_for_cuda_devices(init_process_group_mock):
+ """Test that device_id is passed to init_process_group for CUDA devices but not for CPU."""
+ # Test with CPU device - device_id should be None
+ cpu_strategy = DDPStrategy(parallel_devices=[torch.device("cpu")])
+ cpu_strategy.cluster_environment = LightningEnvironment()
+ cpu_strategy.accelerator = Mock()
+ cpu_strategy.setup_environment()
+
+ process_group_backend = cpu_strategy._get_process_group_backend()
+ global_rank = cpu_strategy.cluster_environment.global_rank()
+ world_size = cpu_strategy.cluster_environment.world_size()
+ kwargs = {}
+ if _TORCH_GREATER_EQUAL_2_3:
+ kwargs["device_id"] = cpu_strategy.root_device if cpu_strategy.root_device.type != "cpu" else None
+ init_process_group_mock.assert_called_with(
+ process_group_backend, rank=global_rank, world_size=world_size, timeout=cpu_strategy._timeout, **kwargs
+ )
+
+ init_process_group_mock.reset_mock()
+
+ # Test with CUDA device - device_id should be the device
+ cuda_device = torch.device("cuda", 0)
+ cuda_strategy = DDPStrategy(parallel_devices=[cuda_device])
+ cuda_strategy.cluster_environment = LightningEnvironment()
+ cuda_strategy.accelerator = Mock()
+ cuda_strategy.setup_environment()
+
+ process_group_backend = cuda_strategy._get_process_group_backend()
+ global_rank = cuda_strategy.cluster_environment.global_rank()
+ world_size = cuda_strategy.cluster_environment.world_size()
+ kwargs = {}
+ if _TORCH_GREATER_EQUAL_2_3:
+ kwargs["device_id"] = cuda_strategy.root_device if cuda_strategy.root_device.type != "cpu" else None
+ init_process_group_mock.assert_called_with(
+ process_group_backend,
+ rank=global_rank,
+ world_size=world_size,
+ timeout=cuda_strategy._timeout,
+ **kwargs,
)
diff --git a/tests/tests_fabric/strategies/test_deepspeed.py b/tests/tests_fabric/strategies/test_deepspeed.py
index 032ee63cd4721..817121c9e2f59 100644
--- a/tests/tests_fabric/strategies/test_deepspeed.py
+++ b/tests/tests_fabric/strategies/test_deepspeed.py
@@ -193,7 +193,9 @@ def test_deepspeed_save_checkpoint_client_state_separation(tmp_path):
model.modules.return_value = [model]
strategy.save_checkpoint(path=tmp_path, state={"model": model, "test": "data"})
# the client_state should not contain any deepspeed engine or deepspeed optimizer
- model.save_checkpoint.assert_called_with(tmp_path, client_state={"test": "data"}, tag="checkpoint")
+ model.save_checkpoint.assert_called_with(
+ tmp_path, client_state={"test": "data"}, tag="checkpoint", exclude_frozen_parameters=False
+ )
# Model and optimizer
optimizer = Mock()
@@ -201,7 +203,9 @@ def test_deepspeed_save_checkpoint_client_state_separation(tmp_path):
model.modules.return_value = [model]
strategy.save_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "test": "data"})
# the client_state should not contain any deepspeed engine or deepspeed optimizer
- model.save_checkpoint.assert_called_with(tmp_path, client_state={"test": "data"}, tag="checkpoint")
+ model.save_checkpoint.assert_called_with(
+ tmp_path, client_state={"test": "data"}, tag="checkpoint", exclude_frozen_parameters=False
+ )
@RunIf(deepspeed=True)
@@ -218,6 +222,27 @@ def test_deepspeed_save_checkpoint_warn_colliding_keys(tmp_path):
strategy.save_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "mp_world_size": 2})
+@RunIf(deepspeed=True)
+@pytest.mark.parametrize("exclude_frozen_parameters", [True, False])
+def test_deepspeed_save_checkpoint_exclude_frozen_parameters(exclude_frozen_parameters):
+ """Test that the DeepSpeed strategy can save checkpoints with the `exclude_frozen_parameters` argument."""
+ from deepspeed import DeepSpeedEngine
+
+ strategy = DeepSpeedStrategy(exclude_frozen_parameters=exclude_frozen_parameters)
+ assert strategy.exclude_frozen_parameters is exclude_frozen_parameters
+
+ model = Mock(spec=DeepSpeedEngine, optimizer=None)
+ model.modules.return_value = [model]
+ strategy.save_checkpoint(path="test_path", state={"model": model, "extra": "data"})
+
+ model.save_checkpoint.assert_called_with(
+ "test_path",
+ client_state={"extra": "data"},
+ tag="checkpoint",
+ exclude_frozen_parameters=exclude_frozen_parameters,
+ )
+
+
@RunIf(deepspeed=True)
def test_deepspeed_load_checkpoint_validate_path(tmp_path):
"""Test that we validate the checkpoint path for a DeepSpeed checkpoint and give suggestions for user error."""
diff --git a/tests/tests_fabric/strategies/test_fsdp.py b/tests/tests_fabric/strategies/test_fsdp.py
index d5f82752a9176..6be379d36582c 100644
--- a/tests/tests_fabric/strategies/test_fsdp.py
+++ b/tests/tests_fabric/strategies/test_fsdp.py
@@ -31,7 +31,7 @@
_get_full_state_dict_context,
_is_sharded_checkpoint,
)
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_3
def test_custom_mixed_precision():
@@ -381,8 +381,11 @@ def test_set_timeout(init_process_group_mock):
process_group_backend = strategy._get_process_group_backend()
global_rank = strategy.cluster_environment.global_rank()
world_size = strategy.cluster_environment.world_size()
+ kwargs = {}
+ if _TORCH_GREATER_EQUAL_2_3:
+ kwargs["device_id"] = strategy.root_device if strategy.root_device.type != "cpu" else None
init_process_group_mock.assert_called_with(
- process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
+ process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs
)
diff --git a/tests/tests_fabric/strategies/test_model_parallel.py b/tests/tests_fabric/strategies/test_model_parallel.py
index 78622adf66fa6..4b6cb6f8fad85 100644
--- a/tests/tests_fabric/strategies/test_model_parallel.py
+++ b/tests/tests_fabric/strategies/test_model_parallel.py
@@ -25,6 +25,7 @@
from lightning.fabric.strategies import ModelParallelStrategy
from lightning.fabric.strategies.fsdp import _is_sharded_checkpoint
from lightning.fabric.strategies.model_parallel import _ParallelBackwardSyncControl
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
from tests_fabric.helpers.runif import RunIf
@@ -316,8 +317,11 @@ def test_set_timeout(init_process_group_mock, _):
process_group_backend = strategy._get_process_group_backend()
global_rank = strategy.cluster_environment.global_rank()
world_size = strategy.cluster_environment.world_size()
+ kwargs = {}
+ if _TORCH_GREATER_EQUAL_2_3:
+ kwargs["device_id"] = strategy.root_device if strategy.root_device.type != "cpu" else None
init_process_group_mock.assert_called_with(
- process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
+ process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs
)
diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py
index d65eaa810ff4d..51c4b320d5525 100644
--- a/tests/tests_fabric/utilities/test_distributed.py
+++ b/tests/tests_fabric/utilities/test_distributed.py
@@ -17,6 +17,7 @@
from lightning.fabric.utilities.distributed import (
_destroy_dist_connection,
_gather_all_tensors,
+ _get_default_process_group_backend_for_device,
_InfiniteBarrier,
_init_dist_connection,
_is_dtensor,
@@ -243,6 +244,27 @@ def test_init_dist_connection_registers_destruction_handler(_, atexit_mock):
atexit_mock.register.assert_not_called()
+def test_get_default_process_group_backend_for_device():
+ """Test that each device type maps to its correct default process group backend."""
+ # register a custom backend for test
+ torch.utils.rename_privateuse1_backend("pcu")
+
+ def mock_backend(store, group_rank, group_size, timeout):
+ pass
+
+ torch.distributed.Backend.register_backend(
+ "pccl",
+ lambda store, group_rank, group_size, timeout: mock_backend(store, group_rank, group_size, timeout),
+ devices=["pcu"],
+ )
+
+ # test that the default backend is correctly set for each device
+ devices = [torch.device("cpu"), torch.device("cuda:0"), torch.device("pcu:0")]
+ backends = ["gloo", "nccl", "pccl"]
+ for device, backend in zip(devices, backends):
+ assert _get_default_process_group_backend_for_device(device) == backend
+
+
@RunIf(min_torch="2.4")
def test_is_dtensor(monkeypatch):
from torch.distributed._tensor import DTensor
diff --git a/tests/tests_fabric/utilities/test_seed.py b/tests/tests_fabric/utilities/test_seed.py
index 2700213747f9a..81fde5aae3ef5 100644
--- a/tests/tests_fabric/utilities/test_seed.py
+++ b/tests/tests_fabric/utilities/test_seed.py
@@ -72,6 +72,14 @@ def test_seed_everything_accepts_valid_seed_from_env():
assert seed_everything() == 17
+@mock.patch.dict(os.environ, {}, clear=True)
+def test_seed_everything_non_verbose_no_warning():
+ """Ensure that no warning is emitted when verbose is False and no seed is provided."""
+ with warnings.catch_warnings(record=True) as caught:
+ seed_everything(verbose=False)
+ assert caught == []
+
+
def test_reset_seed_no_op():
"""Test that the reset_seed function is a no-op when seed_everything() was not used."""
assert "PL_GLOBAL_SEED" not in os.environ
diff --git a/tests/tests_fabric/utilities/test_spike.py b/tests/tests_fabric/utilities/test_spike.py
index e96a5f77df384..e7f0bedb8e9e9 100644
--- a/tests/tests_fabric/utilities/test_spike.py
+++ b/tests/tests_fabric/utilities/test_spike.py
@@ -30,7 +30,7 @@ def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise):
)
-@pytest.mark.flaky(max_runs=3)
+@pytest.mark.flaky(reruns=3)
@pytest.mark.parametrize(
("global_rank_spike", "num_devices", "spike_value", "finite_only"),
# NOTE FOR ALL FOLLOWING TESTS:
diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py
index d93bf1cf60e9c..89d2fa73e3b6b 100644
--- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py
+++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py
@@ -18,7 +18,7 @@
from collections import defaultdict
from typing import Union
from unittest import mock
-from unittest.mock import ANY, Mock, PropertyMock, call
+from unittest.mock import ANY, Mock, PropertyMock, call, patch
import pytest
import torch
@@ -801,3 +801,50 @@ def test_tqdm_leave(leave, tmp_path):
)
trainer.fit(model)
assert pbar.init_train_tqdm.call_count == (4 if leave else 1)
+
+
+@patch("lightning.pytorch.callbacks.progress.rich_progress._RICH_AVAILABLE", False)
+def test_tqdm_progress_bar_reset_behavior(tmp_path):
+ """Test that progress bars call reset() without parameters and set total separately."""
+ model = BoringModel()
+
+ class ResetTrackingTqdm(MockTqdm):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.reset_calls_with_params = []
+
+ def reset(self, total=None):
+ self.reset_calls_with_params.append(total)
+ super().reset(total)
+
+ trainer = Trainer(
+ default_root_dir=tmp_path,
+ limit_train_batches=2,
+ limit_val_batches=2,
+ max_epochs=1,
+ logger=False,
+ enable_checkpointing=False,
+ )
+
+ pbar = trainer.progress_bar_callback
+
+ with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", ResetTrackingTqdm):
+ trainer.fit(model)
+
+ train_bar = pbar.train_progress_bar
+ assert None in train_bar.reset_calls_with_params, (
+ f"train reset() should be called without parameters, got calls: {train_bar.reset_calls_with_params}"
+ )
+ # Verify that total was set separately to the expected value
+ assert 2 in train_bar.total_values, (
+ f"train total should be set to 2 after reset(), got total_values: {train_bar.total_values}"
+ )
+ # Verify that validation progress bar reset() was called without parameters
+ val_bar = pbar.val_progress_bar
+ assert None in val_bar.reset_calls_with_params, (
+ f"validation reset() should be called without parameters, got calls: {val_bar.reset_calls_with_params}"
+ )
+ # Verify that total was set separately to the expected value
+ assert 2 in val_bar.total_values, (
+ f"validation total should be set to 2 after reset(), got total_values: {val_bar.total_values}"
+ )
diff --git a/tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py b/tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py
new file mode 100644
index 0000000000000..fba9e865debfd
--- /dev/null
+++ b/tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py
@@ -0,0 +1,208 @@
+import math
+from datetime import timedelta
+
+import pytest
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader, Dataset
+
+from lightning.pytorch import LightningModule, Trainer, seed_everything
+from lightning.pytorch.callbacks import ModelCheckpoint
+
+
+class TinyDataset(Dataset):
+ def __init__(self, n: int = 4):
+ self.x = torch.arange(n, dtype=torch.float32).view(-1, 1)
+ self.y = self.x.clone()
+
+ def __len__(self):
+ return len(self.x)
+
+ def __getitem__(self, idx):
+ return self.x[idx], self.y[idx]
+
+
+class TrainMetricModule(LightningModule):
+ def __init__(self):
+ super().__init__()
+ self.layer = nn.Linear(1, 1)
+ self._counter = 0.0
+
+ def training_step(self, batch, batch_idx):
+ x, y = batch
+ y_hat = self.layer(x)
+ loss = F.mse_loss(y_hat, y)
+ # strictly increasing train metric per step
+ self._counter += 1.0
+ self.log("train_score", torch.tensor(self._counter), on_step=True, on_epoch=False, prog_bar=False, logger=True)
+ return {"loss": loss}
+
+ def validation_step(self, batch, batch_idx):
+ pass
+
+ def configure_optimizers(self):
+ return torch.optim.SGD(self.parameters(), lr=0.01)
+
+
+def _make_loaders(n=4):
+ ds = TinyDataset(n=n)
+ train_loader = DataLoader(ds, batch_size=2, shuffle=False)
+ val_loader = DataLoader(ds, batch_size=2, shuffle=False)
+ return train_loader, val_loader
+
+
+def test_model_checkpoint_every_n_train_steps_with_train_metric_saves_at_step(tmp_path):
+ """When monitoring a train-step metric, step-interval checkpointing should save at the step boundary (no deferral)
+ and best_model_score should match the last train metric value."""
+ seed_everything(123)
+
+ train_loader, val_loader = _make_loaders(n=4)
+ model = TrainMetricModule()
+
+ ckpt = ModelCheckpoint(
+ dirpath=tmp_path,
+ monitor="train_score",
+ mode="max",
+ save_top_k=1,
+ every_n_train_steps=1,
+ train_time_interval=None,
+ every_n_epochs=0,
+ save_on_train_epoch_end=False,
+ save_weights_only=True,
+ )
+
+ # 2 batches/epoch, run 2 epochs to have multiple step saves
+ trainer = Trainer(
+ max_epochs=2,
+ accelerator="cpu",
+ devices=1,
+ callbacks=[ckpt],
+ num_sanity_val_steps=0,
+ log_every_n_steps=1,
+ limit_train_batches=2,
+ limit_val_batches=0, # no validation needed for this test
+ enable_checkpointing=True,
+ enable_model_summary=False,
+ logger=False,
+ )
+
+ trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
+
+ assert ckpt.best_model_score is not None
+ # 2 epochs * 2 steps/epoch = 4 steps total; metric increments by 1 each step
+ expected = 4.0
+ actual = float(ckpt.best_model_score)
+ assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6)
+
+
+@pytest.mark.parametrize("val_scores", [[0.2, 0.4, 0.9]])
+def test_model_checkpoint_time_interval_with_val_metric_defers_until_validation(tmp_path, val_scores):
+ """With time-interval-based checkpointing, and a validation-only metric, ensure we don't save using stale metrics
+ at step boundaries; saving should occur at validation end."""
+ seed_everything(123)
+
+ train_loader, val_loader = _make_loaders(n=4)
+
+ model = ValMetricModule(val_scores=val_scores)
+
+ ckpt = ModelCheckpoint(
+ dirpath=tmp_path,
+ monitor="auroc",
+ mode="max",
+ save_top_k=1,
+ every_n_train_steps=0, # disable step-based
+ train_time_interval=timedelta(seconds=0), # trigger as often as possible
+ every_n_epochs=0,
+ save_on_train_epoch_end=False,
+ save_weights_only=True,
+ )
+
+ trainer = Trainer(
+ max_epochs=len(val_scores),
+ accelerator="cpu",
+ devices=1,
+ callbacks=[ckpt],
+ num_sanity_val_steps=0,
+ log_every_n_steps=1,
+ limit_train_batches=2,
+ limit_val_batches=1,
+ enable_checkpointing=True,
+ enable_model_summary=False,
+ logger=False,
+ )
+
+ trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
+
+ assert ckpt.best_model_score is not None
+ expected = max(val_scores)
+ actual = float(ckpt.best_model_score)
+ assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6)
+
+
+class ValMetricModule(LightningModule):
+ def __init__(self, val_scores: list[float]):
+ super().__init__()
+ self.layer = nn.Linear(1, 1)
+ self._val_scores = [float(s) for s in val_scores]
+
+ def training_step(self, batch, batch_idx):
+ x, y = batch
+ y_hat = self.layer(x)
+ loss = F.mse_loss(y_hat, y)
+ return {"loss": loss}
+
+ def validation_step(self, batch, batch_idx):
+ pass
+
+ def on_validation_epoch_end(self):
+ score = self._val_scores[self.current_epoch]
+ self.log("auroc", torch.tensor(score, dtype=torch.float32), prog_bar=False, logger=True)
+
+ def configure_optimizers(self):
+ return torch.optim.SGD(self.parameters(), lr=0.01)
+
+
+@pytest.mark.parametrize("val_scores", [[0.1, 0.5, 1.0, 3.0]])
+def test_model_checkpoint_defer_until_next_validation_when_val_every_2_epochs(tmp_path, val_scores):
+ """With validation running every 2 epochs, step-triggered saves at the end of non-validation epochs should be
+ deferred and then performed at the next validation end when the metric is available."""
+ seed_everything(123)
+
+ train_loader, val_loader = _make_loaders(n=4)
+
+ model = ValMetricModule(val_scores=val_scores)
+
+ ckpt = ModelCheckpoint(
+ dirpath=tmp_path,
+ monitor="auroc",
+ mode="max",
+ save_top_k=1,
+ every_n_train_steps=2, # end of each epoch
+ train_time_interval=None,
+ every_n_epochs=0,
+ save_on_train_epoch_end=False,
+ save_weights_only=True,
+ )
+
+ trainer = Trainer(
+ max_epochs=len(val_scores),
+ accelerator="cpu",
+ devices=1,
+ callbacks=[ckpt],
+ num_sanity_val_steps=0,
+ log_every_n_steps=1,
+ limit_train_batches=2,
+ limit_val_batches=1,
+ enable_checkpointing=True,
+ enable_model_summary=False,
+ logger=False,
+ check_val_every_n_epoch=2, # only validate every 2 epochs
+ )
+
+ trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
+
+ assert ckpt.best_model_score is not None
+ expected = max(val_scores) # last/maximum value occurs at final validation epoch
+ actual = float(ckpt.best_model_score)
+ assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6)
diff --git a/tests/tests_pytorch/callbacks/test_model_checkpoint_edge_cases.py b/tests/tests_pytorch/callbacks/test_model_checkpoint_edge_cases.py
new file mode 100644
index 0000000000000..a265f8bc5f194
--- /dev/null
+++ b/tests/tests_pytorch/callbacks/test_model_checkpoint_edge_cases.py
@@ -0,0 +1,174 @@
+import math
+from datetime import timedelta
+
+import pytest
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader, Dataset
+
+from lightning.pytorch import LightningModule, Trainer, seed_everything
+from lightning.pytorch.callbacks import ModelCheckpoint
+
+
+class TinyDataset(Dataset):
+ def __init__(self, n: int = 8):
+ self.x = torch.arange(n, dtype=torch.float32).view(-1, 1)
+ self.y = self.x.clone()
+
+ def __len__(self):
+ return len(self.x)
+
+ def __getitem__(self, idx):
+ return self.x[idx], self.y[idx]
+
+
+def _make_loaders(n=8, batch_size=2):
+ ds = TinyDataset(n=n)
+ train_loader = DataLoader(ds, batch_size=batch_size, shuffle=False)
+ val_loader = DataLoader(ds, batch_size=batch_size, shuffle=False)
+ return train_loader, val_loader
+
+
+class MultiValPerEpochModule(LightningModule):
+ """Logs a validation metric on every validation run, even if validation is run multiple times per epoch."""
+
+ def __init__(self, val_scores: list[float]):
+ super().__init__()
+ self.layer = nn.Linear(1, 1)
+ self._val_scores = [float(s) for s in val_scores]
+ self._val_call_idx = 0
+
+ def training_step(self, batch, batch_idx):
+ x, y = batch
+ y_hat = self.layer(x)
+ loss = F.mse_loss(y_hat, y)
+ return {"loss": loss}
+
+ def validation_step(self, batch, batch_idx):
+ pass
+
+ def on_validation_epoch_end(self):
+ score = self._val_scores[self._val_call_idx]
+ self._val_call_idx += 1
+ self.log("auroc", torch.tensor(score, dtype=torch.float32), prog_bar=False, logger=True)
+
+ def configure_optimizers(self):
+ return torch.optim.SGD(self.parameters(), lr=0.01)
+
+
+class ValOnceEveryTwoEpochsModule(LightningModule):
+ """Logs a validation metric only when validation runs (e.g., every 2 epochs), indexed by current_epoch."""
+
+ def __init__(self, val_scores: list[float]):
+ super().__init__()
+ self.layer = nn.Linear(1, 1)
+ self._val_scores = [float(s) for s in val_scores]
+
+ def training_step(self, batch, batch_idx):
+ x, y = batch
+ y_hat = self.layer(x)
+ loss = F.mse_loss(y_hat, y)
+ return {"loss": loss}
+
+ def validation_step(self, batch, batch_idx):
+ pass
+
+ def on_validation_epoch_end(self):
+ # current_epoch indexes into provided scores; only called when validation runs
+ score = self._val_scores[self.current_epoch]
+ self.log("auroc", torch.tensor(score, dtype=torch.float32), prog_bar=False, logger=True)
+
+ def configure_optimizers(self):
+ return torch.optim.SGD(self.parameters(), lr=0.01)
+
+
+@pytest.mark.parametrize("val_scores", [[0.1, 0.9]])
+def test_checkpoint_defers_with_mid_epoch_validation(tmp_path, val_scores):
+ """With val_check_interval=0.5 (validation mid-epoch and at epoch end), and step-based checkpointing, saves must be
+ deferred until each validation end so monitored validation metrics are fresh."""
+ seed_everything(123)
+
+ # 4 train batches per epoch (batch_size=2 over n=8), so two validations: after 2 batches and after 4 batches
+ train_loader, val_loader = _make_loaders(n=8, batch_size=2)
+
+ model = MultiValPerEpochModule(val_scores=val_scores)
+
+ ckpt = ModelCheckpoint(
+ dirpath=tmp_path,
+ monitor="auroc",
+ mode="max",
+ save_top_k=1,
+ every_n_train_steps=1, # would trigger every step, but must defer to validation
+ train_time_interval=None,
+ every_n_epochs=0,
+ save_on_train_epoch_end=False,
+ save_weights_only=True,
+ )
+
+ trainer = Trainer(
+ max_epochs=1,
+ accelerator="cpu",
+ devices=1,
+ callbacks=[ckpt],
+ num_sanity_val_steps=0,
+ log_every_n_steps=1,
+ limit_train_batches=4, # ensure exactly 4 steps => two validations at 0.5 and 1.0
+ limit_val_batches=1,
+ enable_checkpointing=True,
+ enable_model_summary=False,
+ logger=False,
+ val_check_interval=0.5,
+ )
+
+ trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
+
+ assert ckpt.best_model_score is not None
+ expected = max(val_scores)
+ actual = float(ckpt.best_model_score)
+ assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6)
+
+
+@pytest.mark.parametrize("val_scores", [[0.2, 0.6]])
+def test_time_interval_defers_across_epoch_until_first_validation(tmp_path, val_scores):
+ """With time-interval saving and validation only every 2 epochs, ensure no save uses stale/missing validation
+ metrics; the first save should happen at the first validation end (epoch 2)."""
+ seed_everything(123)
+
+ train_loader, val_loader = _make_loaders(n=4, batch_size=2)
+
+ model = ValOnceEveryTwoEpochsModule(val_scores=val_scores)
+
+ ckpt = ModelCheckpoint(
+ dirpath=tmp_path,
+ monitor="auroc",
+ mode="max",
+ save_top_k=1,
+ every_n_train_steps=0, # disable step-based
+ train_time_interval=timedelta(seconds=0), # trigger frequently
+ every_n_epochs=0,
+ save_on_train_epoch_end=False,
+ save_weights_only=True,
+ )
+
+ trainer = Trainer(
+ max_epochs=2,
+ accelerator="cpu",
+ devices=1,
+ callbacks=[ckpt],
+ num_sanity_val_steps=0,
+ log_every_n_steps=1,
+ limit_train_batches=2,
+ limit_val_batches=1,
+ enable_checkpointing=True,
+ enable_model_summary=False,
+ logger=False,
+ check_val_every_n_epoch=2, # first validation only after 2nd epoch
+ )
+
+ trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
+
+ assert ckpt.best_model_score is not None
+ expected = val_scores[1] # validation runs only once at epoch 2, logging index 1
+ actual = float(ckpt.best_model_score)
+ assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6)
diff --git a/tests/tests_pytorch/callbacks/test_model_checkpoint_step_interval_val_metric.py b/tests/tests_pytorch/callbacks/test_model_checkpoint_step_interval_val_metric.py
new file mode 100644
index 0000000000000..c3fa0bfcd2e38
--- /dev/null
+++ b/tests/tests_pytorch/callbacks/test_model_checkpoint_step_interval_val_metric.py
@@ -0,0 +1,106 @@
+import math
+
+import pytest
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader, Dataset
+
+from lightning.pytorch import LightningModule, Trainer, seed_everything
+from lightning.pytorch.callbacks import ModelCheckpoint
+
+
+class TinyDataset(Dataset):
+ def __init__(self, n: int = 4):
+ self.x = torch.arange(n, dtype=torch.float32).view(-1, 1)
+ self.y = self.x.clone()
+
+ def __len__(self):
+ return len(self.x)
+
+ def __getitem__(self, idx):
+ return self.x[idx], self.y[idx]
+
+
+class ValMetricModule(LightningModule):
+ def __init__(self, val_scores: list[float]):
+ super().__init__()
+ self.layer = nn.Linear(1, 1)
+ self._val_scores = [float(s) for s in val_scores]
+
+ # LightningModule API (minimal)
+ def training_step(self, batch, batch_idx):
+ x, y = batch
+ y_hat = self.layer(x)
+ loss = F.mse_loss(y_hat, y)
+ return {"loss": loss}
+
+ def validation_step(self, batch, batch_idx):
+ # do nothing per-step; we log at epoch end
+ pass
+
+ def on_validation_epoch_end(self):
+ # Log a validation metric only at validation epoch end
+ # Values increase across epochs; best should be the last epoch
+ score = self._val_scores[self.current_epoch]
+ # use logger=True so it lands in trainer.callback_metrics
+ self.log("auroc", torch.tensor(score, dtype=torch.float32), prog_bar=False, logger=True)
+
+ def configure_optimizers(self):
+ return torch.optim.SGD(self.parameters(), lr=0.01)
+
+
+@pytest.mark.parametrize("val_scores", [[0.1, 0.5, 1.0]])
+def test_model_checkpoint_every_n_train_steps_with_val_metric_saves_after_val(tmp_path, val_scores):
+ """Reproduces #20919: Using every_n_train_steps with a validation-only metric should save the best checkpoint only
+ after the metric is computed at validation, not earlier at the train-step boundary.
+
+ Expectation: best_model_score equals the last (max) val score.
+
+ """
+ seed_everything(123)
+
+ # 2 train batches per epoch (so checkpoint triggers at the epoch boundary)
+ ds = TinyDataset(n=4)
+ train_loader = DataLoader(ds, batch_size=2, shuffle=False)
+ val_loader = DataLoader(ds, batch_size=2, shuffle=False)
+
+ model = ValMetricModule(val_scores=val_scores)
+
+ ckpt = ModelCheckpoint(
+ dirpath=tmp_path,
+ monitor="auroc",
+ mode="max",
+ save_top_k=1,
+ # critical: trigger on train steps, not on epoch end
+ every_n_train_steps=2, # equal to number of train batches per epoch
+ train_time_interval=None,
+ every_n_epochs=0,
+ save_on_train_epoch_end=False,
+ save_weights_only=True,
+ )
+
+ trainer = Trainer(
+ max_epochs=len(val_scores),
+ accelerator="cpu",
+ devices=1,
+ callbacks=[ckpt],
+ num_sanity_val_steps=0,
+ log_every_n_steps=1,
+ limit_train_batches=2,
+ limit_val_batches=1,
+ enable_checkpointing=True,
+ enable_model_summary=False,
+ logger=False,
+ )
+
+ trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
+
+ assert ckpt.best_model_score is not None
+ # Should equal the last (max) validation score
+ expected = max(val_scores)
+ actual = float(ckpt.best_model_score)
+ assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6), (
+ f"best_model_score should be {expected} (last/maximum val score), got {actual}.\n"
+ f"This indicates the checkpoint was saved before the validation metric was computed."
+ )
diff --git a/tests/tests_pytorch/callbacks/test_spike.py b/tests/tests_pytorch/callbacks/test_spike.py
index f61a6c59ca9db..86e3ac88e93cf 100644
--- a/tests/tests_pytorch/callbacks/test_spike.py
+++ b/tests/tests_pytorch/callbacks/test_spike.py
@@ -48,7 +48,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
-@pytest.mark.flaky(max_runs=3)
+@pytest.mark.flaky(reruns=3)
@pytest.mark.parametrize(
("global_rank_spike", "num_devices", "spike_value", "finite_only"),
# NOTE FOR ALL FOLLOWING TESTS:
diff --git a/tests/tests_pytorch/loops/test_training_loop.py b/tests/tests_pytorch/loops/test_training_loop.py
index e3a4c37f6a284..f5aaa18095fc5 100644
--- a/tests/tests_pytorch/loops/test_training_loop.py
+++ b/tests/tests_pytorch/loops/test_training_loop.py
@@ -13,12 +13,14 @@
# limitations under the License.
import itertools
import logging
+import warnings
from unittest.mock import Mock
import pytest
import torch
from torch.utils.data import DataLoader
+from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.pytorch import Trainer, seed_everything
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.loops import _FitLoop
@@ -277,3 +279,29 @@ def __iter__(self):
# assert progress bar callback uses correct total steps
assert pbar.train_progress_bar.total == max_steps
+
+
+@pytest.mark.parametrize("warn", [True, False])
+def test_eval_mode_warning(tmp_path, warn):
+ """Test that a warning is raised if any module is in eval mode at the start of training."""
+ model = BoringModel()
+ if warn:
+ model.some_eval_module = torch.nn.Linear(32, 16)
+ model.some_eval_module.eval()
+
+ trainer = Trainer(
+ default_root_dir=tmp_path,
+ max_epochs=1,
+ )
+
+ if warn:
+ with pytest.warns(PossibleUserWarning):
+ trainer.fit(model)
+ else:
+ with warnings.catch_warnings(record=True) as warning_list:
+ warnings.simplefilter("always")
+ trainer.fit(model)
+ eval_warnings = [
+ w for w in warning_list if issubclass(w.category, PossibleUserWarning) and "eval mode" in str(w.message)
+ ]
+ assert len(eval_warnings) == 0, "Expected no eval mode warnings"
diff --git a/tests/tests_pytorch/plugins/precision/test_amp.py b/tests/tests_pytorch/plugins/precision/test_amp.py
index cb061c540b2be..3894c4256e0b8 100644
--- a/tests/tests_pytorch/plugins/precision/test_amp.py
+++ b/tests/tests_pytorch/plugins/precision/test_amp.py
@@ -14,6 +14,8 @@
from unittest.mock import Mock
import pytest
+import torch
+from torch import nn
from torch.optim import Optimizer
from lightning.pytorch.plugins import MixedPrecision
@@ -51,3 +53,19 @@ def test_optimizer_amp_scaling_support_in_step_method():
with pytest.raises(RuntimeError, match="The current optimizer.*does not allow for gradient clipping"):
precision.clip_gradients(optimizer, clip_val=1.0)
+
+
+def test_amp_with_no_grad():
+ """Test that asserts using `no_grad` context wrapper with a persistent AMP context wrapper does not break gradient
+ tracking."""
+ layer = nn.Linear(2, 1)
+ x = torch.randn(1, 2)
+ amp = MixedPrecision(precision="bf16-mixed", device="cpu")
+
+ with amp.autocast_context_manager():
+ with torch.no_grad():
+ _ = layer(x)
+
+ loss = layer(x).mean()
+ loss.backward()
+ assert loss.grad_fn is not None
diff --git a/tests/tests_pytorch/strategies/test_ddp.py b/tests/tests_pytorch/strategies/test_ddp.py
index 915e57440b40f..823d77d0d5848 100644
--- a/tests/tests_pytorch/strategies/test_ddp.py
+++ b/tests/tests_pytorch/strategies/test_ddp.py
@@ -20,6 +20,7 @@
from torch.nn.parallel import DistributedDataParallel
from lightning.fabric.plugins.environments import LightningEnvironment
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.plugins import DoublePrecision, HalfPrecision, Precision
@@ -132,8 +133,41 @@ def test_set_timeout(mock_init_process_group):
process_group_backend = trainer.strategy._get_process_group_backend()
global_rank = trainer.strategy.cluster_environment.global_rank()
world_size = trainer.strategy.cluster_environment.world_size()
+ kwargs = {}
+ if _TORCH_GREATER_EQUAL_2_3:
+ kwargs["device_id"] = trainer.strategy.root_device if trainer.strategy.root_device.type != "cpu" else None
mock_init_process_group.assert_called_with(
- process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
+ process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs
+ )
+
+
+@mock.patch("torch.distributed.init_process_group")
+def test_device_id_passed_for_cuda_devices_pytorch(mock_init_process_group):
+ """Test that device_id is passed to init_process_group for CUDA devices but not for CPU."""
+ # Test with CPU device - device_id should be None
+ model = BoringModel()
+ ddp_strategy = DDPStrategy()
+ trainer = Trainer(
+ max_epochs=1,
+ accelerator="cpu",
+ strategy=ddp_strategy,
+ )
+ trainer.strategy.connect(model)
+ trainer.lightning_module.trainer = trainer
+ trainer.strategy.setup_environment()
+
+ process_group_backend = trainer.strategy._get_process_group_backend()
+ global_rank = trainer.strategy.cluster_environment.global_rank()
+ world_size = trainer.strategy.cluster_environment.world_size()
+ kwargs = {}
+ if _TORCH_GREATER_EQUAL_2_3:
+ kwargs["device_id"] = trainer.strategy.root_device if trainer.strategy.root_device.type != "cpu" else None
+ mock_init_process_group.assert_called_with(
+ process_group_backend,
+ rank=global_rank,
+ world_size=world_size,
+ timeout=trainer.strategy._timeout,
+ **kwargs,
)
diff --git a/tests/tests_pytorch/strategies/test_ddp_integration.py b/tests/tests_pytorch/strategies/test_ddp_integration.py
index 048403366ebc7..fc3a8cfebbac0 100644
--- a/tests/tests_pytorch/strategies/test_ddp_integration.py
+++ b/tests/tests_pytorch/strategies/test_ddp_integration.py
@@ -66,7 +66,7 @@ def test_multi_gpu_model_ddp_fit_test(tmp_path):
assert out["test_acc"] > 0.7
-@RunIf(skip_windows=True)
+@RunIf(skip_windows=True, max_torch="2.7")
@mock.patch("torch.cuda.set_device")
@mock.patch("lightning.pytorch.accelerators.cuda._check_cuda_matmul_precision")
@mock.patch("lightning.pytorch.accelerators.cuda._clear_cuda_memory")
diff --git a/tests/tests_pytorch/strategies/test_deepspeed.py b/tests/tests_pytorch/strategies/test_deepspeed.py
index 7e7d2eacd0617..503d1ea0e630b 100644
--- a/tests/tests_pytorch/strategies/test_deepspeed.py
+++ b/tests/tests_pytorch/strategies/test_deepspeed.py
@@ -562,6 +562,46 @@ def test_deepspeed_multigpu_single_file(tmp_path):
trainer.test(model, ckpt_path=checkpoint_path)
+@RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True)
+def test_deepspeed_strategy_exclude_frozen_parameters_integration(tmp_path):
+ """Test end-to-end integration of exclude_frozen_parameters with actual model training and checkpointing."""
+
+ class TestModelWithFrozenParams(BoringModel):
+ def __init__(self):
+ super().__init__()
+ self.frozen_layer = torch.nn.Linear(32, 32)
+
+ def configure_model(self) -> None:
+ super().configure_model()
+ # Freeze the additional layer parameters
+ for param in self.frozen_layer.parameters():
+ param.requires_grad = False
+
+ def forward(self, x):
+ x = self.frozen_layer(x)
+ return super().forward(x)
+
+ model = TestModelWithFrozenParams()
+
+ trainer = Trainer(
+ default_root_dir=tmp_path,
+ strategy=DeepSpeedStrategy(exclude_frozen_parameters=True),
+ accelerator="gpu",
+ devices=1,
+ fast_dev_run=True,
+ precision="16-mixed",
+ enable_progress_bar=False,
+ enable_model_summary=False,
+ )
+
+ trainer.fit(model)
+ checkpoint_path = os.path.join(tmp_path, "checkpoint_exclude_frozen.ckpt")
+ trainer.save_checkpoint(checkpoint_path)
+
+ # Verify checkpoint was created
+ assert os.path.exists(checkpoint_path)
+
+
class ModelParallelClassificationModel(LightningModule):
def __init__(self, lr: float = 0.01, num_blocks: int = 5):
super().__init__()
diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py
index 560ab19f823ca..f7c15b5930be8 100644
--- a/tests/tests_pytorch/strategies/test_fsdp.py
+++ b/tests/tests_pytorch/strategies/test_fsdp.py
@@ -18,7 +18,7 @@
from lightning.fabric.plugins.environments import LightningEnvironment
from lightning.fabric.strategies.fsdp import _is_sharded_checkpoint
-from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_3
from lightning.fabric.utilities.load import _load_distributed_checkpoint
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
@@ -532,8 +532,11 @@ def test_set_timeout(init_process_group_mock):
process_group_backend = strategy._get_process_group_backend()
global_rank = strategy.cluster_environment.global_rank()
world_size = strategy.cluster_environment.world_size()
+ kwargs = {}
+ if _TORCH_GREATER_EQUAL_2_3:
+ kwargs["device_id"] = strategy.root_device if strategy.root_device.type != "cpu" else None
init_process_group_mock.assert_called_with(
- process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
+ process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs
)
diff --git a/tests/tests_pytorch/strategies/test_model_parallel.py b/tests/tests_pytorch/strategies/test_model_parallel.py
index 86a95944ac20d..c803c10afa4b4 100644
--- a/tests/tests_pytorch/strategies/test_model_parallel.py
+++ b/tests/tests_pytorch/strategies/test_model_parallel.py
@@ -22,6 +22,7 @@
import torch.nn as nn
from lightning.fabric.strategies.model_parallel import _is_sharded_checkpoint
+from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3
from lightning.pytorch import LightningModule
from lightning.pytorch.plugins.environments import LightningEnvironment
from lightning.pytorch.strategies import ModelParallelStrategy
@@ -202,8 +203,11 @@ def test_set_timeout(init_process_group_mock, _):
process_group_backend = strategy._get_process_group_backend()
global_rank = strategy.cluster_environment.global_rank()
world_size = strategy.cluster_environment.world_size()
+ kwargs = {}
+ if _TORCH_GREATER_EQUAL_2_3:
+ kwargs["device_id"] = strategy.root_device if strategy.root_device.type != "cpu" else None
init_process_group_mock.assert_called_with(
- process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
+ process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs
)
diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py
index 1b883dda0282a..248852f4cf1f3 100644
--- a/tests/tests_pytorch/test_cli.py
+++ b/tests/tests_pytorch/test_cli.py
@@ -17,7 +17,7 @@
import operator
import os
import sys
-from contextlib import ExitStack, contextmanager, redirect_stdout
+from contextlib import ExitStack, contextmanager, redirect_stderr, redirect_stdout
from io import StringIO
from pathlib import Path
from typing import Callable, Optional, Union
@@ -98,7 +98,7 @@ class _UnkArgError(Exception):
def _raise():
raise _UnkArgError
- parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
+ parser = LightningArgumentParser(add_help=False)
parser.add_lightning_class_args(Trainer, None)
monkeypatch.setattr(parser, "exit", lambda *args: _raise(), raising=True)
@@ -487,6 +487,45 @@ def test_lightning_cli_print_config():
assert outval["ckpt_path"] is None
+class BoringCkptPathModel(BoringModel):
+ def __init__(self, out_dim: int = 2, hidden_dim: int = 2) -> None:
+ super().__init__()
+ self.save_hyperparameters()
+ self.layer = torch.nn.Linear(32, out_dim)
+
+
+def test_lightning_cli_ckpt_path_argument_hparams(cleandir):
+ class CkptPathCLI(LightningCLI):
+ def add_arguments_to_parser(self, parser):
+ parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2)
+
+ cli_args = ["fit", "--model.out_dim=3", "--trainer.max_epochs=1"]
+ with mock.patch("sys.argv", ["any.py"] + cli_args):
+ cli = CkptPathCLI(BoringCkptPathModel)
+
+ assert cli.config.fit.model.out_dim == 3
+ assert cli.config.fit.model.hidden_dim == 6
+ hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
+ assert hparams_path.is_file()
+ hparams = yaml.safe_load(hparams_path.read_text())
+ assert hparams["out_dim"] == 3
+ assert hparams["hidden_dim"] == 6
+
+ checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt"))
+ cli_args = ["predict", f"--ckpt_path={checkpoint_path}"]
+ with mock.patch("sys.argv", ["any.py"] + cli_args):
+ cli = CkptPathCLI(BoringCkptPathModel)
+
+ assert cli.config.predict.model.out_dim == 3
+ assert cli.config.predict.model.hidden_dim == 6
+ assert cli.config_init.predict.model.layer.out_features == 3
+
+ err = StringIO()
+ with mock.patch("sys.argv", ["any.py"] + cli_args), redirect_stderr(err), pytest.raises(SystemExit):
+ cli = LightningCLI(BoringModel)
+ assert "Parsing of ckpt_path hyperparameters failed" in err.getvalue()
+
+
def test_lightning_cli_submodules(cleandir):
class MainModule(BoringModel):
def __init__(self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1):
diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py
index 69575a351b0a5..81352ebe256ef 100644
--- a/tests/tests_pytorch/tuner/test_lr_finder.py
+++ b/tests/tests_pytorch/tuner/test_lr_finder.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import glob
import logging
import math
import os
@@ -750,3 +751,52 @@ def __init__(self):
assert not torch.allclose(gradients, gradients_no_spacing, rtol=0.1), (
"Gradients should differ significantly in exponential mode when using proper spacing"
)
+
+
+def test_lr_finder_checkpoint_cleanup_on_error(tmp_path):
+ """Test that temporary checkpoint files are cleaned up even when an error occurs during lr finding."""
+
+ class FailingModel(BoringModel):
+ def __init__(self, fail_on_step=2):
+ super().__init__()
+ self.fail_on_step = fail_on_step
+ self.current_step = 0
+ self.learning_rate = 1e-3
+
+ def training_step(self, batch, batch_idx):
+ self.current_step += 1
+ if self.current_step >= self.fail_on_step:
+ raise RuntimeError("Intentional failure for testing cleanup")
+ return super().training_step(batch, batch_idx)
+
+ def configure_optimizers(self):
+ optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate)
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
+ return [optimizer], [lr_scheduler]
+
+ model = FailingModel()
+ lr_finder = LearningRateFinder(num_training_steps=5)
+
+ trainer = Trainer(
+ default_root_dir=tmp_path,
+ max_epochs=1,
+ enable_checkpointing=False,
+ enable_progress_bar=False,
+ enable_model_summary=False,
+ logger=False,
+ callbacks=[lr_finder],
+ )
+
+ # Check no lr_find checkpoint files exist initially
+ lr_find_checkpoints = glob.glob(os.path.join(tmp_path, ".lr_find_*.ckpt"))
+ assert len(lr_find_checkpoints) == 0, "No lr_find checkpoint files should exist initially"
+
+ # Run lr finder and expect it to fail
+ with pytest.raises(RuntimeError, match="Intentional failure for testing cleanup"):
+ trainer.fit(model)
+
+ # Check that no lr_find checkpoint files are left behind
+ lr_find_checkpoints = glob.glob(os.path.join(tmp_path, ".lr_find_*.ckpt"))
+ assert len(lr_find_checkpoints) == 0, (
+ f"lr_find checkpoint files should be cleaned up, but found: {lr_find_checkpoints}"
+ )
diff --git a/tests/tests_pytorch/tuner/test_scale_batch_size.py b/tests/tests_pytorch/tuner/test_scale_batch_size.py
index e4ed533c6fa83..f0e5fbe6a3c49 100644
--- a/tests/tests_pytorch/tuner/test_scale_batch_size.py
+++ b/tests/tests_pytorch/tuner/test_scale_batch_size.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import glob
import logging
import os
from copy import deepcopy
@@ -486,3 +487,49 @@ def test_batch_size_finder_callback_val_batches(tmp_path):
assert trainer.num_val_batches[0] == len(trainer.val_dataloaders)
assert trainer.num_val_batches[0] != steps_per_trial
+
+
+def test_scale_batch_size_checkpoint_cleanup_on_error(tmp_path):
+ """Test that temporary checkpoint files are cleaned up even when an error occurs during batch size scaling."""
+
+ class FailingModel(BoringModel):
+ def __init__(self, fail_on_step=2):
+ super().__init__()
+ self.fail_on_step = fail_on_step
+ self.current_step = 0
+ self.batch_size = 2
+
+ def training_step(self, batch, batch_idx):
+ self.current_step += 1
+ if self.current_step >= self.fail_on_step:
+ raise RuntimeError("Intentional failure for testing cleanup")
+ return super().training_step(batch, batch_idx)
+
+ def train_dataloader(self):
+ return DataLoader(RandomDataset(32, 64), batch_size=self.batch_size)
+
+ model = FailingModel()
+ batch_size_finder = BatchSizeFinder(max_trials=3, steps_per_trial=2)
+ trainer = Trainer(
+ default_root_dir=tmp_path,
+ max_epochs=1,
+ enable_checkpointing=False,
+ enable_progress_bar=False,
+ enable_model_summary=False,
+ logger=False,
+ callbacks=[batch_size_finder],
+ )
+
+ # Check no scale_batch_size checkpoint files exist initially
+ scale_checkpoints = glob.glob(os.path.join(tmp_path, ".scale_batch_size_*.ckpt"))
+ assert len(scale_checkpoints) == 0, "No scale_batch_size checkpoint files should exist initially"
+
+ # Run batch size scaler and expect it to fail
+ with pytest.raises(RuntimeError, match="Intentional failure for testing cleanup"):
+ trainer.fit(model)
+
+ # Check that no scale_batch_size checkpoint files are left behind
+ scale_checkpoints = glob.glob(os.path.join(tmp_path, ".scale_batch_size_*.ckpt"))
+ assert len(scale_checkpoints) == 0, (
+ f"scale_batch_size checkpoint files should be cleaned up, but found: {scale_checkpoints}"
+ )