diff --git a/.github/workflows/_legacy-checkpoints.yml b/.github/workflows/_legacy-checkpoints.yml index 8cceb5870af8d..6c37281fc04ec 100644 --- a/.github/workflows/_legacy-checkpoints.yml +++ b/.github/workflows/_legacy-checkpoints.yml @@ -57,28 +57,32 @@ jobs: steps: - uses: actions/checkout@v5 - - uses: actions/setup-python@v5 + - name: Install uv and set Python version + uses: astral-sh/setup-uv@v6 with: - # Python version here needs to be supported by all PL versions listed in back-compatible-versions.txt. python-version: "3.9" + # TODO: Avoid activating environment like this + # see: https://github.com/astral-sh/setup-uv/tree/v6/?tab=readme-ov-file#activate-environment + activate-environment: true + enable-cache: true - name: Install PL from source env: PACKAGE_NAME: pytorch FREEZE_REQUIREMENTS: 1 timeout-minutes: 20 - run: pip install . --extra-index-url="${TORCH_URL}" + run: uv pip install . --extra-index-url="${TORCH_URL}" if: inputs.pl_version == '' - name: Install PL version timeout-minutes: 20 - run: pip install "pytorch-lightning==${{ inputs.pl_version }}" --extra-index-url="${TORCH_URL}" + run: uv pip install "pytorch-lightning==${{ inputs.pl_version }}" --extra-index-url="${TORCH_URL}" if: inputs.pl_version != '' - name: Adjust tests -> PL if: ${{ matrix.pkg-name != 'lightning' }} run: | - pip install -q -r .actions/requirements.txt + uv pip install -q -r .actions/requirements.txt python .actions/assistant.py copy_replace_imports --source_dir="./tests" \ --source_import="lightning.fabric,lightning.pytorch" \ --target_import="lightning_fabric,pytorch_lightning" @@ -115,7 +119,7 @@ jobs: # export to env bool if secrets.AWS_REGION is not empty run: echo "WITH_SECRETS=$([ -n '${{ secrets.AWS_REGION }}' ] && echo 1 || echo 0)" >> $GITHUB_ENV - - run: pip install -r requirements/ci.txt + - run: uv pip install -r requirements/ci.txt - name: Upload checkpoints to S3 if: ${{ env.WITH_SECRETS == '1' }} working-directory: ${{ env.LEGACY_FOLDER }} diff --git a/.github/workflows/ci-tests-fabric.yml b/.github/workflows/ci-tests-fabric.yml index f4c66f425cc71..615633c5311bf 100644 --- a/.github/workflows/ci-tests-fabric.yml +++ b/.github/workflows/ci-tests-fabric.yml @@ -38,44 +38,30 @@ jobs: strategy: fail-fast: false matrix: - include: + os: [macOS-14, ubuntu-22.04, windows-2022] + config: # only run PyTorch latest - - { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } - - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } - - { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } - - { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" } - - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" } - - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" } - - { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" } - - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" } - - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" } - - { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" } - - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" } - - { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" } - - { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" } - - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" } - - { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" } + - { pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } + - { pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" } + - { pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" } + - { pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" } + - { pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" } + # only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues - - { os: "macOS-14", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.6" } - - { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.6" } - - { os: "windows-2022", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.6" } - # "oldest" versions tests, only on minimum Python - - { os: "macOS-14", pkg-name: "fabric", pytorch-version: "2.1", requires: "oldest" } - - { os: "ubuntu-22.04", pkg-name: "fabric", pytorch-version: "2.1", requires: "oldest" } - - { os: "windows-2022", pkg-name: "fabric", pytorch-version: "2.1", requires: "oldest" } + - { pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.6" } + # "fabric" installs the standalone package - - { os: "macOS-14", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.7" } - - { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.7" } - - { os: "windows-2022", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.7" } + - { pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.7" } + # adding recently cut Torch 2.7 - FUTURE - - { os: "macOS-14", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.8" } - - { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.8" } - - { os: "windows-2022", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.8" } + - { pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.8" } + + # "oldest" versions tests, only on minimum Python + - { pkg-name: "fabric", pytorch-version: "2.1", requires: "oldest" } timeout-minutes: 25 # because of building grpcio on Mac env: - PACKAGE_NAME: ${{ matrix.pkg-name }} + PACKAGE_NAME: ${{ matrix.config.pkg-name }} FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }} - PYPI_CACHE_DIR: "_pip-wheels" TORCH_URL_STABLE: "https://download.pytorch.org/whl/cpu/" TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/" # TODO: Remove this - Enable running MPS tests on this platform @@ -83,68 +69,67 @@ jobs: steps: - uses: actions/checkout@v5 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + - name: Install uv and set Python version + uses: astral-sh/setup-uv@v6 with: - python-version: ${{ matrix.python-version || '3.9' }} + python-version: ${{ matrix.config.python-version || '3.9' }} + # TODO: Avoid activating environment like this + # see: https://github.com/astral-sh/setup-uv/tree/v6/?tab=readme-ov-file#activate-environment + activate-environment: true + enable-cache: true + + - name: Basic setup + run: uv pip install -q -r .actions/requirements.txt - - name: basic setup - run: pip install -q -r .actions/requirements.txt + - name: Append Env. vars for Linux + if: ${{ runner.os == 'Linux' }} + run: echo "GLOO_SOCKET_IFNAME=eth0" >> $GITHUB_ENV + + - name: Append Env. vars for MacOS + if: ${{ runner.os == 'macOS' }} + run: echo "GLOO_SOCKET_IFNAME=lo0" >> $GITHUB_ENV + + - name: Append Env. vars for Windows + if: ${{ runner.os == 'windows' }} + run: | + # Avoid issue on Windows with PyTorch 2.4: "RuntimeError: use_libuv was requested but PyTorch was build without libuv support" + echo "USE_LIBUV=0" >> $GITHUB_ENV - name: Set min. dependencies - if: ${{ matrix.requires == 'oldest' }} + if: ${{ matrix.config.requires == 'oldest' }} run: | cd requirements/fabric - pip install -U "lightning-utilities[cli]" + uv pip install -U "lightning-utilities[cli]" python -m lightning_utilities.cli requirements set-oldest --req_files "['base.txt', 'strategies.txt', 'test.txt']" - pip install "cython<3.0" wheel - pip install "pyyaml==5.4" --no-build-isolation + uv pip install "cython<3.0" wheel + uv pip install "pyyaml==5.4" --no-build-isolation - name: Adjust PyTorch versions in requirements files - if: ${{ matrix.requires != 'oldest' }} + if: ${{ matrix.config.requires != 'oldest' }} run: | - pip install -q -r requirements/ci.txt + uv pip install -q -r requirements/ci.txt python -m wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/adjust-torch-versions.py for fpath in `ls requirements/**/*.txt`; do \ - python ./adjust-torch-versions.py $fpath ${{ matrix.pytorch-version }}; \ + python ./adjust-torch-versions.py $fpath ${{ matrix.config.pytorch-version }}; \ done - - name: pip wheels cache - uses: actions/cache/restore@v4 - with: - path: ${{ env.PYPI_CACHE_DIR }} - key: pypi_wheels - - run: | - mkdir -p $PYPI_CACHE_DIR - ls -lh $PYPI_CACHE_DIR - - name: Expand Env. variables run: | # Switch PyTorch URL between stable and test/future - python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.7' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV + python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.config.pytorch-version }}' == '2.7' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV # Switch coverage scope - python -c "print('COVERAGE_SCOPE=' + str('lightning' if '${{matrix.pkg-name}}' == 'lightning' else 'lightning_fabric'))" >> $GITHUB_ENV + python -c "print('COVERAGE_SCOPE=' + str('lightning' if '${{matrix.config.pkg-name}}' == 'lightning' else 'lightning_fabric'))" >> $GITHUB_ENV # if you install mono-package set dependency only for this subpackage - python -c "print('EXTRA_PREFIX=' + str('' if '${{matrix.pkg-name}}' != 'lightning' else 'fabric-'))" >> $GITHUB_ENV - - name: Append Env. vars for MacOS - if: ${{ runner.os == 'macOS' }} - run: | - # trying to avoid "gloo" issue with SIGABRT - echo "GLOO_SOCKET_IFNAME=lo0" >> $GITHUB_ENV - - name: Append Env. vars for Windows - if: ${{ runner.os == 'windows' }} - run: | - # Avoid issue on Windows with PyTorch 2.4: "RuntimeError: use_libuv was requested but PyTorch was build without libuv support" - echo "USE_LIBUV=0" >> $GITHUB_ENV + python -c "print('EXTRA_PREFIX=' + str('' if '${{matrix.config.pkg-name}}' != 'lightning' else 'fabric-'))" >> $GITHUB_ENV - name: Install package & dependencies timeout-minutes: 20 run: | - pip install -e ".[${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies]" \ - -U --upgrade-strategy=eager --prefer-binary \ - --extra-index-url="${TORCH_URL}" \ - --find-links="${PYPI_CACHE_DIR}" - pip list + uv pip install ".[${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies]" \ + --upgrade \ + --find-links="${TORCH_URL}" + uv pip list + - name: Dump handy wheels if: github.event_name == 'push' && github.ref == 'refs/heads/master' continue-on-error: true @@ -155,7 +140,7 @@ jobs: cache-key: "pypi_wheels" - name: Adjust tests - if: ${{ matrix.pkg-name != 'lightning' }} + if: ${{ matrix.config.pkg-name != 'lightning' }} run: | python .actions/assistant.py copy_replace_imports --source_dir="./tests" \ --source_import="lightning.fabric" --target_import="lightning_fabric" @@ -188,10 +173,13 @@ jobs: with: token: ${{ secrets.CODECOV_TOKEN }} file: tests/tests_fabric/coverage.xml - flags: ${{ env.COVERAGE_SCOPE }},cpu,pytest,python${{ matrix.python-version }} + flags: ${{ env.COVERAGE_SCOPE }},cpu,pytest,python${{ matrix.config.python-version }} name: CPU-coverage fail_ci_if_error: false + - name: Minimize uv cache + run: uv cache prune --ci + fabric-cpu-guardian: runs-on: ubuntu-latest needs: fabric-cpu diff --git a/.github/workflows/ci-tests-pytorch.yml b/.github/workflows/ci-tests-pytorch.yml index e5e9ddd23c06e..41d895f14ae86 100644 --- a/.github/workflows/ci-tests-pytorch.yml +++ b/.github/workflows/ci-tests-pytorch.yml @@ -42,115 +42,106 @@ jobs: strategy: fail-fast: false matrix: - include: + os: [macOS-14, ubuntu-22.04, windows-2022] + config: # only run PyTorch latest - - { os: "macOS-14", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } - - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } - - { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } - - { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" } - - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" } - - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" } - - { os: "macOS-14", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" } - - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" } - - { os: "windows-2022", pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" } - - { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" } - - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" } - - { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" } - - { os: "macOS-14", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" } - - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" } - - { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" } + - { pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } + - { pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.2.2" } + - { pkg-name: "lightning", python-version: "3.11", pytorch-version: "2.3" } + - { pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.4.1" } + - { pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" } + # only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues - - { os: "macOS-14", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.6" } - - { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.6" } - - { os: "windows-2022", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.6" } - # "oldest" versions tests, only on minimum Python - - { os: "macOS-14", pkg-name: "pytorch", pytorch-version: "2.1", requires: "oldest" } - - { os: "ubuntu-22.04", pkg-name: "pytorch", pytorch-version: "2.1", requires: "oldest" } - - { os: "windows-2022", pkg-name: "pytorch", pytorch-version: "2.1", requires: "oldest" } + - { pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.6" } + # "pytorch" installs the standalone package - - { os: "macOS-14", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.7" } - - { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.7" } - - { os: "windows-2022", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.7" } + - { pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.7" } + # adding recently cut Torch 2.7 - FUTURE - - { os: "macOS-14", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.8" } - - { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.8" } - - { os: "windows-2022", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.8" } + - { pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.8" } + + # "oldest" versions tests, only on minimum Python + - { pkg-name: "pytorch", pytorch-version: "2.1", requires: "oldest" } timeout-minutes: 50 env: - PACKAGE_NAME: ${{ matrix.pkg-name }} + PACKAGE_NAME: ${{ matrix.config.pkg-name }} TORCH_URL: "https://download.pytorch.org/whl/cpu/" TORCH_URL_STABLE: "https://download.pytorch.org/whl/cpu/" TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/" FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }} - PYPI_CACHE_DIR: "_pip-wheels" # TODO: Remove this - Enable running MPS tests on this platform DISABLE_MPS: ${{ matrix.os == 'macOS-14' && '1' || '0' }} steps: - uses: actions/checkout@v5 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 + - name: Install uv and set Python version + uses: astral-sh/setup-uv@v6 with: - python-version: ${{ matrix.python-version || '3.9' }} + python-version: ${{ matrix.config.python-version || '3.9' }} + # TODO: Avoid activating environment like this + # see: https://github.com/astral-sh/setup-uv/tree/v6/?tab=readme-ov-file#activate-environment + activate-environment: true + enable-cache: true + + - name: Basic setup + run: uv pip install -q -r .actions/requirements.txt - - name: basic setup - run: pip install -q -r .actions/requirements.txt + - name: Append Env. vars for Linux + if: ${{ runner.os == 'Linux' }} + run: echo "GLOO_SOCKET_IFNAME=eth0" >> $GITHUB_ENV + + - name: Append Env. vars for MacOS + if: ${{ runner.os == 'macOS' }} + run: echo "GLOO_SOCKET_IFNAME=lo0" >> $GITHUB_ENV - name: Set min. dependencies - if: ${{ matrix.requires == 'oldest' }} + if: ${{ matrix.config.requires == 'oldest' }} run: | cd requirements/pytorch - pip install -U "lightning-utilities[cli]" + uv pip install -U "lightning-utilities[cli]" python -m lightning_utilities.cli requirements set-oldest --req_files "['base.txt', 'extra.txt', 'strategies.txt', 'examples.txt', 'test.txt']" - pip install "cython<3.0" wheel - pip install "pyyaml==5.4" --no-build-isolation + uv pip install "cython<3.0" wheel + uv pip install "pyyaml==5.4" --no-build-isolation - name: Adjust PyTorch versions in requirements files - if: ${{ matrix.requires != 'oldest' }} + if: ${{ matrix.config.requires != 'oldest' }} run: | - pip install -q -r requirements/ci.txt + uv pip install -q -r requirements/ci.txt python -m wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/adjust-torch-versions.py for fpath in `ls requirements/**/*.txt`; do \ - python ./adjust-torch-versions.py $fpath ${{ matrix.pytorch-version }}; \ + python ./adjust-torch-versions.py $fpath ${{ matrix.config.pytorch-version }}; \ done cat requirements/pytorch/base.txt - - name: pip wheels cache - uses: actions/cache/restore@v4 - with: - path: ${{ env.PYPI_CACHE_DIR }} - key: pypi_wheels - - run: | - mkdir -p $PYPI_CACHE_DIR - ls -lh $PYPI_CACHE_DIR - - name: Env. variables run: | # Switch PyTorch URL between stable and test/future - python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.7' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV + python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.config.pytorch-version }}' == '2.7' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV # Switch coverage scope - python -c "print('COVERAGE_SCOPE=' + str('lightning' if '${{matrix.pkg-name}}' == 'lightning' else 'pytorch_lightning'))" >> $GITHUB_ENV + python -c "print('COVERAGE_SCOPE=' + str('lightning' if '${{matrix.config.pkg-name}}' == 'lightning' else 'pytorch_lightning'))" >> $GITHUB_ENV # if you install mono-package set dependency only for this subpackage - python -c "print('EXTRA_PREFIX=' + str('' if '${{matrix.pkg-name}}' != 'lightning' else 'pytorch-'))" >> $GITHUB_ENV + python -c "print('EXTRA_PREFIX=' + str('' if '${{matrix.config.pkg-name}}' != 'lightning' else 'pytorch-'))" >> $GITHUB_ENV # Avoid issue on Windows with PyTorch 2.4: "RuntimeError: use_libuv was requested but PyTorch was build without libuv support" - python -c "print('USE_LIBUV=0' if '${{matrix.os}}' == 'windows-2022' and '${{matrix.pytorch-version}}' == '2.4' else '')" >> $GITHUB_ENV + python -c "print('USE_LIBUV=0' if '${{matrix.os}}' == 'windows-2022' and '${{matrix.config.pytorch-version}}' == '2.4' else '')" >> $GITHUB_ENV - name: Install package & dependencies timeout-minutes: 20 run: | - pip install ".[${EXTRA_PREFIX}extra,${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies]" \ - -U --upgrade-strategy=eager --prefer-binary \ + uv pip install ".[${EXTRA_PREFIX}extra,${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies]" \ + --upgrade \ -r requirements/_integrations/accelerators.txt \ - --extra-index-url="${TORCH_URL}" \ - --find-links="${PYPI_CACHE_DIR}" - pip list + --find-links="${TORCH_URL}" + uv pip list + - name: Drop LAI from extensions - if: ${{ matrix.pkg-name != 'lightning' }} + if: ${{ matrix.config.pkg-name != 'lightning' }} # Lightning is dependency of Habana or other accelerators/integrations so in case we test PL we need to remove it - run: pip uninstall -y lightning + run: uv pip uninstall lightning + - name: Drop PL for LAI - if: ${{ matrix.pkg-name == 'lightning' }} - run: pip uninstall -y pytorch-lightning + if: ${{ matrix.config.pkg-name == 'lightning' }} + run: uv pip uninstall pytorch-lightning + - name: Dump handy wheels if: github.event_name == 'push' && github.ref == 'refs/heads/master' continue-on-error: true @@ -170,10 +161,10 @@ jobs: run: | set -e python requirements/pytorch/check-avail-extras.py - python -c "from torch import __version__ as ver; assert ver.startswith('${{ matrix.pytorch-version }}'), ver" + python -c "from torch import __version__ as ver; assert ver.startswith('${{ matrix.config.pytorch-version }}'), ver" - name: Adjust tests / env. -> PL - if: ${{ matrix.pkg-name != 'lightning' }} + if: ${{ matrix.config.pkg-name != 'lightning' }} run: | python .actions/assistant.py copy_replace_imports --source_dir="./tests" \ --source_import="lightning.fabric,lightning.pytorch" \ @@ -223,10 +214,13 @@ jobs: with: token: ${{ secrets.CODECOV_TOKEN }} file: tests/tests_pytorch/coverage.xml - flags: ${{ env.COVERAGE_SCOPE }},cpu,pytest-full,python${{ matrix.python-version }},pytorch${{ matrix.pytorch-version }} + flags: ${{ env.COVERAGE_SCOPE }},cpu,pytest-full,python${{ matrix.config.python-version }},pytorch${{ matrix.config.pytorch-version }} name: CPU-coverage fail_ci_if_error: false + - name: Minimize uv cache + run: uv cache prune --ci + pl-cpu-guardian: runs-on: ubuntu-latest needs: pl-cpu diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index 69d35e605db5b..5fb961798ad22 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -165,12 +165,12 @@ jobs: path: docs/build/html/ - name: Authenticate to Google Cloud - uses: google-github-actions/auth@v2 + uses: google-github-actions/auth@v3 with: credentials_json: ${{ secrets.GCS_SA_KEY }} - name: Setup gcloud - uses: google-github-actions/setup-gcloud@v2 + uses: google-github-actions/setup-gcloud@v3 with: project_id: ${{ secrets.GCS_PROJECT }} diff --git a/.lightning/workflows/fabric.yml b/.lightning/workflows/fabric.yml index 438f56ef7fe94..e6d97ca1b1887 100644 --- a/.lightning/workflows/fabric.yml +++ b/.lightning/workflows/fabric.yml @@ -4,20 +4,22 @@ trigger: pull_request: branches: ["master"] -timeout: "75" # minutes -machine: "L4_X_2" +timeout: "55" # minutes parametrize: matrix: {} include: - # note that this is setting also all oldest requirements which is linked to Torch == 2.0 + # note that this is setting also all oldest requirements which is linked to Torch == 2.1 - image: "pytorchlightning/pytorch_lightning:base-cuda12.1.1-py3.10-torch2.1" PACKAGE_NAME: "fabric" - - image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.7" + machine: "A100_X_2" + - image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.8" PACKAGE_NAME: "fabric" + machine: "L4_X_2" # - image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.7" # PACKAGE_NAME: "fabric" - - image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.7" + - image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.8" PACKAGE_NAME: "lightning" + machine: "L4_X_2" exclude: [] env: @@ -30,6 +32,7 @@ run: | python --version pip --version pip install -q fire wget packaging + pip list set -ex CUDA_VERSION="${image##*cuda}" # Remove everything up to and including "cuda" @@ -40,12 +43,15 @@ run: | echo "Torch URL: ${TORCH_URL}" COVERAGE_SOURCE=$(python -c 'n = "$(PACKAGE_NAME)" ; print(dict(fabric="lightning_fabric").get(n, n))') echo "collecting coverage for: ${COVERAGE_SOURCE}" + TORCH_VER=$(python -c "import torch; print(torch.__version__.rsplit('.', 1)[0])") if [ "${TORCH_VER}" == "2.1" ]; then echo "Set oldest versions" - cd requirements/fabric + pip uninstall -y deepspeed pip install -U "lightning-utilities[cli]" + cd requirements/fabric python -m lightning_utilities.cli requirements set-oldest --req_files "['base.txt', 'strategies.txt']" + python -m lightning_utilities.cli requirements prune-pkgs --packages deepspeed --req_files strategies.txt cd ../.. pip install "cython<3.0" wheel # for compatibility fi @@ -92,6 +98,7 @@ run: | export PL_RUN_STANDALONE_TESTS=1 wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/run_standalone_tests.sh bash ./run_standalone_tests.sh "tests_fabric" + export PL_RUN_STANDALONE_TESTS=0 # echo "Reporting coverage" # todo # python -m coverage report diff --git a/.lightning/workflows/pytorch.yml b/.lightning/workflows/pytorch.yml index 5c92bf881d969..21551565b1fed 100644 --- a/.lightning/workflows/pytorch.yml +++ b/.lightning/workflows/pytorch.yml @@ -4,20 +4,22 @@ trigger: pull_request: branches: ["master"] -timeout: "75" # minutes -machine: "L4_X_2" +timeout: "55" # minutes parametrize: matrix: {} include: - # note that this is setting also all oldest requirements which is linked to Torch == 2.0 + # note that this is setting also all oldest requirements which is linked to Torch == 2.1 - image: "pytorchlightning/pytorch_lightning:base-cuda12.1.1-py3.10-torch2.1" PACKAGE_NAME: "pytorch" - - image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.7" + machine: "A100_X_2" + - image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.8" PACKAGE_NAME: "pytorch" + machine: "L4_X_2" # - image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.7" # PACKAGE_NAME: "pytorch" - - image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.7" + - image: "pytorchlightning/pytorch_lightning:base-cuda12.6.3-py3.12-torch2.8" PACKAGE_NAME: "lightning" + machine: "L4_X_2" exclude: [] env: @@ -30,6 +32,7 @@ run: | python --version pip --version pip install -q fire wget packaging + pip list set -ex CUDA_VERSION="${image##*cuda}" # Remove everything up to and including "cuda" @@ -40,12 +43,15 @@ run: | echo "Torch URL: ${TORCH_URL}" COVERAGE_SOURCE=$(python -c 'n = "$(PACKAGE_NAME)" ; print(dict(fabric="pytorch_lightning").get(n, n))') echo "collecting coverage for: ${COVERAGE_SOURCE}" + TORCH_VER=$(python -c "import torch; print(torch.__version__.rsplit('.', 1)[0])") if [ "${TORCH_VER}" == "2.1" ]; then - recho "Set oldest versions" - cd requirements/pytorch + echo "Set oldest versions" + pip uninstall -y deepspeed pip install -U "lightning-utilities[cli]" + cd requirements/pytorch python -m lightning_utilities.cli requirements set-oldest --req_files "['base.txt', 'extra.txt', 'strategies.txt', 'examples.txt']" + python -m lightning_utilities.cli requirements prune-pkgs --packages deepspeed --req_files strategies.txt cd ../.. pip install "cython<3.0" wheel # for compatibility fi @@ -108,6 +114,7 @@ run: | export PL_RUN_STANDALONE_TESTS=1 wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/run_standalone_tests.sh bash ./run_standalone_tests.sh "tests_pytorch" + export PL_RUN_STANDALONE_TESTS=0 echo "Testing: PyTorch standalone tasks" cd tests_pytorch/ diff --git a/dockers/base-cuda/Dockerfile b/dockers/base-cuda/Dockerfile index 2b6f48771c7f7..41faf0ca55113 100644 --- a/dockers/base-cuda/Dockerfile +++ b/dockers/base-cuda/Dockerfile @@ -13,7 +13,7 @@ # limitations under the License. ARG UBUNTU_VERSION=22.04 -ARG CUDA_VERSION=11.7.1 +ARG CUDA_VERSION=12.1.1 FROM nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION} diff --git a/docs/source-fabric/guide/index.rst b/docs/source-fabric/guide/index.rst index 669501ec09864..6444a6bbdd17a 100644 --- a/docs/source-fabric/guide/index.rst +++ b/docs/source-fabric/guide/index.rst @@ -78,7 +78,7 @@ Build your own Trainer
.. displayitem:: - :header: Organize your model code with with LightningModule + :header: Organize your model code with LightningModule :description: Organize your code in a LightningModule and use it with Fabric :button_link: lightning_module.html :col_css: col-md-4 diff --git a/docs/source-fabric/levels/intermediate.rst b/docs/source-fabric/levels/intermediate.rst index f21e7c96608ab..30f27bdd1d544 100644 --- a/docs/source-fabric/levels/intermediate.rst +++ b/docs/source-fabric/levels/intermediate.rst @@ -19,7 +19,7 @@ Intermediate skills
.. displayitem:: - :header: Organize your model code with with LightningModule + :header: Organize your model code with LightningModule :description: Organize your code in a LightningModule and use it with Fabric :button_link: ../guide/lightning_module.html :col_css: col-md-4 diff --git a/docs/source-pytorch/accelerators/gpu_faq.rst b/docs/source-pytorch/accelerators/gpu_faq.rst index 4cc05555bb559..346843e316bce 100644 --- a/docs/source-pytorch/accelerators/gpu_faq.rst +++ b/docs/source-pytorch/accelerators/gpu_faq.rst @@ -5,31 +5,71 @@ GPU training (FAQ) ================== -****************************************************************** -How should I adjust the learning rate when using multiple devices? -****************************************************************** +*************************************************************** +How should I adjust the batch size when using multiple devices? +*************************************************************** -When using distributed training make sure to modify your learning rate according to your effective -batch size. +Lightning automatically shards your data across multiple GPUs, meaning that each device only sees a unique subset of your +data, but the `batch_size` in your DataLoader remains the same. This means that the effective batch size e.g. the +total number of samples processed in one forward/backward pass is -Let's say you have a batch size of 7 in your dataloader. +.. math:: -.. testcode:: + \text{Effective Batch Size} = \text{DataLoader Batch Size} \times \text{Number of Devices} \times \text{Number of Nodes} - class LitModel(LightningModule): - def train_dataloader(self): - return Dataset(..., batch_size=7) - -Whenever you use multiple devices and/or nodes, your effective batch size will be 7 * devices * num_nodes. +A couple of examples to illustrate this: .. code-block:: python - # effective batch size = 7 * 8 + dataloader = DataLoader(..., batch_size=7) + + # Single GPU: effective batch size = 7 + Trainer(accelerator="gpu", devices=1) + + # Multi-GPU: effective batch size = 7 * 8 = 56 Trainer(accelerator="gpu", devices=8, strategy=...) - # effective batch size = 7 * 8 * 10 + # Multi-node: effective batch size = 7 * 8 * 10 = 560 Trainer(accelerator="gpu", devices=8, num_nodes=10, strategy=...) +In general you should be able to use the same `batch_size` in your DataLoader regardless of the number of devices you are +using. + +.. note:: + + If you want distributed training to work exactly the same as single GPU training, you need to set the `batch_size` + in your DataLoader to `original_batch_size / num_devices` to maintain the same effective batch size. However, this + can lead to poor GPU utilization. + +---- + +****************************************************************** +How should I adjust the learning rate when using multiple devices? +****************************************************************** + +Because the effective batch size is larger when using multiple devices, you need to adjust your learning rate +accordingly. Because the learning rate is a hyperparameter that controls how much to change the model in response to +the estimated error each time the model weights are updated, it is important to scale it with the effective batch size. + +In general, there are two common scaling rules: + +1. **Linear scaling**: Increase the learning rate linearly with the number of devices. + + .. code-block:: python + + # Example: Linear scaling + base_lr = 1e-3 + num_devices = 8 + scaled_lr = base_lr * num_devices # 8e-3 + +2. **Square root scaling**: Increase the learning rate by the square root of the number of devices. + + .. code-block:: python + + # Example: Square root scaling + base_lr = 1e-3 + num_devices = 8 + scaled_lr = base_lr * (num_devices ** 0.5) # 2.83e-3 .. note:: Huge batch sizes are actually really bad for convergence. Check out: `Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour `_ diff --git a/docs/source-pytorch/common/hooks.rst b/docs/source-pytorch/common/hooks.rst new file mode 100644 index 0000000000000..94531ab818d98 --- /dev/null +++ b/docs/source-pytorch/common/hooks.rst @@ -0,0 +1,325 @@ +########################## +Hooks in PyTorch Lightning +########################## + +Hooks in Pytorch Lightning allow you to customize the training, validation, and testing logic of your models. They +provide a way to insert custom behavior at specific points during the training process without modifying the core +training loop. There are several categories of hooks available in PyTorch Lightning: + +1. **Setup/Teardown Hooks**: Called at the beginning and end of training phases +2. **Training Hooks**: Called during the training loop +3. **Validation Hooks**: Called during validation +4. **Test Hooks**: Called during testing +5. **Prediction Hooks**: Called during prediction +6. **Optimizer Hooks**: Called around optimizer operations +7. **Checkpoint Hooks**: Called during checkpoint save/load operations +8. **Exception Hooks**: Called when exceptions occur + +Nearly all hooks can be implemented in three places within your code: + +- **LightningModule**: The main module where you define your model and training logic. +- **Callbacks**: Custom classes that can be passed to the Trainer to handle specific events. +- **Strategy**: Custom strategies for distributed training. + +Importantly, because logic can be place in the same hook but in different places the call order of hooks is in +important to understand. The following order is always used: + +1. Callbacks, called in the order they are passed to the Trainer. +2. ``LightningModule`` +3. Strategy + +.. testcode:: + + from lightning.pytorch import Trainer + from lightning.pytorch.callbacks import Callback + from lightning.pytorch.demos import BoringModel + + class MyModel(BoringModel): + def on_train_start(self): + print("Model: Training is starting!") + + class MyCallback(Callback): + def on_train_start(self, trainer, pl_module): + print("Callback: Training is starting!") + + model = MyModel() + callback = MyCallback() + trainer = Trainer(callbacks=[callback], logger=False, max_epochs=1, enable_progress_bar=False) + trainer.fit(model) + +.. testoutput:: + :hide: + :options: +ELLIPSIS, +NORMALIZE_WHITESPACE + + Callback: Training is starting! + Model: Training is starting! + + +.. note:: + There are a few exceptions to this pattern: + + - **on_train_epoch_end**: Non-monitoring callbacks are called first, then ``LightningModule``, then monitoring callbacks + - **Optimizer hooks** (on_before_backward, on_after_backward, on_before_optimizer_step): Only callbacks and ``LightningModule`` are called + - Some internal hooks may only call ``LightningModule`` or Strategy + +************************ +Training Loop Hook Order +************************ + +The following diagram shows the execution order of hooks during a typical training loop e.g. calling `trainer.fit()`, +with the source of each hook indicated: + +.. code-block:: text + + Training Process Flow: + + trainer.fit() + │ + ├── setup(stage="fit") + │ ├── [LightningDataModule] + │ ├── [Callbacks] + │ ├── [LightningModule] + │ ├── [LightningModule.configure_shared_model()] + │ ├── [LightningModule.configure_model()] + │ ├── Strategy.restore_checkpoint_before_setup + │ │ ├── [LightningModule.on_load_checkpoint()] + │ │ ├── [LightningModule.load_state_dict()] + │ │ ├── [LightningDataModule.load_state_dict()] + │ │ ├── [Callbacks.on_load_checkpoint()] + │ │ └── [Callbacks.load_state_dict()] + │ └── [Strategy] + │ + ├── on_fit_start() + │ ├── [Callbacks] + │ └── [LightningModule] + │ + ├── Strategy.restore_checkpoint_after_setup + │ ├── [LightningModule.on_load_checkpoint()] + │ ├── [LightningModule.load_state_dict()] + │ ├── [LightningDataModule.load_state_dict()] + │ ├── [Callbacks.on_load_checkpoint()] + │ └── [Callbacks.load_state_dict()] + │ + ├── on_sanity_check_start() + │ ├── [Callbacks] + │ ├── [LightningModule] + │ └── [Strategy] + │ ├── on_validation_start() + │ │ ├── [Callbacks] + │ │ ├── [LightningModule] + │ │ └── [Strategy] + │ ├── on_validation_epoch_start() + │ │ ├── [Callbacks] + │ │ ├── [LightningModule] + │ │ └── [Strategy] + │ │ ├── [for each validation batch] + │ │ │ ├── on_validation_batch_start() + │ │ │ │ ├── [Callbacks] + │ │ │ │ ├── [LightningModule] + │ │ │ │ └── [Strategy] + │ │ │ └── on_validation_batch_end() + │ │ │ ├── [Callbacks] + │ │ │ ├── [LightningModule] + │ │ │ └── [Strategy] + │ │ └── [end validation batches] + │ ├── on_validation_epoch_end() + │ │ ├── [Callbacks] + │ │ ├── [LightningModule] + │ │ └── [Strategy] + │ └── on_validation_end() + │ ├── [Callbacks] + │ ├── [LightningModule] + │ └── [Strategy] + ├── on_sanity_check_end() + │ ├── [Callbacks] + │ ├── [LightningModule] + │ └── [Strategy] + │ + ├── on_train_start() + │ ├── [Callbacks] + │ ├── [LightningModule] + │ └── [Strategy] + │ + ├── [Training Epochs Loop] + │ │ + │ ├── on_train_epoch_start() + │ │ ├── [Callbacks] + │ │ └── [LightningModule] + │ │ + │ ├── [Training Batches Loop] + │ │ │ + │ │ ├── on_train_batch_start() + │ │ │ ├── [Callbacks] + │ │ │ ├── [LightningModule] + │ │ │ └── [Strategy] + │ │ │ + │ │ ├── [Forward Pass - training_step()] + │ │ │ └── [Strategy only] + │ │ │ + │ │ ├── on_before_zero_grad() + │ │ │ ├── [Callbacks] + │ │ │ └── [LightningModule] + │ │ │ + │ │ ├── optimizer_zero_grad() + │ │ │ └── [LightningModule only - optimizer_zero_grad()] + │ │ │ + │ │ ├── [Backward Pass - Strategy.backward()] + │ │ │ ├── on_before_backward() + │ │ │ │ ├── [Callbacks] + │ │ │ │ └── [LightningModule] + │ │ │ ├── LightningModule.backward() + │ │ │ └── on_after_backward() + │ │ │ ├── [Callbacks] + │ │ │ └── [LightningModule] + │ │ │ + │ │ ├── on_before_optimizer_step() + │ │ │ ├── [Callbacks] + │ │ │ └── [LightningModule] + │ │ │ + │ │ ├── [Optimizer Step] + │ │ │ └── [LightningModule only - optimizer_step()] + │ │ │ + │ │ └── on_train_batch_end() + │ │ ├── [Callbacks] + │ │ └── [LightningModule] + │ │ + │ │ [Optional: Validation during training] + │ │ ├── on_validation_start() + │ │ │ ├── [Callbacks] + │ │ │ ├── [LightningModule] + │ │ │ └── [Strategy] + │ │ ├── on_validation_epoch_start() + │ │ │ ├── [Callbacks] + │ │ │ ├── [LightningModule] + │ │ │ └── [Strategy] + │ │ │ ├── [for each validation batch] + │ │ │ │ ├── on_validation_batch_start() + │ │ │ │ │ ├── [Callbacks] + │ │ │ │ │ ├── [LightningModule] + │ │ │ │ │ └── [Strategy] + │ │ │ │ └── on_validation_batch_end() + │ │ │ │ ├── [Callbacks] + │ │ │ │ ├── [LightningModule] + │ │ │ │ └── [Strategy] + │ │ │ └── [end validation batches] + │ │ ├── on_validation_epoch_end() + │ │ │ ├── [Callbacks] + │ │ │ ├── [LightningModule] + │ │ │ └── [Strategy] + │ │ └── on_validation_end() + │ │ ├── [Callbacks] + │ │ ├── [LightningModule] + │ │ └── [Strategy] + │ │ + │ └── on_train_epoch_end() **SPECIAL CASE** + │ ├── [Callbacks - Non-monitoring only] + │ ├── [LightningModule] + │ └── [Callbacks - Monitoring only] + │ + ├── [End Training Epochs] + │ + ├── on_train_end() + │ ├── [Callbacks] + │ ├── [LightningModule] + │ └── [Strategy] + │ + └── teardown(stage="fit") + ├── [Strategy] + ├── on_fit_end() + │ ├── [Callbacks] + │ └── [LightningModule] + ├── [LightningDataModule] + ├── [Callbacks] + └── [LightningModule] + +*********************** +Testing Loop Hook Order +*********************** + +When running tests with ``trainer.test()``: + +.. code-block:: text + + trainer.test() + │ + ├── setup(stage="test") + │ └── [Callbacks only] + ├── on_test_start() + │ ├── [Callbacks] + │ ├── [LightningModule] + │ └── [Strategy] + │ + ├── [Test Epochs Loop] + │ │ + │ ├── on_test_epoch_start() + │ │ ├── [Callbacks] + │ │ ├── [LightningModule] + │ │ └── [Strategy] + │ │ + │ ├── [Test Batches Loop] + │ │ │ + │ │ ├── on_test_batch_start() + │ │ │ ├── [Callbacks] + │ │ │ ├── [LightningModule] + │ │ │ └── [Strategy] + │ │ │ + │ │ └── on_test_batch_end() + │ │ ├── [Callbacks] + │ │ ├── [LightningModule] + │ │ └── [Strategy] + │ │ + │ └── on_test_epoch_end() + │ ├── [Callbacks] + │ ├── [LightningModule] + │ └── [Strategy] + │ + ├── on_test_end() + │ ├── [Callbacks] + │ ├── [LightningModule] + │ └── [Strategy] + └── teardown(stage="test") + └── [Callbacks only] + +************************** +Prediction Loop Hook Order +************************** + +When running predictions with ``trainer.predict()``: + +.. code-block:: text + + trainer.predict() + │ + ├── setup(stage="predict") + │ └── [Callbacks only] + ├── on_predict_start() + │ ├── [Callbacks] + │ ├── [LightningModule] + │ └── [Strategy] + │ + ├── [Prediction Epochs Loop] + │ │ + │ ├── on_predict_epoch_start() + │ │ ├── [Callbacks] + │ │ └── [LightningModule] + │ │ + │ ├── [Prediction Batches Loop] + │ │ │ + │ │ ├── on_predict_batch_start() + │ │ │ ├── [Callbacks] + │ │ │ └── [LightningModule] + │ │ │ + │ │ └── on_predict_batch_end() + │ │ ├── [Callbacks] + │ │ └── [LightningModule] + │ │ + │ └── on_predict_epoch_end() + │ ├── [Callbacks] + │ └── [LightningModule] + │ + ├── on_predict_end() + │ ├── [Callbacks] + │ ├── [LightningModule] + │ └── [Strategy] + └── teardown(stage="predict") + └── [Callbacks only] diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst index 86ee52f41f0c9..a81934c104b81 100644 --- a/docs/source-pytorch/common/trainer.rst +++ b/docs/source-pytorch/common/trainer.rst @@ -510,6 +510,7 @@ limit_train_batches How much of training dataset to check. Useful when debugging or testing something that happens at the end of an epoch. +Value is per device. .. testcode:: @@ -535,7 +536,7 @@ limit_test_batches :width: 400 :muted: -How much of test dataset to check. +How much of test dataset to check. Value is per device. .. testcode:: @@ -560,6 +561,7 @@ limit_val_batches How much of validation dataset to check. Useful when debugging or testing something that happens at the end of an epoch. +Value is per device. .. testcode:: diff --git a/docs/source-pytorch/expertise_levels.rst b/docs/source-pytorch/expertise_levels.rst index 7c34123f5b3f7..672ccf48ab7eb 100644 --- a/docs/source-pytorch/expertise_levels.rst +++ b/docs/source-pytorch/expertise_levels.rst @@ -84,7 +84,7 @@ Learn to scale up your models and enable collaborative model development at acad .. Add callout items below this line .. displayitem:: - :header: Level 7: Interactive cloud development + :header: Level 7: Hardware acceleration :description: Learn how to access GPUs and TPUs on the cloud. :button_link: levels/intermediate_level_7.html :col_css: col-md-6 diff --git a/docs/source-pytorch/glossary/index.rst b/docs/source-pytorch/glossary/index.rst index 45683c67c1708..c904932737b02 100644 --- a/docs/source-pytorch/glossary/index.rst +++ b/docs/source-pytorch/glossary/index.rst @@ -20,6 +20,7 @@ FSDP <../advanced/model_parallel/fsdp> GPU <../accelerators/gpu> Half precision <../common/precision> + Hooks <../common/hooks> HPU <../integrations/hpu/index> Inference <../deploy/production_intermediate> Lightning CLI <../cli/lightning_cli> @@ -179,6 +180,13 @@ Glossary :button_link: ../common/precision.html :height: 100 +.. displayitem:: + :header: Hooks + :description: How to customize the training, validation, and testing logic + :col_css: col-md-12 + :button_link: ../common/hooks.html + :height: 100 + .. displayitem:: :header: HPU :description: Habana Gaudi AI Processor Unit for faster training diff --git a/docs/source-pytorch/levels/intermediate.rst b/docs/source-pytorch/levels/intermediate.rst index 282ac7bc98c90..797cfbdf0c2e2 100644 --- a/docs/source-pytorch/levels/intermediate.rst +++ b/docs/source-pytorch/levels/intermediate.rst @@ -16,7 +16,7 @@ Learn to scale up your models and enable collaborative model development at acad .. Add callout items below this line .. displayitem:: - :header: Level 7: Interactive cloud development + :header: Level 7: Hardware acceleration :description: Learn how to access GPUs and TPUs on the cloud. :button_link: intermediate_level_7.html :col_css: col-md-6 diff --git a/docs/source-pytorch/levels/intermediate_level_7.rst b/docs/source-pytorch/levels/intermediate_level_7.rst index ef4122d0d3150..660fd9962c8d2 100644 --- a/docs/source-pytorch/levels/intermediate_level_7.rst +++ b/docs/source-pytorch/levels/intermediate_level_7.rst @@ -1,8 +1,8 @@ :orphan: -###################################### -Level 7: Interactive cloud development -###################################### +############################## +Level 7: Hardware acceleration +############################## Learn to develop models on cloud GPUs and TPUs. diff --git a/docs/source-pytorch/tuning/profiler_basic.rst b/docs/source-pytorch/tuning/profiler_basic.rst index 880381268fa78..01f4d8a51daaf 100644 --- a/docs/source-pytorch/tuning/profiler_basic.rst +++ b/docs/source-pytorch/tuning/profiler_basic.rst @@ -121,3 +121,22 @@ This can be measured with the :class:`~lightning.pytorch.callbacks.device_stats_ CPU metrics will be tracked by default on the CPU accelerator. To enable it for other accelerators set ``DeviceStatsMonitor(cpu_stats=True)``. To disable logging CPU metrics, you can specify ``DeviceStatsMonitor(cpu_stats=False)``. + +.. warning:: + + **Do not wrap** ``Trainer.fit()``, ``Trainer.validate()``, or other Trainer methods inside a manual + ``torch.profiler.profile`` context manager. This will cause unexpected crashes and cryptic errors due to + incompatibility between PyTorch Profiler's context management and Lightning's internal training loop. + Instead, always use the ``profiler`` argument in the ``Trainer`` constructor or the + :class:`~lightning.pytorch.profilers.pytorch.PyTorchProfiler` profiler class if you want to customize the profiling. + + Example: + + .. code-block:: python + + from lightning.pytorch import Trainer + from lightning.pytorch.profilers import PytorchProfiler + + trainer = Trainer(profiler="pytorch") + # or + trainer = Trainer(profiler=PytorchProfiler(dirpath=".", filename="perf_logs")) diff --git a/docs/source-pytorch/versioning.rst b/docs/source-pytorch/versioning.rst index 4a04bd1534de9..296f8052e863d 100644 --- a/docs/source-pytorch/versioning.rst +++ b/docs/source-pytorch/versioning.rst @@ -67,6 +67,13 @@ PyTorch Lightning follows `NEP 29 =2.1.0, <2.9.0 fsspec[http] >=2022.5.0, <2025.8.0 packaging >=20.0, <=25.0 -typing-extensions >4.5.0, <4.15.0 +typing-extensions >4.5.0, <4.16.0 lightning-utilities >=0.10.0, <0.16.0 diff --git a/requirements/fabric/test.txt b/requirements/fabric/test.txt index a9b4271cac2c3..222e9cf412b3f 100644 --- a/requirements/fabric/test.txt +++ b/requirements/fabric/test.txt @@ -1,9 +1,9 @@ -coverage ==7.10.5 +coverage ==7.10.6 numpy >=1.21.0, <1.27.0 pytest ==8.4.1 pytest-cov ==6.2.1 pytest-timeout ==2.4.0 -pytest-rerunfailures ==15.1 +pytest-rerunfailures ==16.0 pytest-random-order ==1.2.0 click ==8.1.8; python_version < "3.11" click ==8.2.1; python_version > "3.10" diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt index ef798883c12ef..6e684d995c1b7 100644 --- a/requirements/pytorch/base.txt +++ b/requirements/pytorch/base.txt @@ -7,5 +7,5 @@ PyYAML >5.4, <6.1.0 fsspec[http] >=2022.5.0, <2025.8.0 torchmetrics >0.7.0, <1.9.0 packaging >=20.0, <=25.0 -typing-extensions >4.5.0, <4.15.0 +typing-extensions >4.5.0, <4.16.0 lightning-utilities >=0.10.0, <0.16.0 diff --git a/requirements/pytorch/docs.txt b/requirements/pytorch/docs.txt index a3e2e88967f75..7ee6c8bb309cb 100644 --- a/requirements/pytorch/docs.txt +++ b/requirements/pytorch/docs.txt @@ -1,7 +1,7 @@ -r ../docs.txt nbformat # used for generate empty notebook -ipython[notebook] <9.5.0 +ipython[notebook] <9.6.0 setuptools<81.0 # workaround for `error in ipython setup command: use_2to3 is invalid.` onnxscript >= 0.2.2, < 0.5.0 diff --git a/requirements/pytorch/examples.txt b/requirements/pytorch/examples.txt index 84ea80df6ff0c..c34309a0234f2 100644 --- a/requirements/pytorch/examples.txt +++ b/requirements/pytorch/examples.txt @@ -3,5 +3,5 @@ requests <2.33.0 torchvision >=0.16.0, <0.24.0 -ipython[all] <8.19.0 +ipython[all] >=8.0.0, <9.0.0 torchmetrics >=0.10.0, <1.9.0 diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index 1fd3ec790055f..86b765c37237f 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -1,15 +1,15 @@ -coverage ==7.10.5 +coverage ==7.10.6 pytest ==8.4.1 pytest-cov ==6.2.1 pytest-timeout ==2.4.0 -pytest-rerunfailures ==15.1 +pytest-rerunfailures ==16.0 pytest-random-order ==1.2.0 # needed in tests cloudpickle >=1.3, <3.2.0 scikit-learn >0.22.1, <1.8.0 numpy >1.20.0, <1.27.0 -onnx >1.12.0, <1.19.0 +onnx >1.12.0, <1.20.0 onnxruntime >=1.12.0, <1.23.0 onnxscript >= 0.1.0, < 0.5.0 psutil <7.0.1 # for `DeviceStatsMonitor` diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index b1102cdce06b7..ae9ab0fae131f 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -6,11 +6,25 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). --- +## [2.5.5] - 2025-09-05 + +### Changed + +- Include `exclude_frozen_parameters` to `DeepSpeedStrategy` ([#21060](https://github.com/Lightning-AI/pytorch-lightning/pull/21060)) +- Let `_get_default_process_group_backend_for_device` support more hardware platforms ( + [#21057](https://github.com/Lightning-AI/pytorch-lightning/pull/21057), [#21093](https://github.com/Lightning-AI/pytorch-lightning/pull/21093)) + +### Fixed + +- Fixed with adding a missing device id for pytorch 2.8 ([#21105](https://github.com/Lightning-AI/pytorch-lightning/pull/21105)) +- Respecting `verbose=False` in `seed_everything` when no seed is provided ([#21161](https://github.com/Lightning-AI/pytorch-lightning/pull/21161)) + + ## [2.5.4] - 2025-08-29 ### Changed -- Added support for NVIDIA H200 GPUs in `get_available_flops` ([#20913](https://github.com/Lightning-AI/pytorch-lightning/pull/21119)) +- Added support for NVIDIA H200 GPUs in `get_available_flops` ([#21119](https://github.com/Lightning-AI/pytorch-lightning/pull/21119)) ## [2.5.3] - 2025-08-13 @@ -45,7 +59,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Removed -- Removed legacy support for `lightning run model`. Use `fabric run` instead ([#20588](https://github.com/Lightning-AI/pytorch-lightning/pull/20588)) +- Removed legacy support for `lightning run model`; use `fabric run` instead ([#20588](https://github.com/Lightning-AI/pytorch-lightning/pull/20588)) ## [2.5.0] - 2024-12-19 diff --git a/src/lightning/fabric/strategies/ddp.py b/src/lightning/fabric/strategies/ddp.py index ce47e4e403c34..af182ad7f422f 100644 --- a/src/lightning/fabric/strategies/ddp.py +++ b/src/lightning/fabric/strategies/ddp.py @@ -41,6 +41,7 @@ _sync_ddp_if_available, ) from lightning.fabric.utilities.distributed import group as _group +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 from lightning.fabric.utilities.rank_zero import rank_zero_only _DDP_FORK_ALIASES = ( @@ -159,7 +160,17 @@ def barrier(self, *args: Any, **kwargs: Any) -> None: if torch.distributed.get_backend() == "nccl": torch.distributed.barrier(device_ids=self._determine_ddp_device_ids()) else: - torch.distributed.barrier() + # Handle PyTorch bug where barrier() fails on CPU with "PrivateUse1HooksInterface" error + try: + torch.distributed.barrier() + except RuntimeError as e: + if "PrivateUse1HooksInterface" in str(e): + # Fallback: Use all_reduce as barrier - all processes must participate + # This achieves the same synchronization effect as barrier() + dummy_tensor = torch.tensor(0.0, device=self.root_device) + torch.distributed.all_reduce(dummy_tensor) + else: + raise @override def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: @@ -212,7 +223,10 @@ def _setup_distributed(self) -> None: self._set_world_ranks() self._process_group_backend = self._get_process_group_backend() assert self.cluster_environment is not None - _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) + kwargs: dict[str, Any] = {"timeout": self._timeout} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None + _init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs) def _get_process_group_backend(self) -> str: return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 322fa1899b0ee..d21182b525f66 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -99,6 +99,7 @@ def __init__( precision: Optional[Precision] = None, process_group_backend: Optional[str] = None, timeout: Optional[timedelta] = default_pg_timeout, + exclude_frozen_parameters: bool = False, ) -> None: """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large billion parameter models. `For more information: https://pytorch- @@ -228,6 +229,8 @@ def __init__( when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards per worker. + exclude_frozen_parameters: Exclude frozen parameters when saving checkpoints. + """ if not _DEEPSPEED_AVAILABLE: raise ImportError( @@ -288,6 +291,7 @@ def __init__( self.remote_device = remote_device self.load_full_weights = load_full_weights + self.exclude_frozen_parameters = exclude_frozen_parameters # default FP16 parameters. self.loss_scale = loss_scale @@ -444,7 +448,9 @@ def save_checkpoint( # there might be other stateful objects unrelated to the deepspeed engine - convert them to a state_dict state = self._convert_stateful_objects_in_state(state, filter={}) # use deepspeed's internal checkpointing function to handle partitioned weights across processes - engine.save_checkpoint(path, client_state=state, tag="checkpoint") + engine.save_checkpoint( + path, client_state=state, tag="checkpoint", exclude_frozen_parameters=self.exclude_frozen_parameters + ) @override def load_checkpoint( diff --git a/src/lightning/fabric/strategies/fsdp.py b/src/lightning/fabric/strategies/fsdp.py index 74af1dd0e8f43..25b50d6a67332 100644 --- a/src/lightning/fabric/strategies/fsdp.py +++ b/src/lightning/fabric/strategies/fsdp.py @@ -662,7 +662,10 @@ def _setup_distributed(self) -> None: self._set_world_ranks() self._process_group_backend = self._get_process_group_backend() assert self.cluster_environment is not None - _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) + kwargs: dict[str, Any] = {"timeout": self._timeout} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None + _init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs) def _get_process_group_backend(self) -> str: return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index ace23a9c7a2c5..0d49ddf91a0bc 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -302,7 +302,10 @@ def _setup_distributed(self) -> None: self._set_world_ranks() self._process_group_backend = self._get_process_group_backend() assert self.cluster_environment is not None - _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) + kwargs: dict[str, Any] = {"timeout": self._timeout} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None + _init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs) def _get_process_group_backend(self) -> str: return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) diff --git a/src/lightning/fabric/utilities/distributed.py b/src/lightning/fabric/utilities/distributed.py index ec4eb261f2d3e..500f3a3e2aa92 100644 --- a/src/lightning/fabric/utilities/distributed.py +++ b/src/lightning/fabric/utilities/distributed.py @@ -319,7 +319,11 @@ def _destroy_dist_connection() -> None: def _get_default_process_group_backend_for_device(device: torch.device) -> str: - return "nccl" if device.type == "cuda" else "gloo" + """Return corresponding distributed backend for a given device.""" + device_backend_map = torch.distributed.Backend.default_device_backend_map + if device.type in device_backend_map: + return device_backend_map[device.type] + return "gloo" class _DatasetSamplerWrapper(Dataset): diff --git a/src/lightning/fabric/utilities/seed.py b/src/lightning/fabric/utilities/seed.py index 534e5e3db653e..841fa195696a2 100644 --- a/src/lightning/fabric/utilities/seed.py +++ b/src/lightning/fabric/utilities/seed.py @@ -40,7 +40,8 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False, verbose: env_seed = os.environ.get("PL_GLOBAL_SEED") if env_seed is None: seed = 0 - rank_zero_warn(f"No seed found, seed set to {seed}") + if verbose: + rank_zero_warn(f"No seed found, seed set to {seed}") else: try: seed = int(env_seed) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 01b64c38051b3..55519273e0b27 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -6,6 +6,22 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). --- +## [2.5.5] - 2025-09-05 + +### Changed + +- Include `exclude_frozen_parameters` to `DeepSpeedStrategy` ([#21060](https://github.com/Lightning-AI/pytorch-lightning/pull/21060)) +- Include `PossibleUserWarning` that is raised if modules are in eval mode when training starts ([#21146](https://github.com/Lightning-AI/pytorch-lightning/pull/21146)) + +### Fixed + +- Fixed `LightningCLI` not using `ckpt_path` hyperparameters to instantiate classes ([#21116](https://github.com/Lightning-AI/pytorch-lightning/pull/21116)) +- Fixed callbacks by defer step/time-triggered `ModelCheckpoint` saves until validation metrics are available ([#21106](https://github.com/Lightning-AI/pytorch-lightning/pull/21106)) +- Fixed with adding a missing device id for pytorch 2.8 ([#21105](https://github.com/Lightning-AI/pytorch-lightning/pull/21105)) +- Fixed `TQDMProgressBar` not resetting correctly when using both a finite and iterable dataloader ([#21147](https://github.com/Lightning-AI/pytorch-lightning/pull/21147)) +- Fixed cleanup of temporary files from `Tuner` on crashes ([#21162](https://github.com/Lightning-AI/pytorch-lightning/pull/21162)) + + ## [2.5.4] - 2025-08-29 ### Fixed diff --git a/src/lightning/pytorch/callbacks/device_stats_monitor.py b/src/lightning/pytorch/callbacks/device_stats_monitor.py index 6279dd13be4af..873c4c05f5aed 100644 --- a/src/lightning/pytorch/callbacks/device_stats_monitor.py +++ b/src/lightning/pytorch/callbacks/device_stats_monitor.py @@ -34,6 +34,67 @@ class DeviceStatsMonitor(Callback): r"""Automatically monitors and logs device stats during training, validation and testing stage. ``DeviceStatsMonitor`` is a special callback as it requires a ``logger`` to passed as argument to the ``Trainer``. + **Logged Metrics** + + Logs device statistics with keys prefixed as ``DeviceStatsMonitor.{hook_name}/{base_metric_name}``. + The actual metrics depend on the active accelerator and the ``cpu_stats`` flag. Below are an overview of the + possible available metrics and their meaning. + + - CPU (via ``psutil``) + + - ``cpu_percent`` — System-wide CPU utilization (%) + - ``cpu_vm_percent`` — System-wide virtual memory (RAM) utilization (%) + - ``cpu_swap_percent`` — System-wide swap memory utilization (%) + + - CUDA GPU (via ``torch.cuda.memory_stats``) + + Logs memory statistics from PyTorch caching allocator (all in bytes). + GPU compute utilization is not logged by default. + + - General Memory Usage: + + - ``allocated_bytes.all.current`` — Current allocated GPU memory + - ``allocated_bytes.all.peak`` — Peak allocated GPU memory + - ``reserved_bytes.all.current`` — Current reserved GPU memory (allocated + cached) + - ``reserved_bytes.all.peak`` — Peak reserved GPU memory + - ``active_bytes.all.current`` — Current GPU memory in active use + - ``active_bytes.all.peak`` — Peak GPU memory in active use + - ``inactive_split_bytes.all.current`` — Memory in inactive, splittable blocks + + - Allocator Pool Statistics* (for ``small_pool`` and ``large_pool``): + + - ``allocated_bytes.{pool_type}.current`` / ``allocated_bytes.{pool_type}.peak`` + - ``reserved_bytes.{pool_type}.current`` / ``reserved_bytes.{pool_type}.peak`` + - ``active_bytes.{pool_type}.current`` / ``active_bytes.{pool_type}.peak`` + + - Allocator Events: + + - ``num_ooms`` — Cumulative out-of-memory errors + - ``num_alloc_retries`` — Number of allocation retries + - ``num_device_alloc`` — Number of device allocations + - ``num_device_free`` — Number of device deallocations + + For a full list of CUDA memory stats, see the + `PyTorch documentation `_. + + - TPU (via ``torch_xla``) + + - *Memory Metrics* (per device, e.g., ``xla:0``): + + - ``memory.free.xla:0`` — Free HBM memory (MB) + - ``memory.used.xla:0`` — Used HBM memory (MB) + - ``memory.percent.xla:0`` — Percentage of HBM memory used (%) + + - *XLA Operation Counters*: + + - ``CachedCompile.xla`` + - ``CreateXlaTensor.xla`` + - ``DeviceDataCacheMiss.xla`` + - ``UncachedCompile.xla`` + - ``xla::add.xla``, ``xla::addmm.xla``, etc. + + These counters can be retrieved using: ``torch_xla.debug.metrics.counter_names()`` + Args: cpu_stats: if ``None``, it will log CPU stats only if the accelerator is CPU. If ``True``, it will log CPU stats regardless of the accelerator. diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 68fed2ff82d31..452e8bdecbba3 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -260,6 +260,9 @@ def __init__( self.best_model_path = "" self.last_model_path = "" self._last_checkpoint_saved = "" + # When using step/time-based checkpointing with a validation-only monitored metric, + # defer the save until validation has produced the metric + self._defer_save_until_validation: bool = False self.kth_value: Tensor self.dirpath: Optional[_PATH] @@ -306,14 +309,17 @@ def on_train_batch_end( batch_idx: int, ) -> None: """Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`""" - if self._should_skip_saving_checkpoint(trainer): - return + # Do not return early here because we may need to set deferral flags even + # if a save already happened at this global step. We'll enforce the skip + # just before actually saving below. + skip_due_to_state = self._should_skip_saving_checkpoint(trainer) skip_batch = self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0) train_time_interval = self._train_time_interval skip_time = True now = time.monotonic() - if train_time_interval: + # Important: allow zero timedelta as a valid interval + if train_time_interval is not None: prev_time_check = self._last_time_checked skip_time = prev_time_check is None or (now - prev_time_check) < train_time_interval.total_seconds() # in case we have time differences across ranks @@ -326,6 +332,42 @@ def on_train_batch_end( self._last_time_checked = now monitor_candidates = self._monitor_candidates(trainer) + # If monitoring a metric that is not yet available (e.g., validation-only), + # defer saving until validation end so the metric is present. + if self.monitor is not None and self.monitor not in monitor_candidates: + # Defer both top-k and last to avoid blocking with `_last_global_step_saved` + self._defer_save_until_validation = True + return + + # Even if the monitored key exists, it could be stale from a previous validation. + # If validation is scheduled to run right after this batch (e.g., last batch of epoch) + # and we are not saving at train epoch end, defer to `on_validation_end` to use fresh metrics. + if ( + self.monitor is not None + and not self._should_save_on_train_epoch_end(trainer) + and getattr(trainer.fit_loop.epoch_loop.batch_progress, "is_last_batch", False) + ): + # Only defer if a validation loop is expected to run after this batch. + will_run_val = False + if getattr(trainer, "enable_validation", False): + num_val_batches = ( + sum(trainer.num_val_batches) + if isinstance(trainer.num_val_batches, list) + else trainer.num_val_batches + ) + if num_val_batches and num_val_batches > 0: + cve = trainer.check_val_every_n_epoch + if cve is None or ((trainer.current_epoch + 1) % cve == 0): + will_run_val = True + + if will_run_val: + self._defer_save_until_validation = True + return + + # Only proceed to save if not skipping due to trainer/callback state + if skip_due_to_state: + return + self._save_topk_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates) @@ -343,6 +385,14 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul """Save a checkpoint at the end of the validation stage.""" if not self._should_skip_saving_checkpoint(trainer) and not self._should_save_on_train_epoch_end(trainer): monitor_candidates = self._monitor_candidates(trainer) + # If a step/time-triggered save was deferred due to a missing monitored metric, + # perform the save now that validation metrics are available. + if self._defer_save_until_validation: + self._save_topk_checkpoint(trainer, monitor_candidates) + self._save_last_checkpoint(trainer, monitor_candidates) + self._defer_save_until_validation = False + return + if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: self._save_topk_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates) diff --git a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py index 4ef260f00006d..4a0b0d67041ba 100644 --- a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py +++ b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py @@ -265,7 +265,9 @@ def on_train_start(self, *_: Any) -> None: def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None: if self._leave: self.train_progress_bar = self.init_train_tqdm() - self.train_progress_bar.reset(convert_inf(self.total_train_batches)) + total = convert_inf(self.total_train_batches) + self.train_progress_bar.reset() + self.train_progress_bar.total = total self.train_progress_bar.initial = 0 self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}") @@ -306,7 +308,9 @@ def on_validation_batch_start( if not self.has_dataloader_changed(dataloader_idx): return - self.val_progress_bar.reset(convert_inf(self.total_val_batches_current_dataloader)) + total = convert_inf(self.total_val_batches_current_dataloader) + self.val_progress_bar.reset() + self.val_progress_bar.total = total self.val_progress_bar.initial = 0 desc = self.sanity_check_description if trainer.sanity_checking else self.validation_description self.val_progress_bar.set_description(f"{desc} DataLoader {dataloader_idx}") @@ -348,7 +352,9 @@ def on_test_batch_start( if not self.has_dataloader_changed(dataloader_idx): return - self.test_progress_bar.reset(convert_inf(self.total_test_batches_current_dataloader)) + total = convert_inf(self.total_test_batches_current_dataloader) + self.test_progress_bar.reset() + self.test_progress_bar.total = total self.test_progress_bar.initial = 0 self.test_progress_bar.set_description(f"{self.test_description} DataLoader {dataloader_idx}") @@ -387,7 +393,9 @@ def on_predict_batch_start( if not self.has_dataloader_changed(dataloader_idx): return - self.predict_progress_bar.reset(convert_inf(self.total_predict_batches_current_dataloader)) + total = convert_inf(self.total_predict_batches_current_dataloader) + self.predict_progress_bar.reset() + self.predict_progress_bar.total = total self.predict_progress_bar.initial = 0 self.predict_progress_bar.set_description(f"{self.predict_description} DataLoader {dataloader_idx}") diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index 225296240674a..91247127f6c87 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -16,6 +16,7 @@ import sys from collections.abc import Iterable from functools import partial, update_wrapper +from pathlib import Path from types import MethodType from typing import Any, Callable, Optional, TypeVar, Union @@ -397,6 +398,7 @@ def __init__( main_kwargs, subparser_kwargs = self._setup_parser_kwargs(self.parser_kwargs) self.setup_parser(run, main_kwargs, subparser_kwargs) self.parse_arguments(self.parser, args) + self._parse_ckpt_path() self.subcommand = self.config["subcommand"] if run else None @@ -551,6 +553,24 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No else: self.config = parser.parse_args(args) + def _parse_ckpt_path(self) -> None: + """If a checkpoint path is given, parse the hyperparameters from the checkpoint and update the config.""" + if not self.config.get("subcommand"): + return + ckpt_path = self.config[self.config.subcommand].get("ckpt_path") + if ckpt_path and Path(ckpt_path).is_file(): + ckpt = torch.load(ckpt_path, weights_only=True, map_location="cpu") + hparams = ckpt.get("hyper_parameters", {}) + hparams.pop("_instantiator", None) + if not hparams: + return + hparams = {self.config.subcommand: {"model": hparams}} + try: + self.config = self.parser.parse_object(hparams, self.config) + except SystemExit: + sys.stderr.write("Parsing of ckpt_path hyperparameters failed!\n") + raise + def _dump_config(self) -> None: if hasattr(self, "config_dump"): return diff --git a/src/lightning/pytorch/loops/fit_loop.py b/src/lightning/pytorch/loops/fit_loop.py index 31d6724a043a3..f25c33359a78a 100644 --- a/src/lightning/pytorch/loops/fit_loop.py +++ b/src/lightning/pytorch/loops/fit_loop.py @@ -414,6 +414,9 @@ def on_run_start(self) -> None: self.epoch_loop.val_loop.setup_data() trainer.training = True + # Check for modules in eval mode at training start + self._warn_if_modules_in_eval_mode() + call._call_callback_hooks(trainer, "on_train_start") call._call_lightning_module_hook(trainer, "on_train_start") call._call_strategy_hook(trainer, "on_train_start") @@ -515,6 +518,19 @@ def on_load_checkpoint(self, state_dict: dict) -> None: self._combined_loader_states_to_load = state_dict.get("combined_loader", []) super().on_load_checkpoint(state_dict) + def _warn_if_modules_in_eval_mode(self) -> None: + """Warn if any modules are in eval mode at the start of training.""" + model = self.trainer.lightning_module + eval_modules = [name for name, module in model.named_modules() if not module.training] + + if eval_modules: + rank_zero_warn( + f"Found {len(eval_modules)} module(s) in eval mode at the start of training." + " This may lead to unexpected behavior during training. If this is intentional," + " you can ignore this warning.", + category=PossibleUserWarning, + ) + def _should_accumulate(self) -> bool: """Whether the gradients should be accumulated.""" return self.epoch_loop._should_accumulate() diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index 75e792af46b90..5ea62233e1f69 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -112,7 +112,8 @@ def clip_gradients( super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm) def autocast_context_manager(self) -> torch.autocast: - return torch.autocast(self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half)) + dtype = torch.bfloat16 if self.precision == "bf16-mixed" else torch.half + return torch.autocast(self.device, dtype=dtype, cache_enabled=False) @override @contextmanager diff --git a/src/lightning/pytorch/strategies/ddp.py b/src/lightning/pytorch/strategies/ddp.py index fd3f66ef42471..92206e1accc31 100644 --- a/src/lightning/pytorch/strategies/ddp.py +++ b/src/lightning/pytorch/strategies/ddp.py @@ -36,7 +36,7 @@ _sync_ddp_if_available, ) from lightning.fabric.utilities.distributed import group as _group -from lightning.fabric.utilities.imports import _IS_WINDOWS +from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_3 from lightning.fabric.utilities.optimizer import _optimizers_to_device from lightning.fabric.utilities.seed import reset_seed from lightning.fabric.utilities.types import ReduceOp @@ -200,7 +200,10 @@ def setup_distributed(self) -> None: self.set_world_ranks() self._process_group_backend = self._get_process_group_backend() assert self.cluster_environment is not None - _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) + kwargs: dict[str, Any] = {"timeout": self._timeout} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None + _init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs) def _get_process_group_backend(self) -> str: return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index dabfde70242b9..c5253f77cdedb 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -122,16 +122,16 @@ def __init__( precision_plugin: Optional[Precision] = None, process_group_backend: Optional[str] = None, timeout: Optional[timedelta] = default_pg_timeout, + exclude_frozen_parameters: bool = False, ) -> None: """Provides capabilities to run training using the DeepSpeed library, with training optimizations for large - billion parameter models. `For more information: https://pytorch- - lightning.readthedocs.io/en/stable/advanced/model_parallel.html#deepspeed`. + billion parameter models. *For more information:* :ref:`deepspeed_advanced`. .. warning:: This is an :ref:`experimental ` feature. Defaults have been set to enable ZeRO-Offload and some have been taken from the link below. These defaults have been set generally, but may require tuning for optimum performance based on your model size. - `For more information: https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training`. + *For more information:* https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training. Arguments: @@ -253,6 +253,8 @@ def __init__( when using ZeRO Stage 3. This differs from the DeepSpeed checkpoint which contains shards per worker. + exclude_frozen_parameters: Exclude frozen parameters when saving checkpoints. + """ if not _DEEPSPEED_AVAILABLE: raise MisconfigurationException( @@ -311,6 +313,7 @@ def __init__( self.remote_device = remote_device self.load_full_weights = load_full_weights + self.exclude_frozen_parameters = exclude_frozen_parameters # default FP16 parameters. self.loss_scale = loss_scale @@ -648,7 +651,12 @@ def save_checkpoint(self, checkpoint: dict, filepath: _PATH, storage_options: Op # dump states as a checkpoint dictionary object _exclude_keys = ["state_dict", "optimizer_states"] checkpoint = {k: v for k, v in checkpoint.items() if k not in _exclude_keys} - self.deepspeed_engine.save_checkpoint(filepath, client_state=checkpoint, tag="checkpoint") + self.deepspeed_engine.save_checkpoint( + filepath, + client_state=checkpoint, + tag="checkpoint", + exclude_frozen_parameters=self.exclude_frozen_parameters, + ) @override def load_checkpoint(self, checkpoint_path: _PATH) -> dict[str, Any]: diff --git a/src/lightning/pytorch/strategies/fsdp.py b/src/lightning/pytorch/strategies/fsdp.py index 55ea354a5cb60..3fbd0f9cd5f0a 100644 --- a/src/lightning/pytorch/strategies/fsdp.py +++ b/src/lightning/pytorch/strategies/fsdp.py @@ -61,7 +61,7 @@ _sync_ddp_if_available, ) from lightning.fabric.utilities.distributed import group as _group -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_3 from lightning.fabric.utilities.init import _has_meta_device_parameters_or_buffers from lightning.fabric.utilities.load import _lazy_load, _materialize_tensors from lightning.fabric.utilities.optimizer import _optimizers_to_device @@ -260,7 +260,10 @@ def setup_environment(self) -> None: self._process_group_backend = self._get_process_group_backend() assert self.cluster_environment is not None - _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) + kwargs: dict[str, Any] = {"timeout": self._timeout} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None + _init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs) # if 'device_mesh' in the `kwargs` is provided as a tuple, update it into the `DeviceMesh` object here if isinstance(self.kwargs.get("device_mesh"), tuple): diff --git a/src/lightning/pytorch/strategies/model_parallel.py b/src/lightning/pytorch/strategies/model_parallel.py index 82fec205af731..e0286dbe2e0e6 100644 --- a/src/lightning/pytorch/strategies/model_parallel.py +++ b/src/lightning/pytorch/strategies/model_parallel.py @@ -39,7 +39,7 @@ _sync_ddp_if_available, ) from lightning.fabric.utilities.distributed import group as _group -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3, _TORCH_GREATER_EQUAL_2_4 from lightning.fabric.utilities.init import _materialize_distributed_module from lightning.fabric.utilities.load import _METADATA_FILENAME from lightning.fabric.utilities.optimizer import _optimizers_to_device @@ -350,7 +350,10 @@ def _setup_distributed(self) -> None: self.set_world_ranks() self._process_group_backend = self._get_process_group_backend() assert self.cluster_environment is not None - _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) + kwargs: dict[str, Any] = {"timeout": self._timeout} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = self.root_device if self.root_device.type != "cpu" else None + _init_dist_connection(self.cluster_environment, self._process_group_backend, **kwargs) def _get_process_group_backend(self) -> str: return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index b78843990af30..f3d45a17d4a3b 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -184,16 +184,16 @@ def __init__( :class:`datetime.timedelta`. limit_train_batches: How much of training dataset to check (float = fraction, int = num_batches). - Default: ``1.0``. + Value is per device. Default: ``1.0``. limit_val_batches: How much of validation dataset to check (float = fraction, int = num_batches). - Default: ``1.0``. + Value is per device. Default: ``1.0``. limit_test_batches: How much of test dataset to check (float = fraction, int = num_batches). - Default: ``1.0``. + Value is per device. Default: ``1.0``. limit_predict_batches: How much of prediction dataset to check (float = fraction, int = num_batches). - Default: ``1.0``. + Value is per device. Default: ``1.0``. overfit_batches: Overfit a fraction of training/validation data (float) or a set number of batches (int). Default: ``0.0``. diff --git a/src/lightning/pytorch/tuner/batch_size_scaling.py b/src/lightning/pytorch/tuner/batch_size_scaling.py index 99badd84bb8ad..78d2aa52f5725 100644 --- a/src/lightning/pytorch/tuner/batch_size_scaling.py +++ b/src/lightning/pytorch/tuner/batch_size_scaling.py @@ -76,24 +76,27 @@ def _scale_batch_size( if trainer.progress_bar_callback: trainer.progress_bar_callback.disable() - new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val) - - if mode == "power": - new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params) - elif mode == "binsearch": - new_size = _run_binary_scaling(trainer, new_size, batch_arg_name, max_trials, params) + try: + new_size, _ = _adjust_batch_size(trainer, batch_arg_name, value=init_val) - garbage_collection_cuda() + if mode == "power": + new_size = _run_power_scaling(trainer, new_size, batch_arg_name, max_trials, params) + elif mode == "binsearch": + new_size = _run_binary_scaling(trainer, new_size, batch_arg_name, max_trials, params) - log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}") + garbage_collection_cuda() - __scale_batch_restore_params(trainer, params) + log.info(f"Finished batch size finder, will continue with full run using batch size {new_size}") + except Exception as ex: + raise ex + finally: + __scale_batch_restore_params(trainer, params) - if trainer.progress_bar_callback: - trainer.progress_bar_callback.enable() + if trainer.progress_bar_callback: + trainer.progress_bar_callback.enable() - trainer._checkpoint_connector.restore(ckpt_path) - trainer.strategy.remove_checkpoint(ckpt_path) + trainer._checkpoint_connector.restore(ckpt_path) + trainer.strategy.remove_checkpoint(ckpt_path) return new_size diff --git a/src/lightning/pytorch/tuner/lr_finder.py b/src/lightning/pytorch/tuner/lr_finder.py index b4b61d5cf0f93..5ef35dcd6d992 100644 --- a/src/lightning/pytorch/tuner/lr_finder.py +++ b/src/lightning/pytorch/tuner/lr_finder.py @@ -257,40 +257,45 @@ def _lr_find( # Initialize lr finder object (stores results) lr_finder = _LRFinder(mode, min_lr, max_lr, num_training) - # Configure optimizer and scheduler - lr_finder._exchange_scheduler(trainer) - - # Fit, lr & loss logged in callback - _try_loop_run(trainer, params) - - # Prompt if we stopped early - if trainer.global_step != num_training + start_steps: - log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.") - - # Transfer results from callback to lr finder object - lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses}) - lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose - - __lr_finder_restore_params(trainer, params) - - if trainer.progress_bar_callback: - trainer.progress_bar_callback.enable() - - # Update results across ranks - lr_finder.results = trainer.strategy.broadcast(lr_finder.results) - - # Restore initial state of model (this will also restore the original optimizer state) - trainer._checkpoint_connector.restore(ckpt_path) - trainer.strategy.remove_checkpoint(ckpt_path) - trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True - trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True - trainer.fit_loop.epoch_loop.val_loop._combined_loader = None - trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit - trainer.fit_loop.setup_data() + lr_finder_finished = False + try: + # Configure optimizer and scheduler + lr_finder._exchange_scheduler(trainer) + + # Fit, lr & loss logged in callback + _try_loop_run(trainer, params) + + # Prompt if we stopped early + if trainer.global_step != num_training + start_steps: + log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.") + + # Transfer results from callback to lr finder object + lr_finder.results.update({"lr": trainer.callbacks[0].lrs, "loss": trainer.callbacks[0].losses}) + lr_finder._total_batch_idx = trainer.fit_loop.total_batch_idx # for debug purpose + + __lr_finder_restore_params(trainer, params) + + if trainer.progress_bar_callback: + trainer.progress_bar_callback.enable() + + # Update results across ranks + lr_finder.results = trainer.strategy.broadcast(lr_finder.results) + lr_finder_finished = True + except Exception as ex: + raise ex + finally: + # Restore initial state of model (this will also restore the original optimizer state) + trainer._checkpoint_connector.restore(ckpt_path) + trainer.strategy.remove_checkpoint(ckpt_path) + trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True + trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True + trainer.fit_loop.epoch_loop.val_loop._combined_loader = None + trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit + trainer.fit_loop.setup_data() # Apply LR suggestion after restoring so it persists for the real training run # When used as a callback, the suggestion would otherwise be lost due to checkpoint restore - if update_attr: + if update_attr and lr_finder_finished: lr = lr_finder.suggestion() if lr is not None: # update the attribute on the LightningModule (e.g., lr or learning_rate) diff --git a/src/version.info b/src/version.info index fe16b348d97f7..0cadbc1e33b27 100644 --- a/src/version.info +++ b/src/version.info @@ -1 +1 @@ -2.5.4 +2.5.5 diff --git a/tests/legacy/back-compatible-versions.txt b/tests/legacy/back-compatible-versions.txt index 9032e5c13cc54..885dd6f85b36a 100644 --- a/tests/legacy/back-compatible-versions.txt +++ b/tests/legacy/back-compatible-versions.txt @@ -107,3 +107,4 @@ 2.5.1 2.5.2 2.5.3 +2.5.4 diff --git a/tests/legacy/generate_checkpoints.sh b/tests/legacy/generate_checkpoints.sh index 1d083a2a8e052..d3cfa693c2906 100644 --- a/tests/legacy/generate_checkpoints.sh +++ b/tests/legacy/generate_checkpoints.sh @@ -7,18 +7,17 @@ set -e LEGACY_FOLDER=$(cd $(dirname $0); pwd -P) -printf "LEGACY_FOLDER: $LEGACY_FOLDER" +printf "LEGACY_FOLDER: $LEGACY_FOLDER\n" TESTS_FOLDER=$(dirname $LEGACY_FOLDER) -ENV_PATH=$LEGACY_FOLDER/vEnv -printf "ENV_PATH: $ENV_PATH" +ENV_PATH=$LEGACY_FOLDER/.venv +printf "ENV_PATH: $ENV_PATH\n" export PYTHONPATH=$TESTS_FOLDER # for `import tests_pytorch` -printf "PYTHONPATH: $PYTHONPATH" +printf "PYTHONPATH: $PYTHONPATH\n" rm -rf $ENV_PATH function create_and_save_checkpoint { - python --version - python -m pip --version - python -m pip list + uv --version + uv pip list python $LEGACY_FOLDER/simple_classif_training.py $pl_ver @@ -33,10 +32,10 @@ do printf "\n\n processing version: $pl_ver\n" # Don't install/update anything before activating venv to avoid breaking any existing environment. - python -m venv $ENV_PATH + uv venv $ENV_PATH source $ENV_PATH/bin/activate - python -m pip install "pytorch_lightning==$pl_ver" \ + uv pip install "pytorch_lightning==$pl_ver" \ -r $LEGACY_FOLDER/requirements.txt \ -r "$(dirname $TESTS_FOLDER)/requirements/pytorch/test.txt" \ -f https://download.pytorch.org/whl/cpu/torch_stable.html @@ -52,7 +51,7 @@ done if [[ -z "$@" ]]; then printf "\n\n processing local version\n" - python -m pip install \ + uv pip install \ -r $LEGACY_FOLDER/requirements.txt \ -r "$(dirname $TESTS_FOLDER)/requirements/pytorch/test.txt" \ -f https://download.pytorch.org/whl/cpu/torch_stable.html diff --git a/tests/tests_fabric/strategies/test_ddp.py b/tests/tests_fabric/strategies/test_ddp.py index fa5c975228a5e..f302da5d1bc4f 100644 --- a/tests/tests_fabric/strategies/test_ddp.py +++ b/tests/tests_fabric/strategies/test_ddp.py @@ -25,6 +25,7 @@ from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies import DDPStrategy from lightning.fabric.strategies.ddp import _DDPBackwardSyncControl +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 from tests_fabric.helpers.runif import RunIf @@ -168,6 +169,52 @@ def test_set_timeout(init_process_group_mock): process_group_backend = strategy._get_process_group_backend() global_rank = strategy.cluster_environment.global_rank() world_size = strategy.cluster_environment.world_size() + kwargs = {} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = strategy.root_device if strategy.root_device.type != "cpu" else None init_process_group_mock.assert_called_with( - process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta + process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs + ) + + +@mock.patch("torch.distributed.init_process_group") +def test_device_id_passed_for_cuda_devices(init_process_group_mock): + """Test that device_id is passed to init_process_group for CUDA devices but not for CPU.""" + # Test with CPU device - device_id should be None + cpu_strategy = DDPStrategy(parallel_devices=[torch.device("cpu")]) + cpu_strategy.cluster_environment = LightningEnvironment() + cpu_strategy.accelerator = Mock() + cpu_strategy.setup_environment() + + process_group_backend = cpu_strategy._get_process_group_backend() + global_rank = cpu_strategy.cluster_environment.global_rank() + world_size = cpu_strategy.cluster_environment.world_size() + kwargs = {} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = cpu_strategy.root_device if cpu_strategy.root_device.type != "cpu" else None + init_process_group_mock.assert_called_with( + process_group_backend, rank=global_rank, world_size=world_size, timeout=cpu_strategy._timeout, **kwargs + ) + + init_process_group_mock.reset_mock() + + # Test with CUDA device - device_id should be the device + cuda_device = torch.device("cuda", 0) + cuda_strategy = DDPStrategy(parallel_devices=[cuda_device]) + cuda_strategy.cluster_environment = LightningEnvironment() + cuda_strategy.accelerator = Mock() + cuda_strategy.setup_environment() + + process_group_backend = cuda_strategy._get_process_group_backend() + global_rank = cuda_strategy.cluster_environment.global_rank() + world_size = cuda_strategy.cluster_environment.world_size() + kwargs = {} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = cuda_strategy.root_device if cuda_strategy.root_device.type != "cpu" else None + init_process_group_mock.assert_called_with( + process_group_backend, + rank=global_rank, + world_size=world_size, + timeout=cuda_strategy._timeout, + **kwargs, ) diff --git a/tests/tests_fabric/strategies/test_deepspeed.py b/tests/tests_fabric/strategies/test_deepspeed.py index 032ee63cd4721..817121c9e2f59 100644 --- a/tests/tests_fabric/strategies/test_deepspeed.py +++ b/tests/tests_fabric/strategies/test_deepspeed.py @@ -193,7 +193,9 @@ def test_deepspeed_save_checkpoint_client_state_separation(tmp_path): model.modules.return_value = [model] strategy.save_checkpoint(path=tmp_path, state={"model": model, "test": "data"}) # the client_state should not contain any deepspeed engine or deepspeed optimizer - model.save_checkpoint.assert_called_with(tmp_path, client_state={"test": "data"}, tag="checkpoint") + model.save_checkpoint.assert_called_with( + tmp_path, client_state={"test": "data"}, tag="checkpoint", exclude_frozen_parameters=False + ) # Model and optimizer optimizer = Mock() @@ -201,7 +203,9 @@ def test_deepspeed_save_checkpoint_client_state_separation(tmp_path): model.modules.return_value = [model] strategy.save_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "test": "data"}) # the client_state should not contain any deepspeed engine or deepspeed optimizer - model.save_checkpoint.assert_called_with(tmp_path, client_state={"test": "data"}, tag="checkpoint") + model.save_checkpoint.assert_called_with( + tmp_path, client_state={"test": "data"}, tag="checkpoint", exclude_frozen_parameters=False + ) @RunIf(deepspeed=True) @@ -218,6 +222,27 @@ def test_deepspeed_save_checkpoint_warn_colliding_keys(tmp_path): strategy.save_checkpoint(path=tmp_path, state={"model": model, "optimizer": optimizer, "mp_world_size": 2}) +@RunIf(deepspeed=True) +@pytest.mark.parametrize("exclude_frozen_parameters", [True, False]) +def test_deepspeed_save_checkpoint_exclude_frozen_parameters(exclude_frozen_parameters): + """Test that the DeepSpeed strategy can save checkpoints with the `exclude_frozen_parameters` argument.""" + from deepspeed import DeepSpeedEngine + + strategy = DeepSpeedStrategy(exclude_frozen_parameters=exclude_frozen_parameters) + assert strategy.exclude_frozen_parameters is exclude_frozen_parameters + + model = Mock(spec=DeepSpeedEngine, optimizer=None) + model.modules.return_value = [model] + strategy.save_checkpoint(path="test_path", state={"model": model, "extra": "data"}) + + model.save_checkpoint.assert_called_with( + "test_path", + client_state={"extra": "data"}, + tag="checkpoint", + exclude_frozen_parameters=exclude_frozen_parameters, + ) + + @RunIf(deepspeed=True) def test_deepspeed_load_checkpoint_validate_path(tmp_path): """Test that we validate the checkpoint path for a DeepSpeed checkpoint and give suggestions for user error.""" diff --git a/tests/tests_fabric/strategies/test_fsdp.py b/tests/tests_fabric/strategies/test_fsdp.py index d5f82752a9176..6be379d36582c 100644 --- a/tests/tests_fabric/strategies/test_fsdp.py +++ b/tests/tests_fabric/strategies/test_fsdp.py @@ -31,7 +31,7 @@ _get_full_state_dict_context, _is_sharded_checkpoint, ) -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_3 def test_custom_mixed_precision(): @@ -381,8 +381,11 @@ def test_set_timeout(init_process_group_mock): process_group_backend = strategy._get_process_group_backend() global_rank = strategy.cluster_environment.global_rank() world_size = strategy.cluster_environment.world_size() + kwargs = {} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = strategy.root_device if strategy.root_device.type != "cpu" else None init_process_group_mock.assert_called_with( - process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta + process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs ) diff --git a/tests/tests_fabric/strategies/test_model_parallel.py b/tests/tests_fabric/strategies/test_model_parallel.py index 78622adf66fa6..4b6cb6f8fad85 100644 --- a/tests/tests_fabric/strategies/test_model_parallel.py +++ b/tests/tests_fabric/strategies/test_model_parallel.py @@ -25,6 +25,7 @@ from lightning.fabric.strategies import ModelParallelStrategy from lightning.fabric.strategies.fsdp import _is_sharded_checkpoint from lightning.fabric.strategies.model_parallel import _ParallelBackwardSyncControl +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 from tests_fabric.helpers.runif import RunIf @@ -316,8 +317,11 @@ def test_set_timeout(init_process_group_mock, _): process_group_backend = strategy._get_process_group_backend() global_rank = strategy.cluster_environment.global_rank() world_size = strategy.cluster_environment.world_size() + kwargs = {} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = strategy.root_device if strategy.root_device.type != "cpu" else None init_process_group_mock.assert_called_with( - process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta + process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs ) diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index d65eaa810ff4d..51c4b320d5525 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -17,6 +17,7 @@ from lightning.fabric.utilities.distributed import ( _destroy_dist_connection, _gather_all_tensors, + _get_default_process_group_backend_for_device, _InfiniteBarrier, _init_dist_connection, _is_dtensor, @@ -243,6 +244,27 @@ def test_init_dist_connection_registers_destruction_handler(_, atexit_mock): atexit_mock.register.assert_not_called() +def test_get_default_process_group_backend_for_device(): + """Test that each device type maps to its correct default process group backend.""" + # register a custom backend for test + torch.utils.rename_privateuse1_backend("pcu") + + def mock_backend(store, group_rank, group_size, timeout): + pass + + torch.distributed.Backend.register_backend( + "pccl", + lambda store, group_rank, group_size, timeout: mock_backend(store, group_rank, group_size, timeout), + devices=["pcu"], + ) + + # test that the default backend is correctly set for each device + devices = [torch.device("cpu"), torch.device("cuda:0"), torch.device("pcu:0")] + backends = ["gloo", "nccl", "pccl"] + for device, backend in zip(devices, backends): + assert _get_default_process_group_backend_for_device(device) == backend + + @RunIf(min_torch="2.4") def test_is_dtensor(monkeypatch): from torch.distributed._tensor import DTensor diff --git a/tests/tests_fabric/utilities/test_seed.py b/tests/tests_fabric/utilities/test_seed.py index 2700213747f9a..81fde5aae3ef5 100644 --- a/tests/tests_fabric/utilities/test_seed.py +++ b/tests/tests_fabric/utilities/test_seed.py @@ -72,6 +72,14 @@ def test_seed_everything_accepts_valid_seed_from_env(): assert seed_everything() == 17 +@mock.patch.dict(os.environ, {}, clear=True) +def test_seed_everything_non_verbose_no_warning(): + """Ensure that no warning is emitted when verbose is False and no seed is provided.""" + with warnings.catch_warnings(record=True) as caught: + seed_everything(verbose=False) + assert caught == [] + + def test_reset_seed_no_op(): """Test that the reset_seed function is a no-op when seed_everything() was not used.""" assert "PL_GLOBAL_SEED" not in os.environ diff --git a/tests/tests_fabric/utilities/test_spike.py b/tests/tests_fabric/utilities/test_spike.py index e96a5f77df384..e7f0bedb8e9e9 100644 --- a/tests/tests_fabric/utilities/test_spike.py +++ b/tests/tests_fabric/utilities/test_spike.py @@ -30,7 +30,7 @@ def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise): ) -@pytest.mark.flaky(max_runs=3) +@pytest.mark.flaky(reruns=3) @pytest.mark.parametrize( ("global_rank_spike", "num_devices", "spike_value", "finite_only"), # NOTE FOR ALL FOLLOWING TESTS: diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index d93bf1cf60e9c..89d2fa73e3b6b 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -18,7 +18,7 @@ from collections import defaultdict from typing import Union from unittest import mock -from unittest.mock import ANY, Mock, PropertyMock, call +from unittest.mock import ANY, Mock, PropertyMock, call, patch import pytest import torch @@ -801,3 +801,50 @@ def test_tqdm_leave(leave, tmp_path): ) trainer.fit(model) assert pbar.init_train_tqdm.call_count == (4 if leave else 1) + + +@patch("lightning.pytorch.callbacks.progress.rich_progress._RICH_AVAILABLE", False) +def test_tqdm_progress_bar_reset_behavior(tmp_path): + """Test that progress bars call reset() without parameters and set total separately.""" + model = BoringModel() + + class ResetTrackingTqdm(MockTqdm): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.reset_calls_with_params = [] + + def reset(self, total=None): + self.reset_calls_with_params.append(total) + super().reset(total) + + trainer = Trainer( + default_root_dir=tmp_path, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=1, + logger=False, + enable_checkpointing=False, + ) + + pbar = trainer.progress_bar_callback + + with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", ResetTrackingTqdm): + trainer.fit(model) + + train_bar = pbar.train_progress_bar + assert None in train_bar.reset_calls_with_params, ( + f"train reset() should be called without parameters, got calls: {train_bar.reset_calls_with_params}" + ) + # Verify that total was set separately to the expected value + assert 2 in train_bar.total_values, ( + f"train total should be set to 2 after reset(), got total_values: {train_bar.total_values}" + ) + # Verify that validation progress bar reset() was called without parameters + val_bar = pbar.val_progress_bar + assert None in val_bar.reset_calls_with_params, ( + f"validation reset() should be called without parameters, got calls: {val_bar.reset_calls_with_params}" + ) + # Verify that total was set separately to the expected value + assert 2 in val_bar.total_values, ( + f"validation total should be set to 2 after reset(), got total_values: {val_bar.total_values}" + ) diff --git a/tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py b/tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py new file mode 100644 index 0000000000000..fba9e865debfd --- /dev/null +++ b/tests/tests_pytorch/callbacks/test_model_checkpoint_additional_cases.py @@ -0,0 +1,208 @@ +import math +from datetime import timedelta + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset + +from lightning.pytorch import LightningModule, Trainer, seed_everything +from lightning.pytorch.callbacks import ModelCheckpoint + + +class TinyDataset(Dataset): + def __init__(self, n: int = 4): + self.x = torch.arange(n, dtype=torch.float32).view(-1, 1) + self.y = self.x.clone() + + def __len__(self): + return len(self.x) + + def __getitem__(self, idx): + return self.x[idx], self.y[idx] + + +class TrainMetricModule(LightningModule): + def __init__(self): + super().__init__() + self.layer = nn.Linear(1, 1) + self._counter = 0.0 + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.layer(x) + loss = F.mse_loss(y_hat, y) + # strictly increasing train metric per step + self._counter += 1.0 + self.log("train_score", torch.tensor(self._counter), on_step=True, on_epoch=False, prog_bar=False, logger=True) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + pass + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.01) + + +def _make_loaders(n=4): + ds = TinyDataset(n=n) + train_loader = DataLoader(ds, batch_size=2, shuffle=False) + val_loader = DataLoader(ds, batch_size=2, shuffle=False) + return train_loader, val_loader + + +def test_model_checkpoint_every_n_train_steps_with_train_metric_saves_at_step(tmp_path): + """When monitoring a train-step metric, step-interval checkpointing should save at the step boundary (no deferral) + and best_model_score should match the last train metric value.""" + seed_everything(123) + + train_loader, val_loader = _make_loaders(n=4) + model = TrainMetricModule() + + ckpt = ModelCheckpoint( + dirpath=tmp_path, + monitor="train_score", + mode="max", + save_top_k=1, + every_n_train_steps=1, + train_time_interval=None, + every_n_epochs=0, + save_on_train_epoch_end=False, + save_weights_only=True, + ) + + # 2 batches/epoch, run 2 epochs to have multiple step saves + trainer = Trainer( + max_epochs=2, + accelerator="cpu", + devices=1, + callbacks=[ckpt], + num_sanity_val_steps=0, + log_every_n_steps=1, + limit_train_batches=2, + limit_val_batches=0, # no validation needed for this test + enable_checkpointing=True, + enable_model_summary=False, + logger=False, + ) + + trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) + + assert ckpt.best_model_score is not None + # 2 epochs * 2 steps/epoch = 4 steps total; metric increments by 1 each step + expected = 4.0 + actual = float(ckpt.best_model_score) + assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6) + + +@pytest.mark.parametrize("val_scores", [[0.2, 0.4, 0.9]]) +def test_model_checkpoint_time_interval_with_val_metric_defers_until_validation(tmp_path, val_scores): + """With time-interval-based checkpointing, and a validation-only metric, ensure we don't save using stale metrics + at step boundaries; saving should occur at validation end.""" + seed_everything(123) + + train_loader, val_loader = _make_loaders(n=4) + + model = ValMetricModule(val_scores=val_scores) + + ckpt = ModelCheckpoint( + dirpath=tmp_path, + monitor="auroc", + mode="max", + save_top_k=1, + every_n_train_steps=0, # disable step-based + train_time_interval=timedelta(seconds=0), # trigger as often as possible + every_n_epochs=0, + save_on_train_epoch_end=False, + save_weights_only=True, + ) + + trainer = Trainer( + max_epochs=len(val_scores), + accelerator="cpu", + devices=1, + callbacks=[ckpt], + num_sanity_val_steps=0, + log_every_n_steps=1, + limit_train_batches=2, + limit_val_batches=1, + enable_checkpointing=True, + enable_model_summary=False, + logger=False, + ) + + trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) + + assert ckpt.best_model_score is not None + expected = max(val_scores) + actual = float(ckpt.best_model_score) + assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6) + + +class ValMetricModule(LightningModule): + def __init__(self, val_scores: list[float]): + super().__init__() + self.layer = nn.Linear(1, 1) + self._val_scores = [float(s) for s in val_scores] + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.layer(x) + loss = F.mse_loss(y_hat, y) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + pass + + def on_validation_epoch_end(self): + score = self._val_scores[self.current_epoch] + self.log("auroc", torch.tensor(score, dtype=torch.float32), prog_bar=False, logger=True) + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.01) + + +@pytest.mark.parametrize("val_scores", [[0.1, 0.5, 1.0, 3.0]]) +def test_model_checkpoint_defer_until_next_validation_when_val_every_2_epochs(tmp_path, val_scores): + """With validation running every 2 epochs, step-triggered saves at the end of non-validation epochs should be + deferred and then performed at the next validation end when the metric is available.""" + seed_everything(123) + + train_loader, val_loader = _make_loaders(n=4) + + model = ValMetricModule(val_scores=val_scores) + + ckpt = ModelCheckpoint( + dirpath=tmp_path, + monitor="auroc", + mode="max", + save_top_k=1, + every_n_train_steps=2, # end of each epoch + train_time_interval=None, + every_n_epochs=0, + save_on_train_epoch_end=False, + save_weights_only=True, + ) + + trainer = Trainer( + max_epochs=len(val_scores), + accelerator="cpu", + devices=1, + callbacks=[ckpt], + num_sanity_val_steps=0, + log_every_n_steps=1, + limit_train_batches=2, + limit_val_batches=1, + enable_checkpointing=True, + enable_model_summary=False, + logger=False, + check_val_every_n_epoch=2, # only validate every 2 epochs + ) + + trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) + + assert ckpt.best_model_score is not None + expected = max(val_scores) # last/maximum value occurs at final validation epoch + actual = float(ckpt.best_model_score) + assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6) diff --git a/tests/tests_pytorch/callbacks/test_model_checkpoint_edge_cases.py b/tests/tests_pytorch/callbacks/test_model_checkpoint_edge_cases.py new file mode 100644 index 0000000000000..a265f8bc5f194 --- /dev/null +++ b/tests/tests_pytorch/callbacks/test_model_checkpoint_edge_cases.py @@ -0,0 +1,174 @@ +import math +from datetime import timedelta + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset + +from lightning.pytorch import LightningModule, Trainer, seed_everything +from lightning.pytorch.callbacks import ModelCheckpoint + + +class TinyDataset(Dataset): + def __init__(self, n: int = 8): + self.x = torch.arange(n, dtype=torch.float32).view(-1, 1) + self.y = self.x.clone() + + def __len__(self): + return len(self.x) + + def __getitem__(self, idx): + return self.x[idx], self.y[idx] + + +def _make_loaders(n=8, batch_size=2): + ds = TinyDataset(n=n) + train_loader = DataLoader(ds, batch_size=batch_size, shuffle=False) + val_loader = DataLoader(ds, batch_size=batch_size, shuffle=False) + return train_loader, val_loader + + +class MultiValPerEpochModule(LightningModule): + """Logs a validation metric on every validation run, even if validation is run multiple times per epoch.""" + + def __init__(self, val_scores: list[float]): + super().__init__() + self.layer = nn.Linear(1, 1) + self._val_scores = [float(s) for s in val_scores] + self._val_call_idx = 0 + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.layer(x) + loss = F.mse_loss(y_hat, y) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + pass + + def on_validation_epoch_end(self): + score = self._val_scores[self._val_call_idx] + self._val_call_idx += 1 + self.log("auroc", torch.tensor(score, dtype=torch.float32), prog_bar=False, logger=True) + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.01) + + +class ValOnceEveryTwoEpochsModule(LightningModule): + """Logs a validation metric only when validation runs (e.g., every 2 epochs), indexed by current_epoch.""" + + def __init__(self, val_scores: list[float]): + super().__init__() + self.layer = nn.Linear(1, 1) + self._val_scores = [float(s) for s in val_scores] + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.layer(x) + loss = F.mse_loss(y_hat, y) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + pass + + def on_validation_epoch_end(self): + # current_epoch indexes into provided scores; only called when validation runs + score = self._val_scores[self.current_epoch] + self.log("auroc", torch.tensor(score, dtype=torch.float32), prog_bar=False, logger=True) + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.01) + + +@pytest.mark.parametrize("val_scores", [[0.1, 0.9]]) +def test_checkpoint_defers_with_mid_epoch_validation(tmp_path, val_scores): + """With val_check_interval=0.5 (validation mid-epoch and at epoch end), and step-based checkpointing, saves must be + deferred until each validation end so monitored validation metrics are fresh.""" + seed_everything(123) + + # 4 train batches per epoch (batch_size=2 over n=8), so two validations: after 2 batches and after 4 batches + train_loader, val_loader = _make_loaders(n=8, batch_size=2) + + model = MultiValPerEpochModule(val_scores=val_scores) + + ckpt = ModelCheckpoint( + dirpath=tmp_path, + monitor="auroc", + mode="max", + save_top_k=1, + every_n_train_steps=1, # would trigger every step, but must defer to validation + train_time_interval=None, + every_n_epochs=0, + save_on_train_epoch_end=False, + save_weights_only=True, + ) + + trainer = Trainer( + max_epochs=1, + accelerator="cpu", + devices=1, + callbacks=[ckpt], + num_sanity_val_steps=0, + log_every_n_steps=1, + limit_train_batches=4, # ensure exactly 4 steps => two validations at 0.5 and 1.0 + limit_val_batches=1, + enable_checkpointing=True, + enable_model_summary=False, + logger=False, + val_check_interval=0.5, + ) + + trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) + + assert ckpt.best_model_score is not None + expected = max(val_scores) + actual = float(ckpt.best_model_score) + assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6) + + +@pytest.mark.parametrize("val_scores", [[0.2, 0.6]]) +def test_time_interval_defers_across_epoch_until_first_validation(tmp_path, val_scores): + """With time-interval saving and validation only every 2 epochs, ensure no save uses stale/missing validation + metrics; the first save should happen at the first validation end (epoch 2).""" + seed_everything(123) + + train_loader, val_loader = _make_loaders(n=4, batch_size=2) + + model = ValOnceEveryTwoEpochsModule(val_scores=val_scores) + + ckpt = ModelCheckpoint( + dirpath=tmp_path, + monitor="auroc", + mode="max", + save_top_k=1, + every_n_train_steps=0, # disable step-based + train_time_interval=timedelta(seconds=0), # trigger frequently + every_n_epochs=0, + save_on_train_epoch_end=False, + save_weights_only=True, + ) + + trainer = Trainer( + max_epochs=2, + accelerator="cpu", + devices=1, + callbacks=[ckpt], + num_sanity_val_steps=0, + log_every_n_steps=1, + limit_train_batches=2, + limit_val_batches=1, + enable_checkpointing=True, + enable_model_summary=False, + logger=False, + check_val_every_n_epoch=2, # first validation only after 2nd epoch + ) + + trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) + + assert ckpt.best_model_score is not None + expected = val_scores[1] # validation runs only once at epoch 2, logging index 1 + actual = float(ckpt.best_model_score) + assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6) diff --git a/tests/tests_pytorch/callbacks/test_model_checkpoint_step_interval_val_metric.py b/tests/tests_pytorch/callbacks/test_model_checkpoint_step_interval_val_metric.py new file mode 100644 index 0000000000000..c3fa0bfcd2e38 --- /dev/null +++ b/tests/tests_pytorch/callbacks/test_model_checkpoint_step_interval_val_metric.py @@ -0,0 +1,106 @@ +import math + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader, Dataset + +from lightning.pytorch import LightningModule, Trainer, seed_everything +from lightning.pytorch.callbacks import ModelCheckpoint + + +class TinyDataset(Dataset): + def __init__(self, n: int = 4): + self.x = torch.arange(n, dtype=torch.float32).view(-1, 1) + self.y = self.x.clone() + + def __len__(self): + return len(self.x) + + def __getitem__(self, idx): + return self.x[idx], self.y[idx] + + +class ValMetricModule(LightningModule): + def __init__(self, val_scores: list[float]): + super().__init__() + self.layer = nn.Linear(1, 1) + self._val_scores = [float(s) for s in val_scores] + + # LightningModule API (minimal) + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self.layer(x) + loss = F.mse_loss(y_hat, y) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + # do nothing per-step; we log at epoch end + pass + + def on_validation_epoch_end(self): + # Log a validation metric only at validation epoch end + # Values increase across epochs; best should be the last epoch + score = self._val_scores[self.current_epoch] + # use logger=True so it lands in trainer.callback_metrics + self.log("auroc", torch.tensor(score, dtype=torch.float32), prog_bar=False, logger=True) + + def configure_optimizers(self): + return torch.optim.SGD(self.parameters(), lr=0.01) + + +@pytest.mark.parametrize("val_scores", [[0.1, 0.5, 1.0]]) +def test_model_checkpoint_every_n_train_steps_with_val_metric_saves_after_val(tmp_path, val_scores): + """Reproduces #20919: Using every_n_train_steps with a validation-only metric should save the best checkpoint only + after the metric is computed at validation, not earlier at the train-step boundary. + + Expectation: best_model_score equals the last (max) val score. + + """ + seed_everything(123) + + # 2 train batches per epoch (so checkpoint triggers at the epoch boundary) + ds = TinyDataset(n=4) + train_loader = DataLoader(ds, batch_size=2, shuffle=False) + val_loader = DataLoader(ds, batch_size=2, shuffle=False) + + model = ValMetricModule(val_scores=val_scores) + + ckpt = ModelCheckpoint( + dirpath=tmp_path, + monitor="auroc", + mode="max", + save_top_k=1, + # critical: trigger on train steps, not on epoch end + every_n_train_steps=2, # equal to number of train batches per epoch + train_time_interval=None, + every_n_epochs=0, + save_on_train_epoch_end=False, + save_weights_only=True, + ) + + trainer = Trainer( + max_epochs=len(val_scores), + accelerator="cpu", + devices=1, + callbacks=[ckpt], + num_sanity_val_steps=0, + log_every_n_steps=1, + limit_train_batches=2, + limit_val_batches=1, + enable_checkpointing=True, + enable_model_summary=False, + logger=False, + ) + + trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader) + + assert ckpt.best_model_score is not None + # Should equal the last (max) validation score + expected = max(val_scores) + actual = float(ckpt.best_model_score) + assert math.isclose(actual, expected, rel_tol=0, abs_tol=1e-6), ( + f"best_model_score should be {expected} (last/maximum val score), got {actual}.\n" + f"This indicates the checkpoint was saved before the validation metric was computed." + ) diff --git a/tests/tests_pytorch/callbacks/test_spike.py b/tests/tests_pytorch/callbacks/test_spike.py index f61a6c59ca9db..86e3ac88e93cf 100644 --- a/tests/tests_pytorch/callbacks/test_spike.py +++ b/tests/tests_pytorch/callbacks/test_spike.py @@ -48,7 +48,7 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx) -@pytest.mark.flaky(max_runs=3) +@pytest.mark.flaky(reruns=3) @pytest.mark.parametrize( ("global_rank_spike", "num_devices", "spike_value", "finite_only"), # NOTE FOR ALL FOLLOWING TESTS: diff --git a/tests/tests_pytorch/loops/test_training_loop.py b/tests/tests_pytorch/loops/test_training_loop.py index e3a4c37f6a284..f5aaa18095fc5 100644 --- a/tests/tests_pytorch/loops/test_training_loop.py +++ b/tests/tests_pytorch/loops/test_training_loop.py @@ -13,12 +13,14 @@ # limitations under the License. import itertools import logging +import warnings from unittest.mock import Mock import pytest import torch from torch.utils.data import DataLoader +from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loops import _FitLoop @@ -277,3 +279,29 @@ def __iter__(self): # assert progress bar callback uses correct total steps assert pbar.train_progress_bar.total == max_steps + + +@pytest.mark.parametrize("warn", [True, False]) +def test_eval_mode_warning(tmp_path, warn): + """Test that a warning is raised if any module is in eval mode at the start of training.""" + model = BoringModel() + if warn: + model.some_eval_module = torch.nn.Linear(32, 16) + model.some_eval_module.eval() + + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=1, + ) + + if warn: + with pytest.warns(PossibleUserWarning): + trainer.fit(model) + else: + with warnings.catch_warnings(record=True) as warning_list: + warnings.simplefilter("always") + trainer.fit(model) + eval_warnings = [ + w for w in warning_list if issubclass(w.category, PossibleUserWarning) and "eval mode" in str(w.message) + ] + assert len(eval_warnings) == 0, "Expected no eval mode warnings" diff --git a/tests/tests_pytorch/plugins/precision/test_amp.py b/tests/tests_pytorch/plugins/precision/test_amp.py index cb061c540b2be..3894c4256e0b8 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp.py +++ b/tests/tests_pytorch/plugins/precision/test_amp.py @@ -14,6 +14,8 @@ from unittest.mock import Mock import pytest +import torch +from torch import nn from torch.optim import Optimizer from lightning.pytorch.plugins import MixedPrecision @@ -51,3 +53,19 @@ def test_optimizer_amp_scaling_support_in_step_method(): with pytest.raises(RuntimeError, match="The current optimizer.*does not allow for gradient clipping"): precision.clip_gradients(optimizer, clip_val=1.0) + + +def test_amp_with_no_grad(): + """Test that asserts using `no_grad` context wrapper with a persistent AMP context wrapper does not break gradient + tracking.""" + layer = nn.Linear(2, 1) + x = torch.randn(1, 2) + amp = MixedPrecision(precision="bf16-mixed", device="cpu") + + with amp.autocast_context_manager(): + with torch.no_grad(): + _ = layer(x) + + loss = layer(x).mean() + loss.backward() + assert loss.grad_fn is not None diff --git a/tests/tests_pytorch/strategies/test_ddp.py b/tests/tests_pytorch/strategies/test_ddp.py index 915e57440b40f..823d77d0d5848 100644 --- a/tests/tests_pytorch/strategies/test_ddp.py +++ b/tests/tests_pytorch/strategies/test_ddp.py @@ -20,6 +20,7 @@ from torch.nn.parallel import DistributedDataParallel from lightning.fabric.plugins.environments import LightningEnvironment +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.plugins import DoublePrecision, HalfPrecision, Precision @@ -132,8 +133,41 @@ def test_set_timeout(mock_init_process_group): process_group_backend = trainer.strategy._get_process_group_backend() global_rank = trainer.strategy.cluster_environment.global_rank() world_size = trainer.strategy.cluster_environment.world_size() + kwargs = {} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = trainer.strategy.root_device if trainer.strategy.root_device.type != "cpu" else None mock_init_process_group.assert_called_with( - process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta + process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs + ) + + +@mock.patch("torch.distributed.init_process_group") +def test_device_id_passed_for_cuda_devices_pytorch(mock_init_process_group): + """Test that device_id is passed to init_process_group for CUDA devices but not for CPU.""" + # Test with CPU device - device_id should be None + model = BoringModel() + ddp_strategy = DDPStrategy() + trainer = Trainer( + max_epochs=1, + accelerator="cpu", + strategy=ddp_strategy, + ) + trainer.strategy.connect(model) + trainer.lightning_module.trainer = trainer + trainer.strategy.setup_environment() + + process_group_backend = trainer.strategy._get_process_group_backend() + global_rank = trainer.strategy.cluster_environment.global_rank() + world_size = trainer.strategy.cluster_environment.world_size() + kwargs = {} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = trainer.strategy.root_device if trainer.strategy.root_device.type != "cpu" else None + mock_init_process_group.assert_called_with( + process_group_backend, + rank=global_rank, + world_size=world_size, + timeout=trainer.strategy._timeout, + **kwargs, ) diff --git a/tests/tests_pytorch/strategies/test_ddp_integration.py b/tests/tests_pytorch/strategies/test_ddp_integration.py index 048403366ebc7..fc3a8cfebbac0 100644 --- a/tests/tests_pytorch/strategies/test_ddp_integration.py +++ b/tests/tests_pytorch/strategies/test_ddp_integration.py @@ -66,7 +66,7 @@ def test_multi_gpu_model_ddp_fit_test(tmp_path): assert out["test_acc"] > 0.7 -@RunIf(skip_windows=True) +@RunIf(skip_windows=True, max_torch="2.7") @mock.patch("torch.cuda.set_device") @mock.patch("lightning.pytorch.accelerators.cuda._check_cuda_matmul_precision") @mock.patch("lightning.pytorch.accelerators.cuda._clear_cuda_memory") diff --git a/tests/tests_pytorch/strategies/test_deepspeed.py b/tests/tests_pytorch/strategies/test_deepspeed.py index 7e7d2eacd0617..503d1ea0e630b 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed.py +++ b/tests/tests_pytorch/strategies/test_deepspeed.py @@ -562,6 +562,46 @@ def test_deepspeed_multigpu_single_file(tmp_path): trainer.test(model, ckpt_path=checkpoint_path) +@RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) +def test_deepspeed_strategy_exclude_frozen_parameters_integration(tmp_path): + """Test end-to-end integration of exclude_frozen_parameters with actual model training and checkpointing.""" + + class TestModelWithFrozenParams(BoringModel): + def __init__(self): + super().__init__() + self.frozen_layer = torch.nn.Linear(32, 32) + + def configure_model(self) -> None: + super().configure_model() + # Freeze the additional layer parameters + for param in self.frozen_layer.parameters(): + param.requires_grad = False + + def forward(self, x): + x = self.frozen_layer(x) + return super().forward(x) + + model = TestModelWithFrozenParams() + + trainer = Trainer( + default_root_dir=tmp_path, + strategy=DeepSpeedStrategy(exclude_frozen_parameters=True), + accelerator="gpu", + devices=1, + fast_dev_run=True, + precision="16-mixed", + enable_progress_bar=False, + enable_model_summary=False, + ) + + trainer.fit(model) + checkpoint_path = os.path.join(tmp_path, "checkpoint_exclude_frozen.ckpt") + trainer.save_checkpoint(checkpoint_path) + + # Verify checkpoint was created + assert os.path.exists(checkpoint_path) + + class ModelParallelClassificationModel(LightningModule): def __init__(self, lr: float = 0.01, num_blocks: int = 5): super().__init__() diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 560ab19f823ca..f7c15b5930be8 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -18,7 +18,7 @@ from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies.fsdp import _is_sharded_checkpoint -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_3 from lightning.fabric.utilities.load import _load_distributed_checkpoint from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint @@ -532,8 +532,11 @@ def test_set_timeout(init_process_group_mock): process_group_backend = strategy._get_process_group_backend() global_rank = strategy.cluster_environment.global_rank() world_size = strategy.cluster_environment.world_size() + kwargs = {} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = strategy.root_device if strategy.root_device.type != "cpu" else None init_process_group_mock.assert_called_with( - process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta + process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs ) diff --git a/tests/tests_pytorch/strategies/test_model_parallel.py b/tests/tests_pytorch/strategies/test_model_parallel.py index 86a95944ac20d..c803c10afa4b4 100644 --- a/tests/tests_pytorch/strategies/test_model_parallel.py +++ b/tests/tests_pytorch/strategies/test_model_parallel.py @@ -22,6 +22,7 @@ import torch.nn as nn from lightning.fabric.strategies.model_parallel import _is_sharded_checkpoint +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 from lightning.pytorch import LightningModule from lightning.pytorch.plugins.environments import LightningEnvironment from lightning.pytorch.strategies import ModelParallelStrategy @@ -202,8 +203,11 @@ def test_set_timeout(init_process_group_mock, _): process_group_backend = strategy._get_process_group_backend() global_rank = strategy.cluster_environment.global_rank() world_size = strategy.cluster_environment.world_size() + kwargs = {} + if _TORCH_GREATER_EQUAL_2_3: + kwargs["device_id"] = strategy.root_device if strategy.root_device.type != "cpu" else None init_process_group_mock.assert_called_with( - process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta + process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta, **kwargs ) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 1b883dda0282a..248852f4cf1f3 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -17,7 +17,7 @@ import operator import os import sys -from contextlib import ExitStack, contextmanager, redirect_stdout +from contextlib import ExitStack, contextmanager, redirect_stderr, redirect_stdout from io import StringIO from pathlib import Path from typing import Callable, Optional, Union @@ -98,7 +98,7 @@ class _UnkArgError(Exception): def _raise(): raise _UnkArgError - parser = LightningArgumentParser(add_help=False, parse_as_dict=False) + parser = LightningArgumentParser(add_help=False) parser.add_lightning_class_args(Trainer, None) monkeypatch.setattr(parser, "exit", lambda *args: _raise(), raising=True) @@ -487,6 +487,45 @@ def test_lightning_cli_print_config(): assert outval["ckpt_path"] is None +class BoringCkptPathModel(BoringModel): + def __init__(self, out_dim: int = 2, hidden_dim: int = 2) -> None: + super().__init__() + self.save_hyperparameters() + self.layer = torch.nn.Linear(32, out_dim) + + +def test_lightning_cli_ckpt_path_argument_hparams(cleandir): + class CkptPathCLI(LightningCLI): + def add_arguments_to_parser(self, parser): + parser.link_arguments("model.out_dim", "model.hidden_dim", compute_fn=lambda x: x * 2) + + cli_args = ["fit", "--model.out_dim=3", "--trainer.max_epochs=1"] + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = CkptPathCLI(BoringCkptPathModel) + + assert cli.config.fit.model.out_dim == 3 + assert cli.config.fit.model.hidden_dim == 6 + hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml" + assert hparams_path.is_file() + hparams = yaml.safe_load(hparams_path.read_text()) + assert hparams["out_dim"] == 3 + assert hparams["hidden_dim"] == 6 + + checkpoint_path = next(Path(cli.trainer.log_dir, "checkpoints").glob("*.ckpt")) + cli_args = ["predict", f"--ckpt_path={checkpoint_path}"] + with mock.patch("sys.argv", ["any.py"] + cli_args): + cli = CkptPathCLI(BoringCkptPathModel) + + assert cli.config.predict.model.out_dim == 3 + assert cli.config.predict.model.hidden_dim == 6 + assert cli.config_init.predict.model.layer.out_features == 3 + + err = StringIO() + with mock.patch("sys.argv", ["any.py"] + cli_args), redirect_stderr(err), pytest.raises(SystemExit): + cli = LightningCLI(BoringModel) + assert "Parsing of ckpt_path hyperparameters failed" in err.getvalue() + + def test_lightning_cli_submodules(cleandir): class MainModule(BoringModel): def __init__(self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1): diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index 69575a351b0a5..81352ebe256ef 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import glob import logging import math import os @@ -750,3 +751,52 @@ def __init__(self): assert not torch.allclose(gradients, gradients_no_spacing, rtol=0.1), ( "Gradients should differ significantly in exponential mode when using proper spacing" ) + + +def test_lr_finder_checkpoint_cleanup_on_error(tmp_path): + """Test that temporary checkpoint files are cleaned up even when an error occurs during lr finding.""" + + class FailingModel(BoringModel): + def __init__(self, fail_on_step=2): + super().__init__() + self.fail_on_step = fail_on_step + self.current_step = 0 + self.learning_rate = 1e-3 + + def training_step(self, batch, batch_idx): + self.current_step += 1 + if self.current_step >= self.fail_on_step: + raise RuntimeError("Intentional failure for testing cleanup") + return super().training_step(batch, batch_idx) + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + model = FailingModel() + lr_finder = LearningRateFinder(num_training_steps=5) + + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=1, + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False, + logger=False, + callbacks=[lr_finder], + ) + + # Check no lr_find checkpoint files exist initially + lr_find_checkpoints = glob.glob(os.path.join(tmp_path, ".lr_find_*.ckpt")) + assert len(lr_find_checkpoints) == 0, "No lr_find checkpoint files should exist initially" + + # Run lr finder and expect it to fail + with pytest.raises(RuntimeError, match="Intentional failure for testing cleanup"): + trainer.fit(model) + + # Check that no lr_find checkpoint files are left behind + lr_find_checkpoints = glob.glob(os.path.join(tmp_path, ".lr_find_*.ckpt")) + assert len(lr_find_checkpoints) == 0, ( + f"lr_find checkpoint files should be cleaned up, but found: {lr_find_checkpoints}" + ) diff --git a/tests/tests_pytorch/tuner/test_scale_batch_size.py b/tests/tests_pytorch/tuner/test_scale_batch_size.py index e4ed533c6fa83..f0e5fbe6a3c49 100644 --- a/tests/tests_pytorch/tuner/test_scale_batch_size.py +++ b/tests/tests_pytorch/tuner/test_scale_batch_size.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import glob import logging import os from copy import deepcopy @@ -486,3 +487,49 @@ def test_batch_size_finder_callback_val_batches(tmp_path): assert trainer.num_val_batches[0] == len(trainer.val_dataloaders) assert trainer.num_val_batches[0] != steps_per_trial + + +def test_scale_batch_size_checkpoint_cleanup_on_error(tmp_path): + """Test that temporary checkpoint files are cleaned up even when an error occurs during batch size scaling.""" + + class FailingModel(BoringModel): + def __init__(self, fail_on_step=2): + super().__init__() + self.fail_on_step = fail_on_step + self.current_step = 0 + self.batch_size = 2 + + def training_step(self, batch, batch_idx): + self.current_step += 1 + if self.current_step >= self.fail_on_step: + raise RuntimeError("Intentional failure for testing cleanup") + return super().training_step(batch, batch_idx) + + def train_dataloader(self): + return DataLoader(RandomDataset(32, 64), batch_size=self.batch_size) + + model = FailingModel() + batch_size_finder = BatchSizeFinder(max_trials=3, steps_per_trial=2) + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=1, + enable_checkpointing=False, + enable_progress_bar=False, + enable_model_summary=False, + logger=False, + callbacks=[batch_size_finder], + ) + + # Check no scale_batch_size checkpoint files exist initially + scale_checkpoints = glob.glob(os.path.join(tmp_path, ".scale_batch_size_*.ckpt")) + assert len(scale_checkpoints) == 0, "No scale_batch_size checkpoint files should exist initially" + + # Run batch size scaler and expect it to fail + with pytest.raises(RuntimeError, match="Intentional failure for testing cleanup"): + trainer.fit(model) + + # Check that no scale_batch_size checkpoint files are left behind + scale_checkpoints = glob.glob(os.path.join(tmp_path, ".scale_batch_size_*.ckpt")) + assert len(scale_checkpoints) == 0, ( + f"scale_batch_size checkpoint files should be cleaned up, but found: {scale_checkpoints}" + )