diff --git a/.actions/requirements.txt b/.actions/requirements.txt index 44548c1e0a370..0742eeb8dc353 100644 --- a/.actions/requirements.txt +++ b/.actions/requirements.txt @@ -1,3 +1,3 @@ -jsonargparse >=4.16.0, <=4.35.0 +jsonargparse requests packaging diff --git a/.azure/gpu-benchmarks.yml b/.azure/gpu-benchmarks.yml index b77dbfc4f792a..045c0cd45ccb9 100644 --- a/.azure/gpu-benchmarks.yml +++ b/.azure/gpu-benchmarks.yml @@ -76,8 +76,13 @@ jobs: displayName: "Image info & NVIDIA" - bash: | - pip install -e .[dev] --find-links ${TORCH_URL} - pip install setuptools==75.6.0 jsonargparse==4.35.0 + pip install -U -q -r .actions/requirements.txt + python .actions/assistant.py copy_replace_imports --source_dir="./tests" \ + --source_import="lightning.fabric,lightning.pytorch" \ + --target_import="lightning_fabric,pytorch_lightning" + displayName: "Adjust tests" + + - bash: pip install -e .[dev] --find-links ${TORCH_URL} env: FREEZE_REQUIREMENTS: "1" displayName: "Install package" @@ -88,17 +93,10 @@ jobs: python -c "import torch ; mgpu = torch.cuda.device_count() ; assert mgpu == 2, f'GPU: {mgpu}'" displayName: "Env details" - - bash: | - pip install -q -r .actions/requirements.txt - python .actions/assistant.py copy_replace_imports --source_dir="./tests" \ - --source_import="lightning.fabric,lightning.pytorch" \ - --target_import="lightning_fabric,pytorch_lightning" - displayName: "Adjust tests" - - bash: python -m pytest parity_$(PACKAGE_NAME) -v --durations=0 env: PL_RUNNING_BENCHMARKS: "1" - PL_RUN_CUDA_TESTS: "1" + RUN_ONLY_CUDA_TESTS: "1" workingDirectory: tests/ displayName: "Testing: benchmarks" @@ -107,7 +105,7 @@ jobs: # without succeeded this could run even if the job has already failed condition: and(succeeded(), eq(variables['PACKAGE_NAME'], 'fabric')) env: - PL_RUN_CUDA_TESTS: "1" + RUN_ONLY_CUDA_TESTS: "1" PL_RUN_STANDALONE_TESTS: "1" displayName: "Testing: fabric standalone tasks" timeoutInMinutes: "10" diff --git a/.azure/gpu-tests-fabric.yml b/.azure/gpu-tests-fabric.yml index 583451fa0cdfa..c584f5bcbd3a2 100644 --- a/.azure/gpu-tests-fabric.yml +++ b/.azure/gpu-tests-fabric.yml @@ -48,7 +48,7 @@ jobs: DEVICES: $( python -c 'print("$(Agent.Name)".split("_")[-1])' ) FREEZE_REQUIREMENTS: "1" PIP_CACHE_DIR: "/var/tmp/pip" - PL_RUN_CUDA_TESTS: "1" + RUN_ONLY_CUDA_TESTS: "1" container: image: $(image) # default shm size is 64m. Increase it to avoid: @@ -60,13 +60,13 @@ jobs: image: "pytorchlightning/pytorch_lightning:base-cuda-py3.10-torch2.1-cuda12.1.1" PACKAGE_NAME: "fabric" "Fabric | latest": - image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.7-cuda12.6.3" + image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.8-cuda12.6.3" PACKAGE_NAME: "fabric" #"Fabric | future": # image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.7-cuda12.6.3" # PACKAGE_NAME: "fabric" "Lightning | latest": - image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.7-cuda12.6.3" + image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.8-cuda12.6.3" PACKAGE_NAME: "lightning" workspace: clean: all @@ -78,8 +78,6 @@ jobs: echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/cu${cuda_ver}/torch_stable.html" scope=$(python -c 'n = "$(PACKAGE_NAME)" ; print(dict(fabric="lightning_fabric").get(n, n))') echo "##vso[task.setvariable variable=COVERAGE_SOURCE]$scope" - python_ver=$(python -c "import sys; print(f'{sys.version_info.major}{sys.version_info.minor}')") - echo "##vso[task.setvariable variable=PYTHON_VERSION_MM]$python_ver" displayName: "set env. vars" - bash: | echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}" @@ -100,6 +98,12 @@ jobs: pip list displayName: "Image info & NVIDIA" + - bash: | + python .actions/assistant.py replace_oldest_ver + pip install "cython<3.0" wheel # for compatibility + condition: contains(variables['Agent.JobName'], 'oldest') + displayName: "setting oldest dependencies" + - bash: | PYTORCH_VERSION=$(python -c "import torch; print(torch.__version__.split('+')[0])") pip install -q wget packaging @@ -109,11 +113,22 @@ jobs: done displayName: "Adjust dependencies" + - bash: | + pip install -U -q -r .actions/requirements.txt + python .actions/assistant.py copy_replace_imports --source_dir="./tests/tests_fabric" \ + --source_import="lightning.fabric" \ + --target_import="lightning_fabric" + python .actions/assistant.py copy_replace_imports --source_dir="./examples/fabric" \ + --source_import="lightning.fabric" \ + --target_import="lightning_fabric" + # without succeeded this could run even if the job has already failed + condition: and(succeeded(), eq(variables['PACKAGE_NAME'], 'fabric')) + displayName: "Adjust tests & examples" + - bash: | set -e extra=$(python -c "print({'lightning': 'fabric-'}.get('$(PACKAGE_NAME)', ''))") - pip install -e ".[${extra}dev]" pytest-timeout -U --extra-index-url="${TORCH_URL}" - pip install setuptools==75.6.0 jsonargparse==4.35.0 + pip install -e ".[${extra}dev]" -U --upgrade-strategy=eager --extra-index-url="${TORCH_URL}" displayName: "Install package & dependencies" - bash: | @@ -130,18 +145,6 @@ jobs: condition: and(succeeded(), eq(variables['PACKAGE_NAME'], 'fabric')) displayName: "Testing: Fabric doctests" - - bash: | - pip install -q -r .actions/requirements.txt - python .actions/assistant.py copy_replace_imports --source_dir="./tests/tests_fabric" \ - --source_import="lightning.fabric" \ - --target_import="lightning_fabric" - python .actions/assistant.py copy_replace_imports --source_dir="./examples/fabric" \ - --source_import="lightning.fabric" \ - --target_import="lightning_fabric" - # without succeeded this could run even if the job has already failed - condition: and(succeeded(), eq(variables['PACKAGE_NAME'], 'fabric')) - displayName: "Adjust tests & examples" - - bash: python -m coverage run --source ${COVERAGE_SOURCE} -m pytest tests_fabric/ -v --durations=50 workingDirectory: tests/ displayName: "Testing: fabric standard" diff --git a/.azure/gpu-tests-pytorch.yml b/.azure/gpu-tests-pytorch.yml index eb76cd49e3f94..16ac6beb34841 100644 --- a/.azure/gpu-tests-pytorch.yml +++ b/.azure/gpu-tests-pytorch.yml @@ -53,20 +53,20 @@ jobs: image: "pytorchlightning/pytorch_lightning:base-cuda-py3.10-torch2.1-cuda12.1.1" PACKAGE_NAME: "pytorch" "PyTorch | latest": - image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.7-cuda12.6.3" + image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.8-cuda12.6.3" PACKAGE_NAME: "pytorch" #"PyTorch | future": # image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.7-cuda12.6.3" # PACKAGE_NAME: "pytorch" "Lightning | latest": - image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.7-cuda12.6.3" + image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.8-cuda12.6.3" PACKAGE_NAME: "lightning" pool: lit-rtx-3090 variables: DEVICES: $( python -c 'print("$(Agent.Name)".split("_")[-1])' ) FREEZE_REQUIREMENTS: "1" PIP_CACHE_DIR: "/var/tmp/pip" - PL_RUN_CUDA_TESTS: "1" + RUN_ONLY_CUDA_TESTS: "1" container: image: $(image) # default shm size is 64m. Increase it to avoid: @@ -82,8 +82,6 @@ jobs: echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/cu${cuda_ver}/torch_stable.html" scope=$(python -c 'n = "$(PACKAGE_NAME)" ; print(dict(pytorch="pytorch_lightning").get(n, n))') echo "##vso[task.setvariable variable=COVERAGE_SOURCE]$scope" - python_ver=$(python -c "import sys; print(f'{sys.version_info.major}{sys.version_info.minor}')") - echo "##vso[task.setvariable variable=PYTHON_VERSION_MM]$python_ver" displayName: "set env. vars" - bash: | echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}" @@ -104,6 +102,12 @@ jobs: pip list displayName: "Image info & NVIDIA" + - bash: | + python .actions/assistant.py replace_oldest_ver + pip install "cython<3.0" wheel # for compatibility + condition: contains(variables['Agent.JobName'], 'oldest') + displayName: "setting oldest dependencies" + - bash: | PYTORCH_VERSION=$(python -c "import torch; print(torch.__version__.split('+')[0])") pip install -q wget packaging @@ -113,10 +117,22 @@ jobs: done displayName: "Adjust dependencies" + - bash: | + pip install -U -q -r .actions/requirements.txt + python .actions/assistant.py copy_replace_imports --source_dir="./tests/tests_pytorch" \ + --source_import="lightning.fabric,lightning.pytorch" \ + --target_import="lightning_fabric,pytorch_lightning" + python .actions/assistant.py copy_replace_imports --source_dir="./examples/pytorch/basics" \ + --source_import="lightning.fabric,lightning.pytorch" \ + --target_import="lightning_fabric,pytorch_lightning" + # without succeeded this could run even if the job has already failed + condition: and(succeeded(), eq(variables['PACKAGE_NAME'], 'pytorch')) + displayName: "Adjust tests & examples" + - bash: | set -e extra=$(python -c "print({'lightning': 'pytorch-'}.get('$(PACKAGE_NAME)', ''))") - pip install -e ".[${extra}dev]" pytest-timeout -U --extra-index-url="${TORCH_URL}" + pip install -e ".[${extra}dev]" -U --upgrade-strategy=eager --extra-index-url="${TORCH_URL}" displayName: "Install package & dependencies" - bash: pip uninstall -y lightning @@ -143,17 +159,6 @@ jobs: condition: and(succeeded(), eq(variables['PACKAGE_NAME'], 'pytorch')) displayName: "Testing: PyTorch doctests" - - bash: | - python .actions/assistant.py copy_replace_imports --source_dir="./tests/tests_pytorch" \ - --source_import="lightning.fabric,lightning.pytorch" \ - --target_import="lightning_fabric,pytorch_lightning" - python .actions/assistant.py copy_replace_imports --source_dir="./examples/pytorch/basics" \ - --source_import="lightning.fabric,lightning.pytorch" \ - --target_import="lightning_fabric,pytorch_lightning" - # without succeeded this could run even if the job has already failed - condition: and(succeeded(), eq(variables['PACKAGE_NAME'], 'pytorch')) - displayName: "Adjust tests & examples" - - bash: | bash .actions/pull_legacy_checkpoints.sh cd tests/legacy diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index cfb03d220c99c..99d354faf9f76 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -3,7 +3,7 @@ Welcome to the PyTorch Lightning community! We're building the most advanced research platform on the planet to implement the latest, best practices and integrations that the amazing PyTorch team and other research organization rolls out! -If you are new to open source, check out [this blog to get started with your first Open Source contribution](https://devblog.pytorchlightning.ai/quick-contribution-guide-86d977171b3a). +If you are new to open source, check out [this blog to get started with your first Open Source contribution](https://medium.com/pytorch-lightning/quick-contribution-guide-86d977171b3a). ## Main Core Value: One less thing to remember @@ -109,6 +109,50 @@ ______________________________________________________________________ ## Guidelines +### Development environment + +To set up a local development environment, we recommend using `uv`, which can be installed following their [instructions](https://docs.astral.sh/uv/getting-started/installation/). + +Once `uv` has been installed, begin by cloning the forked repository: + +```bash +git clone https://github.com/{YOUR_GITHUB_USERNAME}/pytorch-lightning.git +cd pytorch-lightning +``` + +> If you're using [Lightning Studio](https://lightning.ai) or already have your `uv venv` activated, you can quickly set up the project by running: + +```bash +make setup +``` + +This will: + +- Install all required dependencies. +- Perform an editable install of the `pytorch-lightning` project. +- Install and configure `pre-commit`. + +#### Manual Setup (Optional) + +If you prefer more fine-grained control over the dependencies, you can set up the environment manually: + +```bash +uv venv +# uv venv --python 3.11 # use this instead if you need a specific python version + +source .venv/bin/activate # command may differ based on your shell +uv pip install ".[dev, examples]" +``` + +Once the dependencies have been installed, install pre-commit and set up the git hook scripts: + +```bash +uv pip install pre-commit +pre-commit install +``` + +If you would like more information regarding the uv commands, please refer to uv's documentation for more information on their [pip interface](https://docs.astral.sh/uv/pip/). + ### Developments scripts To build the documentation locally, simply execute the following commands from project root (only for Unix): diff --git a/.github/checkgroup.yml b/.github/checkgroup.yml index 5564879229f08..78695257e2884 100644 --- a/.github/checkgroup.yml +++ b/.github/checkgroup.yml @@ -71,7 +71,7 @@ subprojects: # paths: # # tpu CI availability is very limited, so we only require tpu tests # # to pass when their configurations are modified - # - ".github/workflows/tpu-tests.yml" + # - ".github/workflows/tpu-tests.yml.disabled" # - "tests/tests_pytorch/run_tpu_tests.sh" # checks: # - "test-on-tpus (pytorch, pjrt, v4-8)" @@ -135,7 +135,8 @@ subprojects: - "build-pl (3.11, 2.4, 12.1.1)" - "build-pl (3.12, 2.5, 12.1.1)" - "build-pl (3.12, 2.6, 12.4.1)" - - "build-pl (3.12, 2.7, 12.6.3, true)" + - "build-pl (3.12, 2.7, 12.6.3)" + - "build-pl (3.12, 2.8, 12.6.3, true)" # SECTION: lightning_fabric @@ -181,7 +182,7 @@ subprojects: # paths: # # tpu CI availability is very limited, so we only require tpu tests # # to pass when their configurations are modified - # - ".github/workflows/tpu-tests.yml" + # - ".github/workflows/tpu-tests.yml.disabled" # - "tests/tests_fabric/run_tpu_tests.sh" # checks: # - "test-on-tpus (pytorch, pjrt, v4-8)" diff --git a/.github/markdown-links-config.json b/.github/markdown-links-config.json index bc9721da2c587..c8182939c97cf 100644 --- a/.github/markdown-links-config.json +++ b/.github/markdown-links-config.json @@ -22,5 +22,9 @@ "Accept-Encoding": "zstd, br, gzip, deflate" } } - ] + ], + "timeout": "20s", + "retryOn429": true, + "retryCount": 5, + "fallbackRetryDelay": "20s" } diff --git a/.github/workflows/_build-packages.yml b/.github/workflows/_build-packages.yml index cf6ed5379801b..78035470059d1 100644 --- a/.github/workflows/_build-packages.yml +++ b/.github/workflows/_build-packages.yml @@ -51,7 +51,7 @@ jobs: needs: build-packages runs-on: ubuntu-22.04 steps: - - uses: actions/download-artifact@v4 + - uses: actions/download-artifact@v5 with: # download all build artifacts pattern: ${{ inputs.artifact-name }}-* merge-multiple: true diff --git a/.github/workflows/_legacy-checkpoints.yml b/.github/workflows/_legacy-checkpoints.yml index 6cf29c0bd19ad..03fadc1247f16 100644 --- a/.github/workflows/_legacy-checkpoints.yml +++ b/.github/workflows/_legacy-checkpoints.yml @@ -149,8 +149,7 @@ jobs: title: Adding test for legacy checkpoint created with ${{ env.PL_VERSION }} committer: GitHub author: ${{ github.actor }} <${{ github.actor }}@users.noreply.github.com> - commit-message: "update tutorials to `${{ env.PL_VERSION }}`" - body: "**This is automated addition of created checkpoints with the latest `lightning` release!**" + body: "**This is automated addition of created checkpoint with the latest `${{ env.PL_VERSION }}` `lightning` release!**" delete-branch: true token: ${{ secrets.PAT_GHOST }} labels: | @@ -158,4 +157,3 @@ jobs: tests pl assignees: borda - reviewers: borda diff --git a/.github/workflows/call-clear-cache.yml b/.github/workflows/call-clear-cache.yml index 6422e856e09ff..06c69fda66794 100644 --- a/.github/workflows/call-clear-cache.yml +++ b/.github/workflows/call-clear-cache.yml @@ -23,18 +23,18 @@ on: jobs: cron-clear: if: github.event_name == 'schedule' || github.event_name == 'pull_request' - uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.14.3 + uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.15.2 with: - scripts-ref: v0.14.3 + scripts-ref: v0.15.2 dry-run: ${{ github.event_name == 'pull_request' }} pattern: "latest|docs" age-days: 7 direct-clear: if: github.event_name == 'workflow_dispatch' || github.event_name == 'pull_request' - uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.14.3 + uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.15.2 with: - scripts-ref: v0.14.3 + scripts-ref: v0.15.2 dry-run: ${{ github.event_name == 'pull_request' }} pattern: ${{ inputs.pattern || 'pypi_wheels' }} # setting str in case of PR / debugging age-days: ${{ fromJSON(inputs.age-days) || 0 }} # setting 0 in case of PR / debugging diff --git a/.github/workflows/ci-pkg-install.yml b/.github/workflows/ci-pkg-install.yml index 6e38c26f4174e..7a7f4bfdcf955 100644 --- a/.github/workflows/ci-pkg-install.yml +++ b/.github/workflows/ci-pkg-install.yml @@ -50,7 +50,7 @@ jobs: - uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - uses: actions/download-artifact@v4 + - uses: actions/download-artifact@v5 with: name: dist-packages-${{ github.sha }} path: dist diff --git a/.github/workflows/ci-schema.yml b/.github/workflows/ci-schema.yml index fe8cbfbc7ddb4..6594e08158483 100644 --- a/.github/workflows/ci-schema.yml +++ b/.github/workflows/ci-schema.yml @@ -8,7 +8,7 @@ on: jobs: check: - uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.14.3 + uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.15.2 with: # skip azure due to the wrong schema file by MSFT # https://github.com/Lightning-AI/lightning-flash/pull/1455#issuecomment-1244793607 diff --git a/.github/workflows/ci-tests-fabric.yml b/.github/workflows/ci-tests-fabric.yml index 4e4d2c9eed3cb..c8b6d1e71a910 100644 --- a/.github/workflows/ci-tests-fabric.yml +++ b/.github/workflows/ci-tests-fabric.yml @@ -64,13 +64,13 @@ jobs: - { os: "ubuntu-22.04", pkg-name: "fabric", pytorch-version: "2.1", requires: "oldest" } - { os: "windows-2022", pkg-name: "fabric", pytorch-version: "2.1", requires: "oldest" } # "fabric" installs the standalone package - - { os: "macOS-14", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.5" } - - { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.5" } - - { os: "windows-2022", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.5" } + - { os: "macOS-14", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.7" } + - { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.7" } + - { os: "windows-2022", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.7" } # adding recently cut Torch 2.7 - FUTURE - - { os: "macOS-14", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.7" } - - { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.7" } - - { os: "windows-2022", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.7" } + - { os: "macOS-14", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.8" } + - { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.8" } + - { os: "windows-2022", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.8" } timeout-minutes: 25 # because of building grpcio on Mac env: PACKAGE_NAME: ${{ matrix.pkg-name }} @@ -138,7 +138,8 @@ jobs: - name: Install package & dependencies timeout-minutes: 20 run: | - pip install -e ".[${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies]" -U --prefer-binary \ + pip install -e ".[${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies]" \ + -U --upgrade-strategy=eager --prefer-binary \ --extra-index-url="${TORCH_URL}" --find-links="${PYPI_CACHE_DIR}" pip list - name: Dump handy wheels @@ -167,7 +168,7 @@ jobs: run: | echo $GITHUB_RUN_ID python -m coverage run --source ${{ env.COVERAGE_SCOPE }} \ - -m pytest -v --timeout=30 --durations=50 --random-order-seed=$GITHUB_RUN_ID \ + -m pytest -v --timeout=60 --durations=50 --random-order-seed=$GITHUB_RUN_ID \ --junitxml=junit.xml -o junit_family=legacy # NOTE: for Codecov's test results - name: Statistics diff --git a/.github/workflows/ci-tests-pytorch.yml b/.github/workflows/ci-tests-pytorch.yml index d295d5475942a..72a966812397e 100644 --- a/.github/workflows/ci-tests-pytorch.yml +++ b/.github/workflows/ci-tests-pytorch.yml @@ -68,13 +68,13 @@ jobs: - { os: "ubuntu-22.04", pkg-name: "pytorch", pytorch-version: "2.1", requires: "oldest" } - { os: "windows-2022", pkg-name: "pytorch", pytorch-version: "2.1", requires: "oldest" } # "pytorch" installs the standalone package - - { os: "macOS-14", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.5" } - - { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.5" } - - { os: "windows-2022", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.5" } + - { os: "macOS-14", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.7" } + - { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.7" } + - { os: "windows-2022", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.7" } # adding recently cut Torch 2.7 - FUTURE - - { os: "macOS-14", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.7" } - - { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.7" } - - { os: "windows-2022", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.7" } + - { os: "macOS-14", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.8" } + - { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.8" } + - { os: "windows-2022", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.8" } timeout-minutes: 50 env: PACKAGE_NAME: ${{ matrix.pkg-name }} @@ -136,7 +136,8 @@ jobs: - name: Install package & dependencies timeout-minutes: 20 run: | - pip install ".[${EXTRA_PREFIX}extra,${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies]" -U --prefer-binary \ + pip install ".[${EXTRA_PREFIX}extra,${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies]" \ + -U --upgrade-strategy=eager --prefer-binary \ -r requirements/_integrations/accelerators.txt \ --extra-index-url="${TORCH_URL}" --find-links="${PYPI_CACHE_DIR}" pip list @@ -196,7 +197,7 @@ jobs: run: | echo $GITHUB_RUN_ID python -m coverage run --source ${{ env.COVERAGE_SCOPE }} \ - -m pytest . -v --timeout=60 --durations=50 --random-order-seed=$GITHUB_RUN_ID \ + -m pytest . -v --timeout=90 --durations=50 --random-order-seed=$GITHUB_RUN_ID \ --junitxml=junit.xml -o junit_family=legacy # NOTE: for Codecov's test results - name: Statistics diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index bb98466f30f72..93ff401f60f5f 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -49,7 +49,8 @@ jobs: - { python_version: "3.11", pytorch_version: "2.4", cuda_version: "12.1.1" } - { python_version: "3.12", pytorch_version: "2.5", cuda_version: "12.1.1" } - { python_version: "3.12", pytorch_version: "2.6", cuda_version: "12.4.1" } - - { python_version: "3.12", pytorch_version: "2.7", cuda_version: "12.6.3", latest: "true" } + - { python_version: "3.12", pytorch_version: "2.7", cuda_version: "12.6.3" } + - { python_version: "3.12", pytorch_version: "2.8", cuda_version: "12.6.3", latest: "true" } steps: - uses: actions/checkout@v4 with: @@ -97,7 +98,7 @@ jobs: # adding some more images as Thunder mainly using python 3.10, # and we need to support integrations as for example LitGPT python_version: ["3.10"] - pytorch_version: ["2.6.0", "2.7.1"] + pytorch_version: ["2.7.1", "2.8.0"] cuda_version: ["12.6.3"] include: # These are the base images for PL release docker images. @@ -109,6 +110,7 @@ jobs: - { python_version: "3.12", pytorch_version: "2.5.1", cuda_version: "12.1.1" } - { python_version: "3.12", pytorch_version: "2.6.0", cuda_version: "12.4.1" } - { python_version: "3.12", pytorch_version: "2.7.1", cuda_version: "12.6.3" } + - { python_version: "3.12", pytorch_version: "2.8.0", cuda_version: "12.6.3" } steps: - uses: actions/checkout@v4 - uses: docker/setup-buildx-action@v3 @@ -129,9 +131,10 @@ jobs: PYTHON_VERSION=${{ matrix.python_version }} PYTORCH_VERSION=${{ matrix.pytorch_version }} CUDA_VERSION=${{ matrix.cuda_version }} + MAKE_FLAGS="-j2" file: dockers/base-cuda/Dockerfile push: ${{ env.PUSH_NIGHTLY }} - tags: "pytorchlightning/pytorch_lightning:base-cuda-py${{ matrix.python_version }}-torch${{ env.PT_VERSION }}-cuda${{ matrix.cuda_version }}" + tags: "pytorchlightning/pytorch_lightning:base-cuda${{ matrix.cuda_version }}-py${{ matrix.python_version }}-torch${{ env.PT_VERSION }}" timeout-minutes: 95 - uses: ravsamhq/notify-slack-action@v2 if: failure() && env.PUSH_NIGHTLY == 'true' @@ -157,6 +160,8 @@ jobs: continue-on-error: true uses: docker/build-push-action@v6 with: + build-args: | + PYTORCH_VERSION="25.04" file: dockers/nvidia/Dockerfile push: false timeout-minutes: 55 diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index ec23d282710b5..8f6deeb189773 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -159,7 +159,7 @@ jobs: # use input if dispatch or git tag VERSION: ${{ inputs.version || github.ref_name }} steps: - - uses: actions/download-artifact@v4 + - uses: actions/download-artifact@v5 with: name: docs-${{ matrix.pkg-name }}-${{ github.sha }} path: docs/build/html/ diff --git a/.github/workflows/docs-tutorials.yml b/.github/workflows/docs-tutorials.yml index 6a768c6b9ffd9..e6e1f755484fd 100644 --- a/.github/workflows/docs-tutorials.yml +++ b/.github/workflows/docs-tutorials.yml @@ -55,7 +55,6 @@ jobs: author: ${{ github.actor }} <${{ github.actor }}@users.noreply.github.com> token: ${{ secrets.PAT_GHOST }} add-paths: _notebooks - commit-message: "update tutorials to `${{ env.SHA_LATEST }}`" branch: "docs/update-tutorials" # Delete the branch when closing pull requests, and when undeleted after merging. delete-branch: true @@ -72,4 +71,3 @@ jobs: docs examples assignees: borda - reviewers: borda diff --git a/.github/workflows/probot-check-group.yml b/.github/workflows/probot-check-group.yml index eef41cfcf777b..2a4fbcf0c87f7 100644 --- a/.github/workflows/probot-check-group.yml +++ b/.github/workflows/probot-check-group.yml @@ -14,7 +14,7 @@ jobs: if: github.event.pull_request.draft == false timeout-minutes: 61 # in case something is wrong with the internal timeout steps: - - uses: Lightning-AI/probot@v5.4 + - uses: Lightning-AI/probot@v5.5 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: diff --git a/.github/workflows/release-nightly.yml b/.github/workflows/release-nightly.yml index 396e485b90065..24d1a07f9abbc 100644 --- a/.github/workflows/release-nightly.yml +++ b/.github/workflows/release-nightly.yml @@ -54,7 +54,7 @@ jobs: PKG_NAME: "lightning" steps: - uses: actions/checkout@v4 # needed to use local composite action - - uses: actions/download-artifact@v4 + - uses: actions/download-artifact@v5 with: name: nightly-packages-${{ github.sha }} path: dist diff --git a/.github/workflows/release-pkg.yml b/.github/workflows/release-pkg.yml index 348d4ce753117..fa2a499f4abe2 100644 --- a/.github/workflows/release-pkg.yml +++ b/.github/workflows/release-pkg.yml @@ -38,7 +38,7 @@ jobs: if: github.event_name == 'release' steps: - uses: actions/checkout@v4 - - uses: actions/download-artifact@v4 + - uses: actions/download-artifact@v5 with: name: dist-packages-${{ github.sha }} path: dist @@ -140,7 +140,7 @@ jobs: name: ["FABRIC", "PYTORCH", "LIGHTNING"] steps: - uses: actions/checkout@v4 # needed for local action below - - uses: actions/download-artifact@v4 + - uses: actions/download-artifact@v5 with: name: dist-packages-${{ github.sha }} path: dist @@ -165,7 +165,7 @@ jobs: name: ["FABRIC", "PYTORCH", "LIGHTNING"] steps: - uses: actions/checkout@v4 # needed for local action below - - uses: actions/download-artifact@v4 + - uses: actions/download-artifact@v5 with: name: dist-packages-${{ github.sha }} path: dist diff --git a/.github/workflows/tpu-tests.yml b/.github/workflows/tpu-tests.yml.disabled similarity index 99% rename from .github/workflows/tpu-tests.yml rename to .github/workflows/tpu-tests.yml.disabled index 8d9d277cac26b..50d269dfb9f5f 100644 --- a/.github/workflows/tpu-tests.yml +++ b/.github/workflows/tpu-tests.yml.disabled @@ -165,7 +165,7 @@ jobs: gcloud compute tpus tpu-vm list - name: Upload coverage to Codecov - uses: codecov/codecov-action@v4 + uses: codecov/codecov-action@v5 continue-on-error: true with: token: ${{ secrets.CODECOV_TOKEN }} diff --git a/.lightning/workflows/fabric.yml b/.lightning/workflows/fabric.yml new file mode 100644 index 0000000000000..edaf0837fe79e --- /dev/null +++ b/.lightning/workflows/fabric.yml @@ -0,0 +1,110 @@ +trigger: + push: + branches: ["master"] + pull_request: + branches: ["master"] + +timeout: "75" # minutes +machine: "L4_X_2" +parametrize: + matrix: {} + include: + # note that this is setting also all oldest requirements which is linked to Torch == 2.0 + - image: "pytorchlightning/pytorch_lightning:base-cuda-py3.10-torch2.1-cuda12.1.1" + PACKAGE_NAME: "fabric" + - image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.7-cuda12.6.3" + PACKAGE_NAME: "fabric" + # - image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.7-cuda12.6.3" + # PACKAGE_NAME: "fabric" + - image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.7-cuda12.6.3" + PACKAGE_NAME: "lightning" + exclude: [] + +env: + FREEZE_REQUIREMENTS: "1" + RUN_ONLY_CUDA_TESTS: "1" + +run: | + whereis nvidia + nvidia-smi + python --version + pip --version + pip install -q fire wget packaging + set -ex + + CUDA_VERSION="${image##*cuda}" # Remove everything up to and including "cuda" + echo "Using CUDA version: ${CUDA_VERSION}" + CUDA_VERSION_M_M="${cuda_version%.*}" # Get major.minor by removing the last dot and everything after + CUDA_VERSION_MM="${CUDA_VERSION_M_M//'.'/''}" + TORCH_URL="https://download.pytorch.org/whl/cu${CUDA_VERSION_MM}/torch_stable.html" + echo "Torch URL: ${TORCH_URL}" + COVERAGE_SOURCE=$(python -c 'n = "$(PACKAGE_NAME)" ; print(dict(fabric="lightning_fabric").get(n, n))') + echo "collecting coverage for: ${COVERAGE_SOURCE}" + + if [ "${TORCH_VER}" == "2.1" ]; then + echo "Set oldest versions" + python .actions/assistant.py replace_oldest_ver + pip install "cython<3.0" wheel # for compatibility + fi + + echo "Adjust torch versions in requirements files" + PYTORCH_VERSION=$(python -c "import torch; print(torch.__version__.split('+')[0])") + pip install -q wget packaging + python -m wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/adjust-torch-versions.py + for fpath in `ls requirements/**/*.txt`; do \ + python ./adjust-torch-versions.py $fpath ${PYTORCH_VERSION}; \ + done + + if [ "${PACKAGE_NAME}" == "fabric" ]; then + echo "Replaced PL imports" + pip install -U -q -r .actions/requirements.txt + python .actions/assistant.py copy_replace_imports --source_dir="./tests/tests_fabric" \ + --source_import="lightning.fabric" \ + --target_import="lightning_fabric" + python .actions/assistant.py copy_replace_imports --source_dir="./examples/fabric" \ + --source_import="lightning.fabric" \ + --target_import="lightning_fabric" + fi + + extra=$(python -c "print({'lightning': 'fabric-'}.get('$(PACKAGE_NAME)', ''))") + pip install -e ".[${extra}dev]" -U --upgrade-strategy=eager --extra-index-url="${TORCH_URL}" + + python requirements/collect_env_details.py + python -c "import torch ; mgpu = torch.cuda.device_count() ; assert mgpu >= 2, f'GPU: {mgpu}'" + python requirements/pytorch/check-avail-extras.py + python -c "import bitsandbytes" + + echo "Testing: Fabric doctests" + if [ "${PACKAGE_NAME}" == "fabric" ]; then + cd src/ + python -m pytest lightning_fabric + cd .. + fi + + cd tests/ + echo "Testing: fabric standard" + python -m coverage run --source ${COVERAGE_SOURCE} -m pytest tests_fabric/ -v --durations=50 + + echo "Testing: fabric standalone" + export PL_RUN_STANDALONE_TESTS=1 + wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/run_standalone_tests.sh + bash ./run_standalone_tests.sh "tests_fabric" + + # echo "Reporting coverage" # todo + # python -m coverage report + # python -m coverage xml + # python -m coverage html + + # TODO: enable coverage + # # https://docs.codecov.com/docs/codecov-uploader + # curl -Os https://uploader.codecov.io/latest/linux/codecov + # chmod +x codecov + # ./codecov --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) \ + # --flags=gpu,pytest,${COVERAGE_SOURCE} --name="GPU-coverage" --env=linux,azure + # ls -l + cd .. + + echo "Testing: fabric examples" + cd examples/ + bash run_fabric_examples.sh --accelerator=cuda --devices=1 + bash run_fabric_examples.sh --accelerator=cuda --devices=2 --strategy ddp diff --git a/.lightning/workflows/pytorch.yml b/.lightning/workflows/pytorch.yml new file mode 100644 index 0000000000000..81063c3699769 --- /dev/null +++ b/.lightning/workflows/pytorch.yml @@ -0,0 +1,131 @@ +trigger: + push: + branches: ["master"] + pull_request: + branches: ["master"] + +timeout: "75" # minutes +machine: "L4_X_2" +parametrize: + matrix: {} + include: + # note that this is setting also all oldest requirements which is linked to Torch == 2.0 + - image: "pytorchlightning/pytorch_lightning:base-cuda-py3.10-torch2.1-cuda12.1.1" + PACKAGE_NAME: "pytorch" + - image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.7-cuda12.6.3" + PACKAGE_NAME: "pytorch" + # - image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.7-cuda12.6.3" + # PACKAGE_NAME: "pytorch" + - image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.7-cuda12.6.3" + PACKAGE_NAME: "lightning" + exclude: [] + +env: + FREEZE_REQUIREMENTS: "1" + RUN_ONLY_CUDA_TESTS: "1" + +run: | + whereis nvidia + nvidia-smi + python --version + pip --version + pip install -q fire wget packaging + set -ex + + CUDA_VERSION="${image##*cuda}" # Remove everything up to and including "cuda" + echo "Using CUDA version: ${CUDA_VERSION}" + CUDA_VERSION_M_M="${cuda_version%.*}" # Get major.minor by removing the last dot and everything after + CUDA_VERSION_MM="${CUDA_VERSION_M_M//'.'/''}" + TORCH_URL="https://download.pytorch.org/whl/cu${CUDA_VERSION_MM}/torch_stable.html" + echo "Torch URL: ${TORCH_URL}" + COVERAGE_SOURCE=$(python -c 'n = "$(PACKAGE_NAME)" ; print(dict(fabric="pytorch_lightning").get(n, n))') + echo "collecting coverage for: ${COVERAGE_SOURCE}" + + if [ "${TORCH_VER}" == "2.1" ]; then + recho "Set oldest versions" + python .actions/assistant.py replace_oldest_ver + pip install "cython<3.0" wheel # for compatibility + fi + + echo "Adjust torch versions in requirements files" + PYTORCH_VERSION=$(python -c "import torch; print(torch.__version__.split('+')[0])") + pip install -q wget packaging + python -m wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/adjust-torch-versions.py + for fpath in `ls requirements/**/*.txt`; do \ + python ./adjust-torch-versions.py $fpath ${PYTORCH_VERSION}; \ + done + + if [ "${PACKAGE_NAME}" == "pytorch" ]; then + echo "Adjust PL imports" + pip install -U -q -r .actions/requirements.txt + python .actions/assistant.py copy_replace_imports --source_dir="./tests/tests_pytorch" \ + --source_import="lightning.fabric,lightning.pytorch" \ + --target_import="lightning_fabric,pytorch_lightning" + python .actions/assistant.py copy_replace_imports --source_dir="./examples/pytorch/basics" \ + --source_import="lightning.fabric,lightning.pytorch" \ + --target_import="lightning_fabric,pytorch_lightning" + fi + + extra=$(python -c "print({'lightning': 'pytorch-'}.get('$(PACKAGE_NAME)', ''))") + pip install -e ".[${extra}dev]" -U --upgrade-strategy=eager --extra-index-url="${TORCH_URL}" + + if [ "${PACKAGE_NAME}" == "pytorch" ]; then + echo "uninstall lightning to have just single package" + pip uninstall -y lightning + elif [ "${PACKAGE_NAME}" == "lightning" ]; then + echo "uninstall PL to have just single package" + pip uninstall -y pytorch-lightning + fi + + python requirements/collect_env_details.py + python -c "import torch ; mgpu = torch.cuda.device_count() ; assert mgpu >= 2, f'GPU: {mgpu}'" + python requirements/pytorch/check-avail-extras.py + python -c "import bitsandbytes" + + echo "Testing: Pytorch doctests" + if [ "${PACKAGE_NAME}" == "pytorch" ]; then + cd src/ + python -m pytest pytorch_lightning + cd .. + fi + + echo "Get legacy checkpoints" + bash .actions/pull_legacy_checkpoints.sh + cd tests/legacy + # bash generate_checkpoints.sh + ls -lh checkpoints/ + cd ../.. + + cd tests/ + echo "Testing: fabric standard" + python -m coverage run --source ${COVERAGE_SOURCE} -m pytest tests_pytorch/ -v --durations=50 + + echo "Testing: fabric standalone" + export PL_USE_MOCKED_MNIST=1 + export PL_RUN_STANDALONE_TESTS=1 + wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/run_standalone_tests.sh + bash ./run_standalone_tests.sh "tests_pytorch" + + echo "Testing: PyTorch standalone tasks" + cd tests_pytorch/ + bash run_standalone_tasks.sh + + # echo "Reporting coverage" # todo + # python -m coverage report + # python -m coverage xml + # python -m coverage html + + # TODO: enable coverage + # # https://docs.codecov.com/docs/codecov-uploader + # curl -Os https://uploader.codecov.io/latest/linux/codecov + # chmod +x codecov + # ./codecov --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) \ + # --flags=gpu,pytest,${COVERAGE_SOURCE} --name="GPU-coverage" --env=linux,azure + # ls -l + cd ../.. + + echo "Testing: PyTorch examples" + cd examples/ + bash run_pl_examples.sh --trainer.accelerator=gpu --trainer.devices=1 + bash run_pl_examples.sh --trainer.accelerator=gpu --trainer.devices=2 --trainer.strategy=ddp + bash run_pl_examples.sh --trainer.accelerator=gpu --trainer.devices=2 --trainer.strategy=ddp --trainer.precision=16 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4723638fc5e4a..9b9057d794ce1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -58,7 +58,7 @@ repos: #args: ["--write-changes"] # uncomment if you want to get automatic fixing - repo: https://github.com/PyCQA/docformatter - rev: 06907d0267368b49b9180eed423fae5697c1e909 # todo: fix for docformatter after last 1.7.5 + rev: v1.7.7 hooks: - id: docformatter additional_dependencies: [tomli] @@ -70,7 +70,7 @@ repos: - id: sphinx-lint - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.4 + rev: v0.12.2 hooks: # try to fix what is possible - id: ruff diff --git a/Makefile b/Makefile index 426c18042994c..4b077a65a7a02 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: test clean docs +.PHONY: test clean docs setup # to imitate SLURM set only single node export SLURM_LOCALID=0 @@ -7,6 +7,23 @@ export SPHINX_MOCK_REQUIREMENTS=1 # install only Lightning Trainer packages export PACKAGE_NAME=pytorch +setup: + uv pip install -r requirements.txt \ + -r requirements/pytorch/base.txt \ + -r requirements/pytorch/test.txt \ + -r requirements/pytorch/extra.txt \ + -r requirements/pytorch/strategies.txt \ + -r requirements/fabric/base.txt \ + -r requirements/fabric/test.txt \ + -r requirements/fabric/strategies.txt \ + -r requirements/typing.txt \ + -e ".[all]" \ + pre-commit + pre-commit install + @echo "-----------------------------" + @echo "✅ Environment setup complete. Ready to Contribute ⚡️!" + + clean: # clean all temp runs rm -rf $(shell find . -name "mlruns") diff --git a/README.md b/README.md index dd5f0fe43e0c7..546458bf80bc9 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,12 @@ ______________________________________________________________________   +# Why PyTorch Lightning? + +Training models in plain PyTorch is tedious and error-prone - you have to manually handle things like backprop, mixed precision, multi-GPU, and distributed training, often rewriting code for every new project. PyTorch Lightning organizes PyTorch code to automate those complexities so you can focus on your model and data, while keeping full control and scaling from CPU to multi-node without changing your core code. But if you want control of those things, you can still opt into more DIY. + +Fun analogy: If PyTorch is Javascript, PyTorch Lightning is ReactJS or NextJS. + # Lightning has 2 core packages [PyTorch Lightning: Train and deploy PyTorch at scale](#why-pytorch-lightning). diff --git a/_notebooks b/_notebooks index fd70f5114b21f..69112e6fe73b5 160000 --- a/_notebooks +++ b/_notebooks @@ -1 +1 @@ -Subproject commit fd70f5114b21f7f970bd5587b1d3def689507069 +Subproject commit 69112e6fe73b50d159c4b7add6ddea412b458691 diff --git a/dockers/base-cuda/Dockerfile b/dockers/base-cuda/Dockerfile index 2fe1e57e95a77..2b6f48771c7f7 100644 --- a/dockers/base-cuda/Dockerfile +++ b/dockers/base-cuda/Dockerfile @@ -19,8 +19,9 @@ ARG CUDA_VERSION=11.7.1 FROM nvidia/cuda:${CUDA_VERSION}-runtime-ubuntu${UBUNTU_VERSION} ARG PYTHON_VERSION=3.10 -ARG PYTORCH_VERSION=2.1 +ARG PYTORCH_VERSION=2.8 ARG MAX_ALLOWED_NCCL=2.22.3 +ARG MAKE_FLAGS="-j$(nproc)" SHELL ["/bin/bash", "-c"] # https://techoverflow.net/2019/05/18/how-to-fix-configuring-tzdata-interactive-input-when-building-docker-images/ @@ -30,8 +31,7 @@ ENV \ PATH="$PATH:/root/.local/bin" \ CUDA_TOOLKIT_ROOT_DIR="/usr/local/cuda" \ MKL_THREADING_LAYER="GNU" \ - # MAKEFLAGS="-j$(nproc)" - MAKEFLAGS="-j2" + MAKEFLAGS=${MAKE_FLAGS} RUN \ CUDA_VERSION_MM=${CUDA_VERSION%.*} && \ diff --git a/dockers/nvidia/Dockerfile b/dockers/nvidia/Dockerfile index 511913542482b..ef329fd56433b 100644 --- a/dockers/nvidia/Dockerfile +++ b/dockers/nvidia/Dockerfile @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -ARG PYTORCH_VERSION=22.09 +ARG PYTORCH_VERSION=24.05 # https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes FROM nvcr.io/nvidia/pytorch:${PYTORCH_VERSION}-py3 diff --git a/dockers/release/Dockerfile b/dockers/release/Dockerfile index d791d8875e8bc..34ff8099afdb1 100644 --- a/dockers/release/Dockerfile +++ b/dockers/release/Dockerfile @@ -13,8 +13,8 @@ # limitations under the License. ARG PYTHON_VERSION=3.10 -ARG PYTORCH_VERSION=2.0 -ARG CUDA_VERSION=11.8.0 +ARG PYTORCH_VERSION=2.8 +ARG CUDA_VERSION=12.6.3 FROM pytorchlightning/pytorch_lightning:base-cuda-py${PYTHON_VERSION}-torch${PYTORCH_VERSION}-cuda${CUDA_VERSION} diff --git a/docs/source-fabric/advanced/model_parallel/tp_fsdp.rst b/docs/source-fabric/advanced/model_parallel/tp_fsdp.rst index 454ebdacbb9d9..ed1c0c90bbafd 100644 --- a/docs/source-fabric/advanced/model_parallel/tp_fsdp.rst +++ b/docs/source-fabric/advanced/model_parallel/tp_fsdp.rst @@ -9,7 +9,7 @@ The :doc:`Tensor Parallelism documentation ` and a general understanding of .. raw:: html - + Open In Studio diff --git a/docs/source-pytorch/advanced/model_parallel/tp.rst b/docs/source-pytorch/advanced/model_parallel/tp.rst index e857f1f974828..6d3c8b5be9664 100644 --- a/docs/source-pytorch/advanced/model_parallel/tp.rst +++ b/docs/source-pytorch/advanced/model_parallel/tp.rst @@ -8,7 +8,7 @@ This method is most effective for models with very large layers, significantly e .. raw:: html - + Open In Studio diff --git a/docs/source-pytorch/advanced/speed.rst b/docs/source-pytorch/advanced/speed.rst index 69843bd6aa8a5..53f2938ab099e 100644 --- a/docs/source-pytorch/advanced/speed.rst +++ b/docs/source-pytorch/advanced/speed.rst @@ -464,7 +464,7 @@ takes a great deal of care to be optimized for this. Clear Cache =========== -Don't call :func:`torch.cuda.empty_cache` unnecessarily! Every time you call this, ALL your GPUs have to wait to sync. +Don't call ``torch.cuda.empty_cache`` unnecessarily! Every time you call this, ALL your GPUs have to wait to sync. Transferring Tensors to Device ============================== diff --git a/docs/source-pytorch/advanced/transfer_learning.rst b/docs/source-pytorch/advanced/transfer_learning.rst index 7f6af6ad5a56d..50a65870b1572 100644 --- a/docs/source-pytorch/advanced/transfer_learning.rst +++ b/docs/source-pytorch/advanced/transfer_learning.rst @@ -32,7 +32,7 @@ Let's use the `AutoEncoder` as a feature extractor in a separate model. class CIFAR10Classifier(LightningModule): def __init__(self): # init the pretrained LightningModule - self.feature_extractor = AutoEncoder.load_from_checkpoint(PATH) + self.feature_extractor = AutoEncoder.load_from_checkpoint(PATH).encoder self.feature_extractor.freeze() # the autoencoder outputs a 100-dim representation and CIFAR-10 has 10 classes diff --git a/docs/source-pytorch/common/checkpointing_basic.rst b/docs/source-pytorch/common/checkpointing_basic.rst index 1026e972849ef..9966c360a95e8 100644 --- a/docs/source-pytorch/common/checkpointing_basic.rst +++ b/docs/source-pytorch/common/checkpointing_basic.rst @@ -111,7 +111,7 @@ The LightningModule also has access to the Hyperparameters .. code-block:: python model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt") - print(model.learning_rate) + print(model.hparams.learning_rate) ---- diff --git a/docs/source-pytorch/common/precision_intermediate.rst b/docs/source-pytorch/common/precision_intermediate.rst index eff5805497c2d..70f86b1c09bb2 100644 --- a/docs/source-pytorch/common/precision_intermediate.rst +++ b/docs/source-pytorch/common/precision_intermediate.rst @@ -165,7 +165,7 @@ Under the hood, we use `transformer_engine.pytorch.fp8_autocast `__ (BNB) is a library that supports quantizing :class:`torch.nn.Linear` weights. +`bitsandbytes `__ (BNB) is a library that supports quantizing :class:`torch.nn.Linear` weights. Both 4-bit (`paper reference `__) and 8-bit (`paper reference `__) quantization is supported. Specifically, we support the following modes: diff --git a/docs/source-pytorch/common/trainer.rst b/docs/source-pytorch/common/trainer.rst index c40ad1fcf92e2..86ee52f41f0c9 100644 --- a/docs/source-pytorch/common/trainer.rst +++ b/docs/source-pytorch/common/trainer.rst @@ -759,6 +759,9 @@ overfit_batches Uses this much data of the training & validation set. If the training & validation dataloaders have ``shuffle=True``, Lightning will automatically disable it. +* When set to a value > 0, sequential sampling (no shuffling) is used +* Consistent batches are used for both training and validation across epochs, but training and validation use different sets of data + Useful for quickly debugging or trying to overfit on purpose. .. testcode:: @@ -769,11 +772,11 @@ Useful for quickly debugging or trying to overfit on purpose. # use only 1% of the train & val set trainer = Trainer(overfit_batches=0.01) - # overfit on 10 of the same batches + # overfit on 10 consistent train batches & 10 consistent val batches trainer = Trainer(overfit_batches=10) -plugins -^^^^^^^ + # debug using a single consistent train batch and a single consistent val batch + :ref:`Plugins` allow you to connect arbitrary backends, precision libraries, clusters etc. For example: @@ -895,7 +898,7 @@ DataSource can be a ``LightningModule`` or a ``LightningDataModule``. # if 0 (default) train_loader = model.train_dataloader() - # or if using data module: datamodule.train_dataloader() + # or if using data module: datamodule.train_dataloaders() for epoch in epochs: for batch in train_loader: ... diff --git a/docs/source-pytorch/glossary/index.rst b/docs/source-pytorch/glossary/index.rst index 6b5e4b12b307f..45683c67c1708 100644 --- a/docs/source-pytorch/glossary/index.rst +++ b/docs/source-pytorch/glossary/index.rst @@ -209,7 +209,7 @@ Glossary .. displayitem:: :header: LightningModule - :description: A base class organizug your neural network module + :description: A base class organizing your neural network module :col_css: col-md-12 :button_link: ../common/lightning_module.html :height: 100 diff --git a/docs/source-pytorch/versioning.rst b/docs/source-pytorch/versioning.rst index 10c6ec2fdf8e5..948986d5699ed 100644 --- a/docs/source-pytorch/versioning.rst +++ b/docs/source-pytorch/versioning.rst @@ -79,10 +79,16 @@ The table below indicates the coverage of tested versions in our CI. Versions ou - ``torch`` - ``torchmetrics`` - Python + * - 2.5 + - 2.5 + - 2.5 + - ≥2.1, ≤2.7 + - ≥0.7.0 + - ≥3.9, ≤3.12 * - 2.4 - 2.4 - 2.4 - - ≥2.1, ≤2.4 + - ≥2.1, ≤2.6 - ≥0.7.0 - ≥3.9, ≤3.12 * - 2.3 diff --git a/examples/fabric/image_classifier/train_fabric.py b/examples/fabric/image_classifier/train_fabric.py index d207595e9d2ba..955e66fad19b7 100644 --- a/examples/fabric/image_classifier/train_fabric.py +++ b/examples/fabric/image_classifier/train_fabric.py @@ -158,7 +158,7 @@ def run(hparams): # When using distributed training, use `fabric.save` # to ensure the current process is allowed to save a checkpoint if hparams.save_model: - fabric.save(model.state_dict(), "mnist_cnn.pt") + fabric.save(path="mnist_cnn.pt", state=model.state_dict()) if __name__ == "__main__": diff --git a/examples/fabric/kfold_cv/train_fabric.py b/examples/fabric/kfold_cv/train_fabric.py index 4d0bf5f6048da..79b2bd6eb9f8c 100644 --- a/examples/fabric/kfold_cv/train_fabric.py +++ b/examples/fabric/kfold_cv/train_fabric.py @@ -161,7 +161,7 @@ def run(hparams): # When using distributed training, use `fabric.save` # to ensure the current process is allowed to save a checkpoint if hparams.save_model: - fabric.save(model.state_dict(), "mnist_cnn.pt") + fabric.save(path="mnist_cnn.pt", state=model.state_dict()) if __name__ == "__main__": diff --git a/examples/fabric/tensor_parallel/train.py b/examples/fabric/tensor_parallel/train.py index 35ee9074f18a8..2543702cc3450 100644 --- a/examples/fabric/tensor_parallel/train.py +++ b/examples/fabric/tensor_parallel/train.py @@ -67,7 +67,7 @@ def train(): # See `fabric consolidate --help` if you need to convert the checkpoint to a single file fabric.print("Saving a (distributed) checkpoint ...") state = {"model": model, "optimizer": optimizer, "iteration": i} - fabric.save("checkpoint.pt", state) + fabric.save(path="checkpoint.pt", state=state) fabric.print("Training successfully completed!") fabric.print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB") diff --git a/pyproject.toml b/pyproject.toml index b45f60489c6fe..a63da5f246392 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -180,6 +180,7 @@ markers = [ ] filterwarnings = [ "error::FutureWarning", + "ignore::FutureWarning:onnxscript", # Temporary ignore until onnxscript is updated ] xfail_strict = true junit_duration_report = "call" diff --git a/requirements/ci.txt b/requirements/ci.txt index 6b879f4f3fbb1..02c0b0e9c105e 100644 --- a/requirements/ci.txt +++ b/requirements/ci.txt @@ -1,6 +1,6 @@ setuptools <80.9.1 wheel <0.46.0 -awscli >=1.30.0, <1.41.0 +awscli >=1.30.0, <1.43.0 twine ==6.1.0 importlib-metadata <9.0.0 wget diff --git a/requirements/docs.txt b/requirements/docs.txt index 1acd55018df8c..9fa72085df6b5 100644 --- a/requirements/docs.txt +++ b/requirements/docs.txt @@ -3,7 +3,7 @@ myst-parser >=0.18.1, <4.0.0 nbsphinx >=0.8.5, <=0.9.7 nbconvert >7.14, <7.17 pandoc >=1.0, <=2.4 -docutils >=0.16, <0.22 +docutils>=0.18.1,<=0.22 sphinxcontrib-fulltoc >=1.0, <=1.2.0 sphinxcontrib-mockautodoc sphinx-autobuild @@ -17,7 +17,7 @@ sphinx-rtd-dark-mode sphinxcontrib-video ==0.4.1 jinja2 <3.2.0 -lightning-utilities >=0.11.1, <0.15.0 +lightning-utilities >=0.11.1, <0.16.0 # installed from S3 location and fetched in advance lai-sphinx-theme diff --git a/requirements/doctests.txt b/requirements/doctests.txt index 919ee2d3a2672..91b1217b65584 100644 --- a/requirements/doctests.txt +++ b/requirements/doctests.txt @@ -1,2 +1,2 @@ -pytest ==8.4.0 +pytest ==8.4.1 pytest-doctestplus ==1.4.0 diff --git a/requirements/fabric/base.txt b/requirements/fabric/base.txt index cb70b24ae26c4..b740cf2d7cdd6 100644 --- a/requirements/fabric/base.txt +++ b/requirements/fabric/base.txt @@ -1,8 +1,8 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -torch >=2.1.0, <2.8.0 -fsspec[http] >=2022.5.0, <2025.6.0 +torch >=2.1.0, <2.9.0 +fsspec[http] >=2022.5.0, <2025.8.0 packaging >=20.0, <=25.0 -typing-extensions >=4.4.0, <4.15.0 -lightning-utilities >=0.10.0, <0.15.0 +typing-extensions >4.5.0, <4.15.0 +lightning-utilities >=0.10.0, <0.16.0 diff --git a/requirements/fabric/examples.txt b/requirements/fabric/examples.txt index 72eabb238f3bb..ab6ffb8b137df 100644 --- a/requirements/fabric/examples.txt +++ b/requirements/fabric/examples.txt @@ -1,6 +1,5 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -torchvision >=0.16.0, <0.23.0 -torchmetrics >=0.10.0, <1.8.0 -lightning-utilities >=0.8.0, <0.15.0 +torchvision >=0.16.0, <0.24.0 +torchmetrics >=0.10.0, <1.9.0 diff --git a/requirements/fabric/strategies.txt b/requirements/fabric/strategies.txt index 5be2eed05284c..bea30b37fa5f8 100644 --- a/requirements/fabric/strategies.txt +++ b/requirements/fabric/strategies.txt @@ -5,5 +5,5 @@ # note: is a bug around 0.10 with `MPS_Accelerator must implement all abstract methods` # shall be resolved by https://github.com/microsoft/DeepSpeed/issues/4372 -deepspeed >=0.8.2, <=0.9.3; platform_system != "Windows" and platform_system != "Darwin" # strict +deepspeed >=0.9.3, <=0.9.3; platform_system != "Windows" and platform_system != "Darwin" # strict bitsandbytes >=0.45.2,<0.47.0; platform_system != "Darwin" diff --git a/requirements/fabric/test.txt b/requirements/fabric/test.txt index 7a58fac8189b6..d8884253eab80 100644 --- a/requirements/fabric/test.txt +++ b/requirements/fabric/test.txt @@ -1,9 +1,9 @@ -coverage ==7.9.1 +coverage ==7.10.3 numpy >=1.17.2, <1.27.0 -pytest ==8.4.0 +pytest ==8.4.1 pytest-cov ==6.2.1 pytest-timeout ==2.4.0 pytest-rerunfailures ==15.1 -pytest-random-order ==1.1.1 +pytest-random-order ==1.2.0 click ==8.1.8 tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt index e77ecc8a1baeb..ef798883c12ef 100644 --- a/requirements/pytorch/base.txt +++ b/requirements/pytorch/base.txt @@ -1,11 +1,11 @@ # NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment -torch >=2.1.0, <2.8.0 +torch >=2.1.0, <2.9.0 tqdm >=4.57.0, <4.68.0 -PyYAML >=5.4, <6.1.0 -fsspec[http] >=2022.5.0, <2025.6.0 -torchmetrics >=0.7.0, <1.8.0 +PyYAML >5.4, <6.1.0 +fsspec[http] >=2022.5.0, <2025.8.0 +torchmetrics >0.7.0, <1.9.0 packaging >=20.0, <=25.0 -typing-extensions >=4.4.0, <4.15.0 -lightning-utilities >=0.10.0, <0.15.0 +typing-extensions >4.5.0, <4.15.0 +lightning-utilities >=0.10.0, <0.16.0 diff --git a/requirements/pytorch/docs.txt b/requirements/pytorch/docs.txt index 1f4e0cb8031c4..35cc6234ae5d2 100644 --- a/requirements/pytorch/docs.txt +++ b/requirements/pytorch/docs.txt @@ -4,4 +4,6 @@ nbformat # used for generate empty notebook ipython[notebook] <8.19.0 setuptools<81.0 # workaround for `error in ipython setup command: use_2to3 is invalid.` --r ../../_notebooks/.actions/requires.txt +onnxscript >= 0.2.2, <0.4.0 + +#-r ../../_notebooks/.actions/requires.txt diff --git a/requirements/pytorch/examples.txt b/requirements/pytorch/examples.txt index d9ad8150693b9..84ea80df6ff0c 100644 --- a/requirements/pytorch/examples.txt +++ b/requirements/pytorch/examples.txt @@ -2,7 +2,6 @@ # in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment requests <2.33.0 -torchvision >=0.16.0, <0.23.0 +torchvision >=0.16.0, <0.24.0 ipython[all] <8.19.0 -torchmetrics >=0.10.0, <1.8.0 -lightning-utilities >=0.8.0, <0.15.0 +torchmetrics >=0.10.0, <1.9.0 diff --git a/requirements/pytorch/extra.txt b/requirements/pytorch/extra.txt index 0f11b19c23431..ab3a36f7dad3b 100644 --- a/requirements/pytorch/extra.txt +++ b/requirements/pytorch/extra.txt @@ -6,6 +6,6 @@ matplotlib>3.1, <3.10.0 omegaconf >=2.2.3, <2.4.0 hydra-core >=1.2.0, <1.4.0 jsonargparse[signatures,jsonnet] >=4.39.0, <4.41.0 -rich >=12.3.0, <14.1.0 +rich >=12.3.0, <14.2.0 tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute bitsandbytes >=0.45.2,<0.47.0; platform_system != "Darwin" diff --git a/requirements/pytorch/strategies.txt b/requirements/pytorch/strategies.txt index 8d3af408a98fe..1f7296798b551 100644 --- a/requirements/pytorch/strategies.txt +++ b/requirements/pytorch/strategies.txt @@ -3,4 +3,4 @@ # note: is a bug around 0.10 with `MPS_Accelerator must implement all abstract methods` # shall be resolved by https://github.com/microsoft/DeepSpeed/issues/4372 -deepspeed >=0.8.2, <=0.9.3; platform_system != "Windows" and platform_system != "Darwin" # strict +deepspeed >=0.9.3, <=0.9.3; platform_system != "Windows" and platform_system != "Darwin" # strict diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index 865109c87b140..a1fbdec222c7f 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -1,9 +1,9 @@ -coverage ==7.9.1 -pytest ==8.4.0 +coverage ==7.10.3 +pytest ==8.4.1 pytest-cov ==6.2.1 pytest-timeout ==2.4.0 pytest-rerunfailures ==15.1 -pytest-random-order ==1.1.1 +pytest-random-order ==1.2.0 # needed in tests cloudpickle >=1.3, <3.2.0 @@ -11,9 +11,10 @@ scikit-learn >0.22.1, <1.7.0 numpy >=1.17.2, <1.27.0 onnx >=1.12.0, <1.19.0 onnxruntime >=1.12.0, <1.21.0 +onnxscript >= 0.2.2, <0.4.0 psutil <7.0.1 # for `DeviceStatsMonitor` pandas >2.0, <2.4.0 # needed in benchmarks fastapi # for `ServableModuleValidator` # not setting version as re-defined in App uvicorn # for `ServableModuleValidator` # not setting version as re-defined in App -tensorboard >=2.9.1, <2.20.0 # for `TensorBoardLogger` +tensorboard >=2.9.1, <2.21.0 # for `TensorBoardLogger` diff --git a/requirements/typing.txt b/requirements/typing.txt index 940534fc729bb..7e0c34e2ac3fa 100644 --- a/requirements/typing.txt +++ b/requirements/typing.txt @@ -1,5 +1,5 @@ -mypy==1.16.0 -torch==2.7.1 +mypy==1.17.1 +torch==2.7.1 # todo: update typing in separate PR types-Markdown types-PyYAML diff --git a/src/lightning/__about__.py b/src/lightning/__about__.py index 42be5087642a5..a35adb9cbb5cc 100644 --- a/src/lightning/__about__.py +++ b/src/lightning/__about__.py @@ -15,7 +15,7 @@ import time __author__ = "Lightning AI et al." -__author_email__ = "pytorch@lightning.ai" +__author_email__ = "developer@lightning.ai" __license__ = "Apache-2.0" __copyright__ = f"Copyright (c) 2018-{time.strftime('%Y')}, {__author__}." __homepage__ = "https://github.com/Lightning-AI/lightning" diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 102bea4d1f4c4..0659105560bba 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -6,6 +6,19 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). --- +## [2.5.3] - 2025-08-DD + +### Changed + +- Enable "auto" for `devices` and `accelerator` as CLI arguments ([#20913](https://github.com/Lightning-AI/pytorch-lightning/pull/20913)) +- Raise ValueError when seed is `out-of-bounds` or `cannot be cast to int` ([#21029](https://github.com/Lightning-AI/pytorch-lightning/pull/21029)) + +### Fixed + +- Fixed XLA strategy to add support for `global_ordinal`, `local_ordinal`, `world_size` which came instead of deprecated methods ([#20852](https://github.com/Lightning-AI/pytorch-lightning/issues/20852)) +- Fixed remove extra `name` parameter in accelerator registry decorator ([#20975](https://github.com/Lightning-AI/pytorch-lightning/pull/20975)) + + ## [2.5.2] - 2025-3-20 ### Changed diff --git a/src/lightning/fabric/accelerators/registry.py b/src/lightning/fabric/accelerators/registry.py index 4959a0fb9426a..539b7aa8a01dc 100644 --- a/src/lightning/fabric/accelerators/registry.py +++ b/src/lightning/fabric/accelerators/registry.py @@ -73,14 +73,14 @@ def register( data["description"] = description data["init_params"] = init_params - def do_register(name: str, accelerator: Callable) -> Callable: + def do_register(accelerator: Callable) -> Callable: data["accelerator"] = accelerator data["accelerator_name"] = name self[name] = data return accelerator if accelerator is not None: - return do_register(name, accelerator) + return do_register(accelerator) return do_register diff --git a/src/lightning/fabric/cli.py b/src/lightning/fabric/cli.py index 2268614abb97b..594bb46f4b362 100644 --- a/src/lightning/fabric/cli.py +++ b/src/lightning/fabric/cli.py @@ -25,7 +25,7 @@ from lightning.fabric.plugins.precision.precision import _PRECISION_INPUT_STR, _PRECISION_INPUT_STR_ALIAS from lightning.fabric.strategies import STRATEGY_REGISTRY from lightning.fabric.utilities.consolidate_checkpoint import _process_cli_args -from lightning.fabric.utilities.device_parser import _parse_gpu_ids +from lightning.fabric.utilities.device_parser import _parse_gpu_ids, _select_auto_accelerator from lightning.fabric.utilities.distributed import _suggested_max_num_threads from lightning.fabric.utilities.load import _load_distributed_checkpoint @@ -34,7 +34,7 @@ _CLICK_AVAILABLE = RequirementCache("click") _LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk") -_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu") +_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu", "auto") def _get_supported_strategies() -> list[str]: @@ -187,6 +187,14 @@ def _set_env_variables(args: Namespace) -> None: def _get_num_processes(accelerator: str, devices: str) -> int: """Parse the `devices` argument to determine how many processes need to be launched on the current machine.""" + + if accelerator == "auto" or accelerator is None: + accelerator = _select_auto_accelerator() + if devices == "auto": + if accelerator == "cuda" or accelerator == "mps" or accelerator == "cpu": + devices = "1" + else: + raise ValueError(f"Cannot default to '1' device for accelerator='{accelerator}'") if accelerator == "gpu": parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True) elif accelerator == "cuda": diff --git a/src/lightning/fabric/plugins/environments/xla.py b/src/lightning/fabric/plugins/environments/xla.py index a227d2322b9a3..b8350872f22d9 100644 --- a/src/lightning/fabric/plugins/environments/xla.py +++ b/src/lightning/fabric/plugins/environments/xla.py @@ -66,6 +66,11 @@ def world_size(self) -> int: The output is cached for performance. """ + if _XLA_GREATER_EQUAL_2_1: + from torch_xla import runtime as xr + + return xr.world_size() + import torch_xla.core.xla_model as xm return xm.xrt_world_size() @@ -82,6 +87,11 @@ def global_rank(self) -> int: The output is cached for performance. """ + if _XLA_GREATER_EQUAL_2_1: + from torch_xla import runtime as xr + + return xr.global_ordinal() + import torch_xla.core.xla_model as xm return xm.get_ordinal() @@ -98,6 +108,11 @@ def local_rank(self) -> int: The output is cached for performance. """ + if _XLA_GREATER_EQUAL_2_1: + from torch_xla import runtime as xr + + return xr.local_ordinal() + import torch_xla.core.xla_model as xm return xm.get_local_ordinal() diff --git a/src/lightning/fabric/plugins/precision/bitsandbytes.py b/src/lightning/fabric/plugins/precision/bitsandbytes.py index 8a71a25bb914f..4c648f2b97181 100644 --- a/src/lightning/fabric/plugins/precision/bitsandbytes.py +++ b/src/lightning/fabric/plugins/precision/bitsandbytes.py @@ -256,10 +256,12 @@ def quantize( if int8params.has_fp16_weights: int8params.data = B else: - if hasattr(bnb.functional, "double_quant"): + # bitsandbytes >= 0.45 supports an improved API + if hasattr(bnb.functional, "int8_vectorwise_quant"): + CB, SCB, _ = bnb.functional.int8_vectorwise_quant(B) + else: # old method is deprecated in 0.45, removed in 0.46+. CB, _, SCB, _, _ = bnb.functional.double_quant(B) - else: # for bitsandbytes versions ≥0.46 - CB, SCB = bnb.functional.int8_double_quant(B) + int8params.data = CB setattr(int8params, "CB", CB) setattr(int8params, "SCB", SCB) diff --git a/src/lightning/fabric/utilities/consolidate_checkpoint.py b/src/lightning/fabric/utilities/consolidate_checkpoint.py index 15d20d8d89ecc..68c956dbad997 100644 --- a/src/lightning/fabric/utilities/consolidate_checkpoint.py +++ b/src/lightning/fabric/utilities/consolidate_checkpoint.py @@ -1,4 +1,5 @@ import logging +import sys from argparse import ArgumentParser, Namespace from pathlib import Path @@ -40,23 +41,23 @@ def _parse_cli_args() -> Namespace: def _process_cli_args(args: Namespace) -> Namespace: if not _TORCH_GREATER_EQUAL_2_3: _log.error("Processing distributed checkpoints requires PyTorch >= 2.3.") - exit(1) + sys.exit(1) checkpoint_folder = Path(args.checkpoint_folder) if not checkpoint_folder.exists(): _log.error(f"The provided checkpoint folder does not exist: {checkpoint_folder}") - exit(1) + sys.exit(1) if not checkpoint_folder.is_dir(): _log.error( f"The provided checkpoint path must be a folder, containing the checkpoint shards: {checkpoint_folder}" ) - exit(1) + sys.exit(1) if not (checkpoint_folder / _METADATA_FILENAME).is_file(): _log.error( "Only FSDP-sharded checkpoints saved with Lightning are supported for consolidation. The provided folder" f" is not in that format: {checkpoint_folder}" ) - exit(1) + sys.exit(1) if args.output_file is None: output_file = checkpoint_folder.with_suffix(checkpoint_folder.suffix + ".consolidated") @@ -67,7 +68,7 @@ def _process_cli_args(args: Namespace) -> Namespace: "The path for the converted checkpoint already exists. Choose a different path by providing" f" `--output_file` or move/delete the file first: {output_file}" ) - exit(1) + sys.exit(1) return Namespace(checkpoint_folder=checkpoint_folder, output_file=output_file) diff --git a/src/lightning/fabric/utilities/device_parser.py b/src/lightning/fabric/utilities/device_parser.py index ff5bebd9b4516..8bdacc0f523f5 100644 --- a/src/lightning/fabric/utilities/device_parser.py +++ b/src/lightning/fabric/utilities/device_parser.py @@ -204,3 +204,18 @@ def _check_data_type(device_ids: object) -> None: raise TypeError(f"{msg} a sequence of {type(id_).__name__}.") elif type(device_ids) not in (int, str): raise TypeError(f"{msg} {device_ids!r}.") + + +def _select_auto_accelerator() -> str: + """Choose the accelerator type (str) based on availability.""" + from lightning.fabric.accelerators.cuda import CUDAAccelerator + from lightning.fabric.accelerators.mps import MPSAccelerator + from lightning.fabric.accelerators.xla import XLAAccelerator + + if XLAAccelerator.is_available(): + return "tpu" + if MPSAccelerator.is_available(): + return "mps" + if CUDAAccelerator.is_available(): + return "cuda" + return "cpu" diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index a618371d7f2b4..70239baac0e6d 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -34,6 +34,7 @@ _TORCH_EQUAL_2_4_0 = compare_version("torch", operator.eq, "2.4.0") _TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0") _TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1") +_TORCH_GREATER_EQUAL_2_5 = compare_version("torch", operator.ge, "2.5.0") _TORCH_LESS_EQUAL_2_6 = compare_version("torch", operator.le, "2.6.0") - +_TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0") _PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) diff --git a/src/lightning/fabric/utilities/seed.py b/src/lightning/fabric/utilities/seed.py index f9c0ddeb86cf0..534e5e3db653e 100644 --- a/src/lightning/fabric/utilities/seed.py +++ b/src/lightning/fabric/utilities/seed.py @@ -27,7 +27,8 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False, verbose: Args: seed: the integer value seed for global random state in Lightning. If ``None``, it will read the seed from ``PL_GLOBAL_SEED`` env variable. If ``None`` and the - ``PL_GLOBAL_SEED`` env variable is not set, then the seed defaults to 0. + ``PL_GLOBAL_SEED`` env variable is not set, then the seed defaults to 0. If seed is + not in bounds or cannot be cast to int, a ValueError is raised. workers: if set to ``True``, will properly configure all dataloaders passed to the Trainer with a ``worker_init_fn``. If the user already provides such a function for their dataloaders, setting this argument will have no influence. See also: @@ -44,14 +45,12 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False, verbose: try: seed = int(env_seed) except ValueError: - seed = 0 - rank_zero_warn(f"Invalid seed found: {repr(env_seed)}, seed set to {seed}") + raise ValueError(f"Invalid seed specified via PL_GLOBAL_SEED: {repr(env_seed)}") elif not isinstance(seed, int): seed = int(seed) if not (min_seed_value <= seed <= max_seed_value): - rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") - seed = 0 + raise ValueError(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") if verbose: log.info(rank_prefixed_message(f"Seed set to {seed}", _get_rank())) diff --git a/src/lightning/fabric/utilities/spike.py b/src/lightning/fabric/utilities/spike.py index e96eccd75b1a2..9c1b0a2a00572 100644 --- a/src/lightning/fabric/utilities/spike.py +++ b/src/lightning/fabric/utilities/spike.py @@ -1,19 +1,16 @@ import json -import operator import os import warnings from typing import TYPE_CHECKING, Any, Literal, Optional, Union import torch -from lightning_utilities.core.imports import compare_version +from lightning.fabric.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_1_0_0 from lightning.fabric.utilities.types import _PATH if TYPE_CHECKING: from lightning.fabric.fabric import Fabric -_TORCHMETRICS_GREATER_EQUAL_1_0_0 = compare_version("torchmetrics", operator.ge, "1.0.0") - class SpikeDetection: """Spike Detection Callback. diff --git a/src/lightning/fabric/utilities/testing/_runif.py b/src/lightning/fabric/utilities/testing/_runif.py index 6f5d933f9dae3..ec980693b75f3 100644 --- a/src/lightning/fabric/utilities/testing/_runif.py +++ b/src/lightning/fabric/utilities/testing/_runif.py @@ -17,7 +17,7 @@ from typing import Optional import torch -from lightning_utilities.core.imports import RequirementCache, compare_version +from lightning_utilities.core.imports import compare_version from packaging.version import Version from lightning.fabric.accelerators import XLAAccelerator @@ -40,11 +40,12 @@ def _runif_reasons( standalone: bool = False, deepspeed: bool = False, dynamo: bool = False, + linux_only: bool = False, ) -> tuple[list[str], dict[str, bool]]: """Construct reasons for pytest skipif. Args: - min_cuda_gpus: Require this number of gpus and that the ``PL_RUN_CUDA_TESTS=1`` environment variable is set. + min_cuda_gpus: Require this number of gpus and that the ``RUN_ONLY_CUDA_TESTS=1`` environment variable is set. min_torch: Require that PyTorch is greater or equal than this version. max_torch: Require that PyTorch is less than this version. min_python: Require that Python is greater or equal than this version. @@ -112,9 +113,7 @@ def _runif_reasons( reasons.append("Standalone execution") kwargs["standalone"] = True - if deepspeed and not ( - _DEEPSPEED_AVAILABLE and not _TORCH_GREATER_EQUAL_2_4 and RequirementCache(module="deepspeed.utils") - ): + if deepspeed and not (_DEEPSPEED_AVAILABLE and not _TORCH_GREATER_EQUAL_2_4): reasons.append("Deepspeed") if dynamo: @@ -123,4 +122,7 @@ def _runif_reasons( if not is_dynamo_supported(): reasons.append("torch.dynamo") + if linux_only and sys.platform != "linux": + reasons.append("only linux") + return reasons, kwargs diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 23325f868786c..6f5d7acbaed43 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -6,6 +6,30 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). --- +## [2.5.3] - 2025-08-DD + +### Changed + +- Added `save_on_exception` option to `ModelCheckpoint` Callback ([#20916](https://github.com/Lightning-AI/pytorch-lightning/pull/20916)) +- Allow `dataloader_idx_` in log names when `add_dataloader_idx=False` ([#20987](https://github.com/Lightning-AI/pytorch-lightning/pull/20987)) +- Allow returning `ONNXProgram` when calling `to_onnx(dynamo=True)` ([#20811](https://github.com/Lightning-AI/pytorch-lightning/pull/20811)) +- Extended support for general mappings being returned from `training_step` when using manual optimization ([#21011](https://github.com/Lightning-AI/pytorch-lightning/pull/21011)) + +### Fixed + +- Fixed Allowing trainer to accept CUDAAccelerator instance as accelerator with FSDP strategy ([#20964](https://github.com/Lightning-AI/pytorch-lightning/pull/20964)) +- Fixed progress bar console clearing for Rich `14.1+` ([#21016](https://github.com/Lightning-AI/pytorch-lightning/pull/21016)) +- Fixed `AdvancedProfiler` to handle nested profiling actions for Python 3.12+ ([#20809](https://github.com/Lightning-AI/pytorch-lightning/pull/20809)) +- Fixed rich progress bar error when resume training ([#21000](https://github.com/Lightning-AI/pytorch-lightning/pull/21000)) +- Fixed double iteration bug when resumed from a checkpoint. ([#20775](https://github.com/Lightning-AI/pytorch-lightning/pull/20775)) +- Fixed support for more dtypes in `ModelSummary` ([#21034](https://github.com/Lightning-AI/pytorch-lightning/pull/21034)) +- Fixed metrics in `RichProgressBar` being updated according to user provided `refresh_rate` ([#21032](https://github.com/Lightning-AI/pytorch-lightning/pull/21032)) +- Fixed `save_last` behavior in the absence of validation ([#20960](https://github.com/Lightning-AI/pytorch-lightning/pull/20960)) +- Fixed integration between `LearningRateFinder` and `EarlyStopping` ([#21056](https://github.com/Lightning-AI/pytorch-lightning/pull/21056)) +- Fixed gradient calculation in `lr_finder` for `mode="exponential"` ([#21055](https://github.com/Lightning-AI/pytorch-lightning/pull/21055)) +- Fixed `save_hyperparameters` crashing with `dataclasses` using `init=False` fields ([#21051](https://github.com/Lightning-AI/pytorch-lightning/pull/21051)) + + ## [2.5.2] - 2025-06-20 ### Changed diff --git a/src/lightning/pytorch/callbacks/lr_finder.py b/src/lightning/pytorch/callbacks/lr_finder.py index f667b5c501a10..aaadc3c38ed5e 100644 --- a/src/lightning/pytorch/callbacks/lr_finder.py +++ b/src/lightning/pytorch/callbacks/lr_finder.py @@ -106,7 +106,7 @@ def __init__( self._attr_name = attr_name self._early_exit = False - self.lr_finder: Optional[_LRFinder] = None + self.optimal_lr: Optional[_LRFinder] = None def lr_find(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: with isolate_rng(): diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 6b7b2831a2e04..68fed2ff82d31 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -97,6 +97,7 @@ class ModelCheckpoint(Checkpoint): collisions unless ``enable_version_counter`` is set to False. The version counter is unrelated to the top-k ranking of the checkpoint, and we recommend formatting the filename to include the monitored metric to avoid collisions. + save_on_exception: Whether to save a checkpoint when an exception is raised. Default: ``False``. mode: one of {min, max}. If ``save_top_k != 0``, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. @@ -133,9 +134,15 @@ class ModelCheckpoint(Checkpoint): will only save checkpoints at epochs 0 < E <= N where both values for ``every_n_epochs`` and ``check_val_every_n_epoch`` evenly divide E. save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch. - If this is ``False``, then the check runs at the end of the validation. + If ``True``, checkpoints are saved at the end of every training epoch. + If ``False``, checkpoints are saved at the end of validation. + If ``None`` (default), checkpointing behavior is determined based on training configuration. + If ``check_val_every_n_epoch != 1``, checkpointing will not be performed at the end of + every training epoch. If there are no validation batches of data, checkpointing will occur at the + end of the training epoch. If there is a non-default number of validation runs per training epoch + (``val_check_interval != 1``), checkpointing is performed after validation. enable_version_counter: Whether to append a version to the existing file name. - If this is ``False``, then the checkpoint files will be overwritten. + If ``False``, then the checkpoint files will be overwritten. Note: For extra customization, ModelCheckpoint includes the following attributes: @@ -224,6 +231,7 @@ def __init__( verbose: bool = False, save_last: Optional[Union[bool, Literal["link"]]] = None, save_top_k: int = 1, + save_on_exception: bool = False, save_weights_only: bool = False, mode: str = "min", auto_insert_metric_name: bool = True, @@ -238,6 +246,7 @@ def __init__( self.verbose = verbose self.save_last = save_last self.save_top_k = save_top_k + self.save_on_exception = save_on_exception self.save_weights_only = save_weights_only self.auto_insert_metric_name = auto_insert_metric_name self._save_on_train_epoch_end = save_on_train_epoch_end @@ -338,6 +347,26 @@ def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModul self._save_topk_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates) + @override + def on_exception(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", exception: BaseException) -> None: + """Save a checkpoint when an exception is raised.""" + if not self._should_save_on_exception(trainer): + return + monitor_candidates = self._monitor_candidates(trainer) + filepath = self.format_checkpoint_name(metrics=monitor_candidates) + self._save_checkpoint(trainer, filepath) + self._save_last_checkpoint(trainer, monitor_candidates) + rank_zero_info( + f"An {type(exception).__name__} was raised with message: \ + {str(exception)}, saved checkpoint to {filepath}" + ) + + def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """Ensure save_last=True is applied when training ends.""" + if self.save_last and not self._last_checkpoint_saved: + monitor_candidates = self._monitor_candidates(trainer) + self._save_last_checkpoint(trainer, monitor_candidates) + @override def state_dict(self) -> dict[str, Any]: return { @@ -426,6 +455,14 @@ def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool: or self._last_global_step_saved == trainer.global_step # already saved at the last step ) + def _should_save_on_exception(self, trainer: "pl.Trainer") -> bool: + return ( + self.save_on_exception + and not bool(trainer.fast_dev_run) # disable checkpointing with fast_dev_run + and not trainer.sanity_checking # don't save anything during sanity check + and self._last_global_step_saved != trainer.global_step # already saved at the last step + ) + def _should_save_on_train_epoch_end(self, trainer: "pl.Trainer") -> bool: if self._save_on_train_epoch_end is not None: return self._save_on_train_epoch_end @@ -538,7 +575,7 @@ def _format_checkpoint_name( self, filename: Optional[str], metrics: dict[str, Tensor], - prefix: str = "", + prefix: Optional[str] = None, auto_insert_metric_name: bool = True, ) -> str: if not filename: @@ -565,13 +602,17 @@ def _format_checkpoint_name( metrics[name] = torch.tensor(0) filename = filename.format(metrics) - if prefix: + if prefix is not None: filename = self.CHECKPOINT_JOIN_CHAR.join([prefix, filename]) return filename def format_checkpoint_name( - self, metrics: dict[str, Tensor], filename: Optional[str] = None, ver: Optional[int] = None + self, + metrics: dict[str, Tensor], + filename: Optional[str] = None, + ver: Optional[int] = None, + prefix: Optional[str] = None, ) -> str: """Generate a filename according to the defined template. @@ -603,7 +644,9 @@ def format_checkpoint_name( """ filename = filename or self.filename - filename = self._format_checkpoint_name(filename, metrics, auto_insert_metric_name=self.auto_insert_metric_name) + filename = self._format_checkpoint_name( + filename, metrics, prefix=prefix, auto_insert_metric_name=self.auto_insert_metric_name + ) if ver is not None: filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}")) diff --git a/src/lightning/pytorch/callbacks/progress/rich_progress.py b/src/lightning/pytorch/callbacks/progress/rich_progress.py index 7bb98e8a9058c..644497cbb632f 100644 --- a/src/lightning/pytorch/callbacks/progress/rich_progress.py +++ b/src/lightning/pytorch/callbacks/progress/rich_progress.py @@ -184,7 +184,7 @@ def render(self, task: "Task") -> Text: def _generate_metrics_texts(self) -> Generator[str, None, None]: for name, value in self._metrics.items(): - if not isinstance(value, str): + if not isinstance(value, (str, int)): value = f"{value:{self._metrics_format}}" yield f"{name}: {value}" @@ -331,7 +331,19 @@ def _init_progress(self, trainer: "pl.Trainer") -> None: self._reset_progress_bar_ids() reconfigure(**self._console_kwargs) self._console = get_console() - self._console.clear_live() + + # Compatibility shim for Rich >= 14.1.0: + if hasattr(self._console, "_live_stack"): + # In recent Rich releases, the internal `_live` variable was replaced with `_live_stack` (a list) + # to support nested Live displays. This broke our original call to `clear_live()`, + # because it now only pops one Live instance instead of clearing them all. + # We check for `_live_stack` and clear it manually for compatibility across + # both old and new Rich versions. + if len(self._console._live_stack) > 0: + self._console.clear_live() + else: + self._console.clear_live() + self._metric_component = MetricsTextColumn( trainer, self.theme.metrics, @@ -447,6 +459,11 @@ def _add_task(self, total_batches: Union[int, float], description: str, visible: visible=visible, ) + def _initialize_train_progress_bar_id(self) -> None: + total_batches = self.total_train_batches + train_description = self._get_train_description(self.trainer.current_epoch) + self.train_progress_bar_id = self._add_task(total_batches, train_description) + def _update(self, progress_bar_id: Optional["TaskID"], current: int, visible: bool = True) -> None: if self.progress is not None and self.is_enabled: assert progress_bar_id is not None @@ -531,13 +548,16 @@ def on_train_batch_end( batch: Any, batch_idx: int, ) -> None: + if not self.is_disabled and self.train_progress_bar_id is None: + # can happen when resuming from a mid-epoch restart + self._initialize_train_progress_bar_id() self._update(self.train_progress_bar_id, batch_idx + 1) - self._update_metrics(trainer, pl_module) + self._update_metrics(trainer, pl_module, batch_idx + 1) self.refresh() @override def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: - self._update_metrics(trainer, pl_module) + self._update_metrics(trainer, pl_module, total_batches=True) @override def on_validation_batch_end( @@ -612,7 +632,21 @@ def _reset_progress_bar_ids(self) -> None: self.test_progress_bar_id = None self.predict_progress_bar_id = None - def _update_metrics(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + def _update_metrics( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + current: Optional[int] = None, + total_batches: bool = False, + ) -> None: + if not self.is_enabled or self._metric_component is None: + return + + if current is not None and not total_batches: + total = self.total_train_batches + if not self._should_update(current, total): + return + metrics = self.get_metrics(trainer, pl_module) if self._metric_component: self._metric_component.update(metrics) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 7df0cb7757f81..dd64da356e042 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -47,6 +47,7 @@ from lightning.fabric.utilities.apply_func import convert_to_tensors from lightning.fabric.utilities.cloud_io import get_filesystem from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_5 from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH from lightning.fabric.wrappers import _FabricOptimizer from lightning.pytorch.callbacks.callback import Callback @@ -60,7 +61,7 @@ from lightning.pytorch.trainer.connectors.logger_connector.result import _get_default_dtype from lightning.pytorch.utilities import GradClipAlgorithmType from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_9_1 +from lightning.pytorch.utilities.imports import _TORCH_GREATER_EQUAL_2_6, _TORCHMETRICS_GREATER_EQUAL_0_9_1 from lightning.pytorch.utilities.model_helpers import _restricted_classmethod from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature @@ -72,10 +73,17 @@ OptimizerLRScheduler, ) +_ONNX_AVAILABLE = RequirementCache("onnx") +_ONNXSCRIPT_AVAILABLE = RequirementCache("onnxscript") + if TYPE_CHECKING: from torch.distributed.device_mesh import DeviceMesh -_ONNX_AVAILABLE = RequirementCache("onnx") + if _TORCH_GREATER_EQUAL_2_5: + if _TORCH_GREATER_EQUAL_2_6: + from torch.onnx import ONNXProgram + else: + from torch.onnx._internal.exporter import ONNXProgram # type: ignore[no-redef] warning_cache = WarningCache() log = logging.getLogger(__name__) @@ -381,7 +389,7 @@ def log( logger: Optional[bool] = None, on_step: Optional[bool] = None, on_epoch: Optional[bool] = None, - reduce_fx: Union[str, Callable] = "mean", + reduce_fx: Union[str, Callable[[Any], Any]] = "mean", enable_graph: bool = False, sync_dist: bool = False, sync_dist_group: Optional[Any] = None, @@ -466,10 +474,10 @@ def log( ) # make sure user doesn't introduce logic for multi-dataloaders - if "/dataloader_idx_" in name: + if add_dataloader_idx and "/dataloader_idx_" in name: raise MisconfigurationException( f"You called `self.log` with the key `{name}`" - " but it should not contain information about `dataloader_idx`" + " but it should not contain information about `dataloader_idx` when `add_dataloader_idx=True`" ) value = apply_to_collection(value, (Tensor, numbers.Number), self.__to_tensor, name) @@ -546,7 +554,7 @@ def log_dict( logger: Optional[bool] = None, on_step: Optional[bool] = None, on_epoch: Optional[bool] = None, - reduce_fx: Union[str, Callable] = "mean", + reduce_fx: Union[str, Callable[[Any], Any]] = "mean", enable_graph: bool = False, sync_dist: bool = False, sync_dist_group: Optional[Any] = None, @@ -808,7 +816,22 @@ def validation_step(self, batch, batch_idx): # CASE 2: multiple validation dataloaders def validation_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. - ... + x, y = batch + + # implement your own + out = self(x) + + if dataloader_idx == 0: + loss = self.loss0(out, y) + else: + loss = self.loss1(out, y) + + # calculate acc + labels_hat = torch.argmax(out, dim=1) + acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + + # log the outputs separately for each dataloader + self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc}) Note: If you don't need to validate you don't need to implement this method. @@ -875,7 +898,22 @@ def test_step(self, batch, batch_idx): # CASE 2: multiple test dataloaders def test_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. - ... + x, y = batch + + # implement your own + out = self(x) + + if dataloader_idx == 0: + loss = self.loss0(out, y) + else: + loss = self.loss1(out, y) + + # calculate acc + labels_hat = torch.argmax(out, dim=1) + acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) + + # log the outputs separately for each dataloader + self.log_dict({f"test_loss_{dataloader_idx}": loss, f"test_acc_{dataloader_idx}": acc}) Note: If you don't need to test you don't need to implement this method. @@ -1386,12 +1424,18 @@ def _verify_is_manual_optimization(self, fn_name: str) -> None: ) @torch.no_grad() - def to_onnx(self, file_path: Union[str, Path, BytesIO], input_sample: Optional[Any] = None, **kwargs: Any) -> None: + def to_onnx( + self, + file_path: Union[str, Path, BytesIO, None] = None, + input_sample: Optional[Any] = None, + **kwargs: Any, + ) -> Optional["ONNXProgram"]: """Saves the model in ONNX format. Args: - file_path: The path of the file the onnx model should be saved to. + file_path: The path of the file the onnx model should be saved to. Default: None (no file saved). input_sample: An input for tracing. Default: None (Use self.example_input_array) + **kwargs: Will be passed to torch.onnx.export function. Example:: @@ -1412,6 +1456,12 @@ def forward(self, x): if not _ONNX_AVAILABLE: raise ModuleNotFoundError(f"`{type(self).__name__}.to_onnx()` requires `onnx` to be installed.") + if kwargs.get("dynamo", False) and not (_ONNXSCRIPT_AVAILABLE and _TORCH_GREATER_EQUAL_2_5): + raise ModuleNotFoundError( + f"`{type(self).__name__}.to_onnx(dynamo=True)` " + "requires `onnxscript` and `torch>=2.5.0` to be installed." + ) + mode = self.training if input_sample is None: @@ -1428,8 +1478,9 @@ def forward(self, x): file_path = str(file_path) if isinstance(file_path, Path) else file_path # PyTorch (2.5) declares file_path to be str | PathLike[Any] | None, but # BytesIO does work, too. - torch.onnx.export(self, input_sample, file_path, **kwargs) # type: ignore + ret = torch.onnx.export(self, input_sample, file_path, **kwargs) # type: ignore self.train(mode) + return ret @torch.no_grad() def to_torchscript( diff --git a/src/lightning/pytorch/core/optimizer.py b/src/lightning/pytorch/core/optimizer.py index 46126e212378e..b85e9b2c10e5a 100644 --- a/src/lightning/pytorch/core/optimizer.py +++ b/src/lightning/pytorch/core/optimizer.py @@ -274,7 +274,7 @@ def _configure_schedulers_automatic_opt(schedulers: list, monitor: Optional[str] scheduler["reduce_on_plateau"] = scheduler.get( "reduce_on_plateau", isinstance(scheduler["scheduler"], optim.lr_scheduler.ReduceLROnPlateau) ) - if scheduler["reduce_on_plateau"] and scheduler.get("monitor", None) is None: + if scheduler["reduce_on_plateau"] and scheduler.get("monitor") is None: raise MisconfigurationException( "The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used." ' For example: {"optimizer": optimizer, "lr_scheduler":' diff --git a/src/lightning/pytorch/demos/transformer.py b/src/lightning/pytorch/demos/transformer.py index fefa073fbd310..13b5e05adc680 100644 --- a/src/lightning/pytorch/demos/transformer.py +++ b/src/lightning/pytorch/demos/transformer.py @@ -54,15 +54,24 @@ def __init__( self.ninp = ninp self.vocab_size = vocab_size - self.src_mask = None + self.src_mask: Optional[Tensor] = None + + def generate_square_subsequent_mask(self, size: int) -> Tensor: + """Generate a square mask for the sequence to prevent future tokens from being seen.""" + mask = torch.triu(torch.ones(size, size), diagonal=1) + mask = mask.float().masked_fill(mask == 1, float("-inf")).masked_fill(mask == 0, 0.0) + return mask def forward(self, inputs: Tensor, target: Tensor, mask: Optional[Tensor] = None) -> Tensor: _, t = inputs.shape - # we assume target is already shifted w.r.t. inputs + # Generate source mask to prevent future token leakage + if self.src_mask is None or self.src_mask.size(0) != t: + self.src_mask = self.generate_square_subsequent_mask(t).to(inputs.device) + + # Generate target mask if not provided if mask is None: - mask = torch.tril(torch.ones(t, t, device=inputs.device)) == 1 - mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, 0.0) + mask = self.generate_square_subsequent_mask(t).to(inputs.device) src = self.pos_encoder(self.embedding(inputs) * math.sqrt(self.ninp)) target = self.pos_encoder(self.embedding(target) * math.sqrt(self.ninp)) diff --git a/src/lightning/pytorch/loops/loop.py b/src/lightning/pytorch/loops/loop.py index daad309cd75d4..f4324c003f7a9 100644 --- a/src/lightning/pytorch/loops/loop.py +++ b/src/lightning/pytorch/loops/loop.py @@ -23,6 +23,7 @@ class _Loop: def __init__(self, trainer: "pl.Trainer") -> None: self._restarting = False self._loaded_from_state_dict = False + self._resuming_from_checkpoint = False self.trainer = trainer @property @@ -38,6 +39,11 @@ def restarting(self, restarting: bool) -> None: if isinstance(loop, _Loop): loop.restarting = restarting + @property + def is_resuming(self) -> bool: + """Indicates whether training is being resumed from a checkpoint.""" + return self._resuming_from_checkpoint + def reset_restart_stage(self) -> None: pass @@ -87,6 +93,7 @@ def load_state_dict( v.load_state_dict(state_dict.copy(), prefix + k + ".") self.restarting = True self._loaded_from_state_dict = True + self._resuming_from_checkpoint = True def _load_from_state_dict(self, state_dict: dict, prefix: str) -> None: for k, v in self.__dict__.items(): @@ -102,4 +109,5 @@ def _load_from_state_dict(self, state_dict: dict, prefix: str) -> None: def on_iteration_done(self) -> None: self._restarting = False self._loaded_from_state_dict = False + self._resuming_from_checkpoint = False self.reset_restart_stage() diff --git a/src/lightning/pytorch/loops/optimization/manual.py b/src/lightning/pytorch/loops/optimization/manual.py index e1aabcbf42976..10bd5b8b1c666 100644 --- a/src/lightning/pytorch/loops/optimization/manual.py +++ b/src/lightning/pytorch/loops/optimization/manual.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import OrderedDict +from collections.abc import Mapping from contextlib import suppress from dataclasses import dataclass, field from typing import Any @@ -45,7 +46,7 @@ class ManualResult(OutputResult): @classmethod def from_training_step_output(cls, training_step_output: STEP_OUTPUT) -> "ManualResult": extra = {} - if isinstance(training_step_output, dict): + if isinstance(training_step_output, Mapping): extra = training_step_output.copy() elif isinstance(training_step_output, Tensor): extra = {"loss": training_step_output} diff --git a/src/lightning/pytorch/loops/training_epoch_loop.py b/src/lightning/pytorch/loops/training_epoch_loop.py index 599eccdc8ca91..c0a57ae12c4d1 100644 --- a/src/lightning/pytorch/loops/training_epoch_loop.py +++ b/src/lightning/pytorch/loops/training_epoch_loop.py @@ -237,7 +237,11 @@ def reset(self) -> None: def on_run_start(self, data_fetcher: _DataFetcher) -> None: # `iter()` was called once in `FitLoop.setup_data()` already - if self.trainer.current_epoch > 0 and not self.restarting: + # Call `iter()` again only when: + # 1. Not restarting + # 2. Not resuming from checkpoint (not is_resuming) + # 3. Past first epoch (current_epoch > 0) + if self.trainer.current_epoch > 0 and not self.trainer.fit_loop.is_resuming and not self.restarting: iter(data_fetcher) # creates the iterator inside the fetcher # add the previous `fetched` value to properly track `is_last_batch` with no prefetching diff --git a/src/lightning/pytorch/profilers/advanced.py b/src/lightning/pytorch/profilers/advanced.py index 41681fbd239f3..c0b4b9953cc33 100644 --- a/src/lightning/pytorch/profilers/advanced.py +++ b/src/lightning/pytorch/profilers/advanced.py @@ -19,6 +19,7 @@ import os import pstats import tempfile +from collections import defaultdict from pathlib import Path from typing import Optional, Union @@ -66,14 +67,15 @@ def __init__( If you attempt to stop recording an action which was never started. """ super().__init__(dirpath=dirpath, filename=filename) - self.profiled_actions: dict[str, cProfile.Profile] = {} + self.profiled_actions: dict[str, cProfile.Profile] = defaultdict(cProfile.Profile) self.line_count_restriction = line_count_restriction self.dump_stats = dump_stats @override def start(self, action_name: str) -> None: - if action_name not in self.profiled_actions: - self.profiled_actions[action_name] = cProfile.Profile() + # Disable all profilers before starting a new one + for pr in self.profiled_actions.values(): + pr.disable() self.profiled_actions[action_name].enable() @override @@ -114,7 +116,7 @@ def summary(self) -> str: @override def teardown(self, stage: Optional[str]) -> None: super().teardown(stage=stage) - self.profiled_actions = {} + self.profiled_actions.clear() def __reduce__(self) -> tuple: # avoids `TypeError: cannot pickle 'cProfile.Profile' object` diff --git a/src/lightning/pytorch/trainer/call.py b/src/lightning/pytorch/trainer/call.py index b5354eb2b08dd..77536cdc16b33 100644 --- a/src/lightning/pytorch/trainer/call.py +++ b/src/lightning/pytorch/trainer/call.py @@ -13,6 +13,7 @@ # limitations under the License. import logging import signal +import sys from copy import deepcopy from typing import Any, Callable, Optional, Union @@ -62,7 +63,7 @@ def _call_and_handle_interrupt(trainer: "pl.Trainer", trainer_fn: Callable, *arg launcher = trainer.strategy.launcher if isinstance(launcher, _SubprocessScriptLauncher): launcher.kill(_get_sigkill_signal()) - exit(1) + sys.exit(1) except BaseException as exception: _interrupt(trainer, exception) diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 1423c1aeeafe4..7f44de0589938 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -29,7 +29,7 @@ SLURMEnvironment, TorchElasticEnvironment, ) -from lightning.fabric.utilities.device_parser import _determine_root_gpu_device +from lightning.fabric.utilities.device_parser import _determine_root_gpu_device, _select_auto_accelerator from lightning.fabric.utilities.imports import _IS_INTERACTIVE from lightning.pytorch.accelerators import AcceleratorRegistry from lightning.pytorch.accelerators.accelerator import Accelerator @@ -332,18 +332,12 @@ def _check_device_config_and_set_final_flags(self, devices: Union[list[int], str @staticmethod def _choose_auto_accelerator() -> str: """Choose the accelerator type (str) based on availability.""" - if XLAAccelerator.is_available(): - return "tpu" if _habana_available_and_importable(): from lightning_habana import HPUAccelerator if HPUAccelerator.is_available(): return "hpu" - if MPSAccelerator.is_available(): - return "mps" - if CUDAAccelerator.is_available(): - return "cuda" - return "cpu" + return _select_auto_accelerator() @staticmethod def _choose_gpu_accelerator_backend() -> str: @@ -459,10 +453,11 @@ def _check_strategy_and_fallback(self) -> None: if ( strategy_flag in FSDPStrategy.get_registered_strategies() or type(self._strategy_flag) is FSDPStrategy - ) and self._accelerator_flag not in ("cuda", "gpu"): + ) and not (self._accelerator_flag in ("cuda", "gpu") or isinstance(self._accelerator_flag, CUDAAccelerator)): raise ValueError( - f"The strategy `{FSDPStrategy.strategy_name}` requires a GPU accelerator, but got:" - f" {self._accelerator_flag}" + f"The strategy `{FSDPStrategy.strategy_name}` requires a GPU accelerator, but received " + f"`accelerator={self._accelerator_flag!r}`. Please set `accelerator='cuda'`, `accelerator='gpu'`," + " or pass a `CUDAAccelerator()` instance to use FSDP." ) if strategy_flag in _DDP_FORK_ALIASES and "fork" not in torch.multiprocessing.get_all_start_methods(): raise ValueError( diff --git a/src/lightning/pytorch/trainer/connectors/data_connector.py b/src/lightning/pytorch/trainer/connectors/data_connector.py index 3e5273085ed2b..841d78b457d48 100644 --- a/src/lightning/pytorch/trainer/connectors/data_connector.py +++ b/src/lightning/pytorch/trainer/connectors/data_connector.py @@ -244,15 +244,23 @@ def _get_distributed_sampler( def _resolve_overfit_batches(combined_loader: CombinedLoader, mode: RunningStage) -> None: + """Resolve overfit batches by disabling shuffling. + + When overfit_batches > 0, this function ensures that sequential sampling is used without shuffling for consistent + batches across epochs. Training and validation use different sets of data. + + """ all_have_sequential_sampler = all( isinstance(dl.sampler, SequentialSampler) for dl in combined_loader.flattened if hasattr(dl, "sampler") ) if all_have_sequential_sampler: return + rank_zero_warn( f"You requested to overfit but enabled {mode.dataloader_prefix} dataloader shuffling." f" We are turning off the {mode.dataloader_prefix} dataloader shuffling for you." ) + updated = [ _update_dataloader(dl, sampler=SequentialSampler(dl.dataset), mode=mode) if hasattr(dl, "dataset") else dl for dl in combined_loader.flattened diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py index 90ae28bb8c7ee..1a4aa7c401960 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py @@ -25,9 +25,9 @@ from lightning.fabric.utilities import move_data_to_device from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars from lightning.fabric.utilities.distributed import _distributed_is_initialized +from lightning.fabric.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_1_0_0 from lightning.pytorch.utilities.data import extract_batch_size from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_1_0_0 from lightning.pytorch.utilities.memory import recursive_detach from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn from lightning.pytorch.utilities.warnings import PossibleUserWarning diff --git a/src/lightning/pytorch/tuner/lr_finder.py b/src/lightning/pytorch/tuner/lr_finder.py index b50bedb10d53f..a5d758f7fff19 100644 --- a/src/lightning/pytorch/tuner/lr_finder.py +++ b/src/lightning/pytorch/tuner/lr_finder.py @@ -71,26 +71,10 @@ class _LRFinder: Args: mode: either `linear` or `exponential`, how to increase lr after each step - lr_min: lr to start search from - lr_max: lr to stop search - num_training: number of steps to take between lr_min and lr_max - Example:: - # Run lr finder - lr_finder = trainer.lr_find(model) - - # Results stored in - lr_finder.results - - # Plot using - lr_finder.plot() - - # Get suggestion - lr = lr_finder.suggestion() - """ def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int) -> None: @@ -138,10 +122,9 @@ def plot( """Plot results from lr_find run Args: suggest: if True, will mark suggested lr to use with a red point - show: if True, will show figure - ax: Axes object to which the plot is to be drawn. If not provided, a new figure is created. + """ if not _MATPLOTLIB_AVAILABLE: raise MisconfigurationException( @@ -190,7 +173,10 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float] """ losses = torch.tensor(self.results["loss"][skip_begin:-skip_end]) - losses = losses[torch.isfinite(losses)] + lrs = torch.tensor(self.results["lr"][skip_begin:-skip_end]) + is_finite = torch.isfinite(losses) + losses = losses[is_finite] + lrs = lrs[is_finite] if len(losses) < 2: # computing torch.gradient requires at least 2 points @@ -201,12 +187,12 @@ def suggestion(self, skip_begin: int = 10, skip_end: int = 1) -> Optional[float] self._optimal_idx = None return None - # TODO: When computing the argmin here, and some losses are non-finite, the expected indices could be - # incorrectly shifted by an offset - gradients = torch.gradient(losses)[0] # Unpack the tuple + gradients = torch.gradient(losses, spacing=[lrs])[0] # Compute the gradient of losses w.r.t. learning rates min_grad = torch.argmin(gradients).item() - - self._optimal_idx = min_grad + skip_begin + all_losses_idx = torch.arange(len(self.results["loss"])) + idx_non_skipped = all_losses_idx[skip_begin:-skip_end] + idx_finite = idx_non_skipped[is_finite] + self._optimal_idx = idx_finite[min_grad].item() # type: ignore return self.results["lr"][self._optimal_idx] @@ -306,7 +292,8 @@ def _lr_find( trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True trainer.fit_loop.epoch_loop.val_loop._combined_loader = None - + trainer.fit_loop._combined_loader = None # reset data fetcher to avoid issues with the next fit + trainer.fit_loop.setup_data() return lr_finder diff --git a/src/lightning/pytorch/utilities/imports.py b/src/lightning/pytorch/utilities/imports.py index 6c0815a6af9dc..5572f1d20d3d6 100644 --- a/src/lightning/pytorch/utilities/imports.py +++ b/src/lightning/pytorch/utilities/imports.py @@ -14,17 +14,20 @@ """General utilities.""" import functools +import operator import sys -from lightning_utilities.core.imports import RequirementCache, package_available +from lightning_utilities.core.imports import RequirementCache, compare_version, package_available from lightning.pytorch.utilities.rank_zero import rank_zero_warn _PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11) +_TORCH_GREATER_EQUAL_2_6 = compare_version("torch", operator.ge, "2.6.0") _TORCHMETRICS_GREATER_EQUAL_0_8_0 = RequirementCache("torchmetrics>=0.8.0") _TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1") _TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task _TORCHMETRICS_GREATER_EQUAL_1_0_0 = RequirementCache("torchmetrics>=1.0.0") +_TORCH_EQUAL_2_8 = RequirementCache("torch>=2.8.0,<2.9.0") _OMEGACONF_AVAILABLE = package_available("omegaconf") _TORCHVISION_AVAILABLE = RequirementCache("torchvision") diff --git a/src/lightning/pytorch/utilities/model_summary/model_summary.py b/src/lightning/pytorch/utilities/model_summary/model_summary.py index 6a5baf2c1e04a..8cd66c7089c94 100644 --- a/src/lightning/pytorch/utilities/model_summary/model_summary.py +++ b/src/lightning/pytorch/utilities/model_summary/model_summary.py @@ -25,6 +25,7 @@ from torch.utils.hooks import RemovableHandle import lightning.pytorch as pl +from lightning.fabric.utilities import rank_zero_warn from lightning.fabric.utilities.distributed import _is_dtensor from lightning.pytorch.utilities.model_helpers import _ModuleMode from lightning.pytorch.utilities.rank_zero import WarningCache @@ -216,7 +217,22 @@ def __init__(self, model: "pl.LightningModule", max_depth: int = 1) -> None: self._layer_summary = self.summarize() # 1 byte -> 8 bits # TODO: how do we compute precision_megabytes in case of mixed precision? - precision_to_bits = {"64": 64, "32": 32, "16": 16, "bf16": 16} + precision_to_bits = { + "64": 64, + "32": 32, + "16": 16, + "bf16": 16, + "16-true": 16, + "bf16-true": 16, + "32-true": 32, + "64-true": 64, + } + if self._model._trainer and self._model.trainer.precision not in precision_to_bits: + rank_zero_warn( + f"Precision {self._model.trainer.precision} is not supported by the model summary. " + " Estimated model size in MB will not be accurate. Using 32 bits instead.", + category=UserWarning, + ) precision = precision_to_bits.get(self._model.trainer.precision, 32) if self._model._trainer else 32 self._precision_megabytes = (precision / 8.0) * 1e-6 diff --git a/src/lightning/pytorch/utilities/parsing.py b/src/lightning/pytorch/utilities/parsing.py index 16eef555291bd..829cc7a994b93 100644 --- a/src/lightning/pytorch/utilities/parsing.py +++ b/src/lightning/pytorch/utilities/parsing.py @@ -167,7 +167,8 @@ def save_hyperparameters( if given_hparams is not None: init_args = given_hparams elif is_dataclass(obj): - init_args = {f.name: getattr(obj, f.name) for f in fields(obj)} + obj_fields = fields(obj) + init_args = {f.name: getattr(obj, f.name) for f in obj_fields if f.init} else: init_args = {} diff --git a/src/lightning/pytorch/utilities/testing/_runif.py b/src/lightning/pytorch/utilities/testing/_runif.py index 9c46913681143..0d25cfd1b86ee 100644 --- a/src/lightning/pytorch/utilities/testing/_runif.py +++ b/src/lightning/pytorch/utilities/testing/_runif.py @@ -15,10 +15,10 @@ from lightning_utilities.core.imports import RequirementCache -from lightning.fabric.utilities.testing import _runif_reasons as fabric_run_if +from lightning.fabric.utilities.testing import _runif_reasons as _fabric_run_if from lightning.pytorch.accelerators.cpu import _PSUTIL_AVAILABLE from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE -from lightning.pytorch.core.module import _ONNX_AVAILABLE +from lightning.pytorch.core.module import _ONNX_AVAILABLE, _ONNXSCRIPT_AVAILABLE from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE _SKLEARN_AVAILABLE = RequirementCache("scikit-learn") @@ -42,11 +42,13 @@ def _runif_reasons( psutil: bool = False, sklearn: bool = False, onnx: bool = False, + linux_only: bool = False, + onnxscript: bool = False, ) -> tuple[list[str], dict[str, bool]]: """Construct reasons for pytest skipif. Args: - min_cuda_gpus: Require this number of gpus and that the ``PL_RUN_CUDA_TESTS=1`` environment variable is set. + min_cuda_gpus: Require this number of gpus and that the ``RUN_ONLY_CUDA_TESTS=1`` environment variable is set. min_torch: Require that PyTorch is greater or equal than this version. max_torch: Require that PyTorch is less than this version. min_python: Require that Python is greater or equal than this version. @@ -64,10 +66,11 @@ def _runif_reasons( psutil: Require that psutil is installed. sklearn: Require that scikit-learn is installed. onnx: Require that onnx is installed. + onnxscript: Require that onnxscript is installed. """ - reasons, kwargs = fabric_run_if( + reasons, kwargs = _fabric_run_if( min_cuda_gpus=min_cuda_gpus, min_torch=min_torch, max_torch=max_torch, @@ -79,6 +82,7 @@ def _runif_reasons( standalone=standalone, deepspeed=deepspeed, dynamo=dynamo, + linux_only=linux_only, ) if rich and not _RICH_AVAILABLE: @@ -96,4 +100,7 @@ def _runif_reasons( if onnx and not _ONNX_AVAILABLE: reasons.append("onnx") + if onnxscript and not _ONNXSCRIPT_AVAILABLE: + reasons.append("onnxscript") + return reasons, kwargs diff --git a/src/lightning/pytorch/utilities/upgrade_checkpoint.py b/src/lightning/pytorch/utilities/upgrade_checkpoint.py index 04cf000283d77..8aa091ab90f18 100644 --- a/src/lightning/pytorch/utilities/upgrade_checkpoint.py +++ b/src/lightning/pytorch/utilities/upgrade_checkpoint.py @@ -13,6 +13,7 @@ # limitations under the License. import glob import logging +import sys from argparse import ArgumentParser, Namespace from pathlib import Path from shutil import copyfile @@ -35,7 +36,7 @@ def _upgrade(args: Namespace) -> None: f"The path {path} does not exist. Please provide a valid path to a checkpoint file or a directory" f" containing checkpoints ending in {extension}." ) - exit(1) + sys.exit(1) if path.is_file(): files = [path] @@ -46,7 +47,7 @@ def _upgrade(args: Namespace) -> None: f"No checkpoint files with extension {extension} were found in {path}." f" HINT: Try setting the `--extension` option to specify the right file extension to look for." ) - exit(1) + sys.exit(1) _log.info("Creating a backup of the existing checkpoint files before overwriting in the upgrade process.") for file in files: diff --git a/src/lightning_fabric/__about__.py b/src/lightning_fabric/__about__.py index 89cba35b3d9d4..6d222eed108fa 100644 --- a/src/lightning_fabric/__about__.py +++ b/src/lightning_fabric/__about__.py @@ -1,7 +1,7 @@ import time __author__ = "Lightning AI et al." -__author_email__ = "pytorch@lightning.ai" +__author_email__ = "developer@lightning.ai" __license__ = "Apache-2.0" __copyright__ = f"Copyright (c) 2022-{time.strftime('%Y')}, {__author__}." __homepage__ = "https://github.com/Lightning-AI/lightning" diff --git a/src/pytorch_lightning/README.md b/src/pytorch_lightning/README.md index e3878c501967f..c0699e7695496 100644 --- a/src/pytorch_lightning/README.md +++ b/src/pytorch_lightning/README.md @@ -370,7 +370,7 @@ The PyTorch Lightning community is maintained by - [10+ core contributors](https://lightning.ai/docs/pytorch/stable/community/governance.html) who are all a mix of professional engineers, Research Scientists, and Ph.D. students from top AI labs. - 680+ active community contributors. -Want to help us build Lightning and reduce boilerplate for thousands of researchers? [Learn how to make your first contribution here](https://devblog.pytorchlightning.ai/quick-contribution-guide-86d977171b3a) +Want to help us build Lightning and reduce boilerplate for thousands of researchers? [Learn how to make your first contribution here](https://medium.com/pytorch-lightning/quick-contribution-guide-86d977171b3a) PyTorch Lightning is also part of the [PyTorch ecosystem](https://pytorch.org/ecosystem/) which requires projects to have solid testing, documentation and support. diff --git a/src/pytorch_lightning/__about__.py b/src/pytorch_lightning/__about__.py index 297d32a70cd6b..aadfc73cfb1c5 100644 --- a/src/pytorch_lightning/__about__.py +++ b/src/pytorch_lightning/__about__.py @@ -14,7 +14,7 @@ import time __author__ = "Lightning AI et al." -__author_email__ = "pytorch@lightning.ai" +__author_email__ = "developer@lightning.ai" __license__ = "Apache-2.0" __copyright__ = f"Copyright (c) 2018-{time.strftime('%Y')}, {__author__}." __homepage__ = "https://github.com/Lightning-AI/lightning" diff --git a/src/version.info b/src/version.info index f225a78adf053..aedc15bb0c6e2 100644 --- a/src/version.info +++ b/src/version.info @@ -1 +1 @@ -2.5.2 +2.5.3 diff --git a/tests/README.md b/tests/README.md index 26e49c4f5751f..e542bfb0db180 100644 --- a/tests/README.md +++ b/tests/README.md @@ -26,7 +26,7 @@ Additionally, for testing backward compatibility with older versions of PyTorch bash .actions/pull_legacy_checkpoints.sh ``` -Note: These checkpoints are generated to set baselines for maintaining backward compatibility with legacy versions of PyTorch Lightning. Details of checkpoints for back-compatibility can be found [here](https://github.com/Lightning-AI/pytorch-lightning/blob/master/tests/legacy/README.md). +Note: These checkpoints are generated to set baselines for maintaining backward compatibility with legacy versions of PyTorch Lightning. Details of checkpoints for back-compatibility can be found [here](https://github.com/Lightning-AI/pytorch-lightning/tree/master/tests/legacy/README.md). You can run the full test suite in your terminal via this make script: diff --git a/tests/legacy/back-compatible-versions.txt b/tests/legacy/back-compatible-versions.txt index 996f1340747dc..091d993cdd725 100644 --- a/tests/legacy/back-compatible-versions.txt +++ b/tests/legacy/back-compatible-versions.txt @@ -105,3 +105,4 @@ 2.3.2 2.3.3 2.5.1 +2.5.2 diff --git a/tests/tests_fabric/accelerators/test_registry.py b/tests/tests_fabric/accelerators/test_registry.py index 8036a6f45b8a0..b88ecf1db1e57 100644 --- a/tests/tests_fabric/accelerators/test_registry.py +++ b/tests/tests_fabric/accelerators/test_registry.py @@ -16,6 +16,38 @@ import torch from lightning.fabric.accelerators import ACCELERATOR_REGISTRY, Accelerator +from lightning.fabric.accelerators.registry import _AcceleratorRegistry + + +class TestAccelerator(Accelerator): + """Helper accelerator class for testing.""" + + def __init__(self, param1=None, param2=None): + self.param1 = param1 + self.param2 = param2 + super().__init__() + + def setup_device(self, device: torch.device) -> None: + pass + + def teardown(self) -> None: + pass + + @staticmethod + def parse_devices(devices): + return devices + + @staticmethod + def get_parallel_devices(devices): + return ["foo"] * devices + + @staticmethod + def auto_device_count(): + return 3 + + @staticmethod + def is_available(): + return True def test_accelerator_registry_with_new_accelerator(): @@ -71,3 +103,75 @@ def is_available(): def test_available_accelerators_in_registry(): assert ACCELERATOR_REGISTRY.available_accelerators() == {"cpu", "cuda", "mps", "tpu"} + + +def test_registry_as_decorator(): + """Test that the registry can be used as a decorator.""" + test_registry = _AcceleratorRegistry() + + # Test decorator usage + @test_registry.register("test_decorator", description="Test decorator accelerator", param1="value1", param2=42) + class DecoratorAccelerator(TestAccelerator): + pass + + # Verify registration worked + assert "test_decorator" in test_registry + assert test_registry["test_decorator"]["description"] == "Test decorator accelerator" + assert test_registry["test_decorator"]["init_params"] == {"param1": "value1", "param2": 42} + assert test_registry["test_decorator"]["accelerator"] == DecoratorAccelerator + assert test_registry["test_decorator"]["accelerator_name"] == "test_decorator" + + # Test that we can instantiate the accelerator + instance = test_registry.get("test_decorator") + assert isinstance(instance, DecoratorAccelerator) + assert instance.param1 == "value1" + assert instance.param2 == 42 + + +def test_registry_as_static_method(): + """Test that the registry can be used as a static method call.""" + test_registry = _AcceleratorRegistry() + + class StaticMethodAccelerator(TestAccelerator): + pass + + # Test static method usage + result = test_registry.register( + "test_static", + StaticMethodAccelerator, + description="Test static method accelerator", + param1="static_value", + param2=100, + ) + + # Verify registration worked + assert "test_static" in test_registry + assert test_registry["test_static"]["description"] == "Test static method accelerator" + assert test_registry["test_static"]["init_params"] == {"param1": "static_value", "param2": 100} + assert test_registry["test_static"]["accelerator"] == StaticMethodAccelerator + assert test_registry["test_static"]["accelerator_name"] == "test_static" + assert result == StaticMethodAccelerator # Should return the accelerator class + + # Test that we can instantiate the accelerator + instance = test_registry.get("test_static") + assert isinstance(instance, StaticMethodAccelerator) + assert instance.param1 == "static_value" + assert instance.param2 == 100 + + +def test_registry_without_parameters(): + """Test registration without init parameters.""" + test_registry = _AcceleratorRegistry() + + class SimpleAccelerator(TestAccelerator): + def __init__(self): + super().__init__() + + test_registry.register("simple", SimpleAccelerator, description="Simple accelerator") + + assert "simple" in test_registry + assert test_registry["simple"]["description"] == "Simple accelerator" + assert test_registry["simple"]["init_params"] == {} + + instance = test_registry.get("simple") + assert isinstance(instance, SimpleAccelerator) diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index 68f3f2cc38191..9d4a0b9462f2e 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -212,7 +212,7 @@ def pytest_collection_modifyitems(items: list[pytest.Function], config: pytest.C options = { "standalone": "PL_RUN_STANDALONE_TESTS", - "min_cuda_gpus": "PL_RUN_CUDA_TESTS", + "min_cuda_gpus": "RUN_ONLY_CUDA_TESTS", "tpu": "PL_RUN_TPU_TESTS", } if os.getenv(options["standalone"], "0") == "1" and os.getenv(options["min_cuda_gpus"], "0") == "1": diff --git a/tests/tests_fabric/plugins/environments/test_xla.py b/tests/tests_fabric/plugins/environments/test_xla.py index 7e33d5db87dd4..76fd18012ee94 100644 --- a/tests/tests_fabric/plugins/environments/test_xla.py +++ b/tests/tests_fabric/plugins/environments/test_xla.py @@ -97,3 +97,64 @@ def test_detect(monkeypatch): monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "is_available", lambda: True) assert XLAEnvironment.detect() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True) +@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True) +def test_world_size_from_xla_runtime_greater_2_1(xla_available): + """Test that world_size uses torch_xla.runtime when XLA >= 2.1.""" + env = XLAEnvironment() + + with mock.patch("torch_xla.runtime.world_size", return_value=4) as mock_world_size: + env.world_size.cache_clear() + assert env.world_size() == 4 + mock_world_size.assert_called_once() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True) +@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True) +def test_global_rank_from_xla_runtime_greater_2_1(xla_available): + """Test that global_rank uses torch_xla.runtime when XLA >= 2.1.""" + env = XLAEnvironment() + + with mock.patch("torch_xla.runtime.global_ordinal", return_value=2) as mock_global_ordinal: + env.global_rank.cache_clear() + assert env.global_rank() == 2 + mock_global_ordinal.assert_called_once() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True) +@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True) +def test_local_rank_from_xla_runtime_greater_2_1(xla_available): + """Test that local_rank uses torch_xla.runtime when XLA >= 2.1.""" + env = XLAEnvironment() + + with mock.patch("torch_xla.runtime.local_ordinal", return_value=1) as mock_local_ordinal: + env.local_rank.cache_clear() + assert env.local_rank() == 1 + mock_local_ordinal.assert_called_once() + + +@mock.patch.dict(os.environ, {}, clear=True) +@mock.patch("lightning.fabric.accelerators.xla._XLA_GREATER_EQUAL_2_1", True) +@mock.patch("lightning.fabric.plugins.environments.xla._XLA_GREATER_EQUAL_2_1", True) +def test_setters_readonly_when_xla_runtime_greater_2_1(xla_available): + """Test that set_world_size and set_global_rank don't affect values when using XLA runtime >= 2.1.""" + env = XLAEnvironment() + + with ( + mock.patch("torch_xla.runtime.world_size", return_value=4), + mock.patch("torch_xla.runtime.global_ordinal", return_value=2), + ): + env.world_size.cache_clear() + env.global_rank.cache_clear() + + # Values should come from XLA runtime and not be affected by setters + env.set_world_size(100) + assert env.world_size() == 4 + + env.set_global_rank(100) + assert env.global_rank() == 2 diff --git a/tests/tests_fabric/plugins/precision/test_amp_integration.py b/tests/tests_fabric/plugins/precision/test_amp_integration.py index bcbd9435d47ac..a01a597811a63 100644 --- a/tests/tests_fabric/plugins/precision/test_amp_integration.py +++ b/tests/tests_fabric/plugins/precision/test_amp_integration.py @@ -42,8 +42,8 @@ def forward(self, x): @pytest.mark.parametrize( ("accelerator", "precision", "expected_dtype"), [ - ("cpu", "16-mixed", torch.bfloat16), - ("cpu", "bf16-mixed", torch.bfloat16), + pytest.param("cpu", "16-mixed", torch.bfloat16, marks=RunIf(skip_windows=True)), + pytest.param("cpu", "bf16-mixed", torch.bfloat16, marks=RunIf(skip_windows=True)), pytest.param("cuda", "16-mixed", torch.float16, marks=RunIf(min_cuda_gpus=2)), pytest.param("cuda", "bf16-mixed", torch.bfloat16, marks=RunIf(min_cuda_gpus=2, bf16_cuda=True)), ], diff --git a/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py b/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py index 6ae96b9bcafc6..2abfe73c92dec 100644 --- a/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py +++ b/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py @@ -29,7 +29,8 @@ def __init__(self): self.register_buffer("buffer", torch.ones(3)) -@pytest.mark.parametrize("strategy", ["ddp_spawn", pytest.param("ddp_fork", marks=RunIf(skip_windows=True))]) +@RunIf(skip_windows=True) +@pytest.mark.parametrize("strategy", ["ddp_spawn", "ddp_fork"]) def test_memory_sharing_disabled(strategy): """Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race conditions on model updates.""" diff --git a/tests/tests_fabric/strategies/test_ddp_integration.py b/tests/tests_fabric/strategies/test_ddp_integration.py index 9d43724228cd2..241d624483b1e 100644 --- a/tests/tests_fabric/strategies/test_ddp_integration.py +++ b/tests/tests_fabric/strategies/test_ddp_integration.py @@ -36,7 +36,7 @@ @pytest.mark.parametrize( "accelerator", [ - "cpu", + pytest.param("cpu", marks=RunIf(skip_windows=True)), pytest.param("cuda", marks=RunIf(min_cuda_gpus=2)), ], ) diff --git a/tests/tests_fabric/strategies/test_model_parallel_integration.py b/tests/tests_fabric/strategies/test_model_parallel_integration.py index bddfadd9a2c54..4c11fb0edcd78 100644 --- a/tests/tests_fabric/strategies/test_model_parallel_integration.py +++ b/tests/tests_fabric/strategies/test_model_parallel_integration.py @@ -132,11 +132,15 @@ def fn(model, device_mesh): @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=2) -@pytest.mark.parametrize( - "compile", - [True, False], +@pytest.mark.parametrize("compile", [True, False]) +@pytest.mark.xfail( + raises=AssertionError, + reason="Test left zombie thread", + strict=False, + run=True, + condition=lambda e: isinstance(e, AssertionError) and str(e).startswith("Test left zombie thread"), ) -def test_tensor_parallel(distributed, compile): +def test_tensor_parallel(distributed, compile: bool): from torch.distributed._tensor import DTensor parallelize = _parallelize_feed_forward_tp @@ -185,10 +189,7 @@ def test_tensor_parallel(distributed, compile): @RunIf(min_torch="2.4", standalone=True, min_cuda_gpus=4) -@pytest.mark.parametrize( - "compile", - [True, False], -) +@pytest.mark.parametrize("compile", [True, False]) def test_fsdp2_tensor_parallel(distributed, compile): from torch.distributed._tensor import DTensor diff --git a/tests/tests_fabric/test_cli.py b/tests/tests_fabric/test_cli.py index 944584114184b..e71c42bb46e13 100644 --- a/tests/tests_fabric/test_cli.py +++ b/tests/tests_fabric/test_cli.py @@ -46,7 +46,7 @@ def test_run_env_vars_defaults(monkeypatch, fake_script): assert "LT_PRECISION" not in os.environ -@pytest.mark.parametrize("accelerator", ["cpu", "gpu", "cuda", pytest.param("mps", marks=RunIf(mps=True))]) +@pytest.mark.parametrize("accelerator", ["cpu", "gpu", "cuda", "auto", pytest.param("mps", marks=RunIf(mps=True))]) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) @mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2) def test_run_env_vars_accelerator(_, accelerator, monkeypatch, fake_script): @@ -85,7 +85,7 @@ def test_run_env_vars_unsupported_strategy(strategy, fake_script): assert f"Invalid value for '--strategy': '{strategy}'" in ioerr.getvalue() -@pytest.mark.parametrize("devices", ["1", "2", "0,", "1,0", "-1"]) +@pytest.mark.parametrize("devices", ["1", "2", "0,", "1,0", "-1", "auto"]) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) @mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2) def test_run_env_vars_devices_cuda(_, devices, monkeypatch, fake_script): @@ -97,7 +97,7 @@ def test_run_env_vars_devices_cuda(_, devices, monkeypatch, fake_script): @RunIf(mps=True) -@pytest.mark.parametrize("accelerator", ["mps", "gpu"]) +@pytest.mark.parametrize("accelerator", ["mps", "gpu", "auto"]) @mock.patch.dict(os.environ, os.environ.copy(), clear=True) def test_run_env_vars_devices_mps(accelerator, monkeypatch, fake_script): monkeypatch.setitem(sys.modules, "torch.distributed.run", Mock()) diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index fa9cc0ed40e93..d65eaa810ff4d 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -105,8 +105,8 @@ def _test_all_reduce(strategy): assert result is tensor # inplace -# flaky with "process 0 terminated with signal SIGABRT" (GLOO) -@pytest.mark.flaky(reruns=3, only_rerun="torch.multiprocessing.spawn.ProcessExitedException") +# flaky with "torch.multiprocessing.spawn.ProcessExitedException: process 0 terminated with signal SIGABRT" (GLOO) +@pytest.mark.flaky(reruns=3) @RunIf(skip_windows=True) @pytest.mark.parametrize( "process", @@ -128,9 +128,10 @@ def test_collective_operations(devices, process): @pytest.mark.skipif( - RequirementCache("torch<2.4") and RequirementCache("numpy>=2.0"), + RequirementCache("numpy>=2.0"), reason="torch.distributed not compatible with numpy>=2.0", ) +@RunIf(min_torch="2.4", skip_windows=True) @pytest.mark.flaky(reruns=3) # flaky with "process 0 terminated with signal SIGABRT" (GLOO) def test_is_shared_filesystem(tmp_path, monkeypatch): # In the non-distributed case, every location is interpreted as 'shared' diff --git a/tests/tests_fabric/utilities/test_seed.py b/tests/tests_fabric/utilities/test_seed.py index 4a948a5f98736..2700213747f9a 100644 --- a/tests/tests_fabric/utilities/test_seed.py +++ b/tests/tests_fabric/utilities/test_seed.py @@ -47,19 +47,29 @@ def test_correct_seed_with_environment_variable(): @mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}, clear=True) def test_invalid_seed(): - """Ensure that we still fix the seed even if an invalid seed is given.""" - with pytest.warns(UserWarning, match="Invalid seed found"): - seed = seed_everything() - assert seed == 0 + """Ensure that a ValueError is raised if an invalid seed is given.""" + with pytest.raises(ValueError, match="Invalid seed specified"): + seed_everything() @mock.patch.dict(os.environ, {}, clear=True) @pytest.mark.parametrize("seed", [10e9, -10e9]) def test_out_of_bounds_seed(seed): - """Ensure that we still fix the seed even if an out-of-bounds seed is given.""" - with pytest.warns(UserWarning, match="is not in bounds"): - actual = seed_everything(seed) - assert actual == 0 + """Ensure that a ValueError is raised if an out-of-bounds seed is given.""" + with pytest.raises(ValueError, match="is not in bounds"): + seed_everything(seed) + + +def test_seed_everything_accepts_valid_seed_argument(): + """Ensure that seed_everything returns the provided valid seed.""" + seed_value = 45 + assert seed_everything(seed_value) == seed_value + + +@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "17"}, clear=True) +def test_seed_everything_accepts_valid_seed_from_env(): + """Ensure that seed_everything uses the valid seed from the PL_GLOBAL_SEED environment variable.""" + assert seed_everything() == 17 def test_reset_seed_no_op(): diff --git a/tests/tests_fabric/utilities/test_spike.py b/tests/tests_fabric/utilities/test_spike.py index 6054bf224d3df..e96a5f77df384 100644 --- a/tests/tests_fabric/utilities/test_spike.py +++ b/tests/tests_fabric/utilities/test_spike.py @@ -1,11 +1,12 @@ import contextlib -import sys import pytest import torch from lightning.fabric import Fabric -from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, SpikeDetection, TrainingSpikeException +from lightning.fabric.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_1_0_0 +from lightning.fabric.utilities.spike import SpikeDetection, TrainingSpikeException +from tests_fabric.helpers.runif import RunIf def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise): @@ -32,6 +33,8 @@ def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise): @pytest.mark.flaky(max_runs=3) @pytest.mark.parametrize( ("global_rank_spike", "num_devices", "spike_value", "finite_only"), + # NOTE FOR ALL FOLLOWING TESTS: + # adding run on linux only because multiprocessing on other platforms takes forever [ pytest.param(0, 1, None, True), pytest.param(0, 1, None, False), @@ -41,150 +44,22 @@ def spike_detection_test(fabric, global_rank_spike, spike_value, should_raise): pytest.param(0, 1, float("-inf"), False), pytest.param(0, 1, float("NaN"), True), pytest.param(0, 1, float("NaN"), False), - pytest.param( - 0, - 2, - None, - True, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 0, - 2, - None, - False, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 1, - 2, - None, - True, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 1, - 2, - None, - False, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 0, - 2, - float("inf"), - True, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 0, - 2, - float("inf"), - False, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 1, - 2, - float("inf"), - True, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 1, - 2, - float("inf"), - False, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 0, - 2, - float("-inf"), - True, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 0, - 2, - float("-inf"), - False, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 1, - 2, - float("-inf"), - True, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 1, - 2, - float("-inf"), - False, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 0, - 2, - float("NaN"), - True, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 0, - 2, - float("NaN"), - False, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 1, - 2, - float("NaN"), - True, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 1, - 2, - float("NaN"), - False, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), + pytest.param(0, 2, None, True, marks=RunIf(linux_only=True)), + pytest.param(0, 2, None, False, marks=RunIf(linux_only=True)), + pytest.param(1, 2, None, True, marks=RunIf(linux_only=True)), + pytest.param(1, 2, None, False, marks=RunIf(linux_only=True)), + pytest.param(0, 2, float("inf"), True, marks=RunIf(linux_only=True)), + pytest.param(0, 2, float("inf"), False, marks=RunIf(linux_only=True)), + pytest.param(1, 2, float("inf"), True, marks=RunIf(linux_only=True)), + pytest.param(1, 2, float("inf"), False, marks=RunIf(linux_only=True)), + pytest.param(0, 2, float("-inf"), True, marks=RunIf(linux_only=True)), + pytest.param(0, 2, float("-inf"), False, marks=RunIf(linux_only=True)), + pytest.param(1, 2, float("-inf"), True, marks=RunIf(linux_only=True)), + pytest.param(1, 2, float("-inf"), False, marks=RunIf(linux_only=True)), + pytest.param(0, 2, float("NaN"), True, marks=RunIf(linux_only=True)), + pytest.param(0, 2, float("NaN"), False, marks=RunIf(linux_only=True)), + pytest.param(1, 2, float("NaN"), True, marks=RunIf(linux_only=True)), + pytest.param(1, 2, float("NaN"), False, marks=RunIf(linux_only=True)), ], ) @pytest.mark.skipif(not _TORCHMETRICS_GREATER_EQUAL_1_0_0, reason="requires torchmetrics>=1.0.0") diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index 430fb9842cddc..639414a797aa0 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -246,6 +246,9 @@ def test_rich_progress_bar_with_refresh_rate(tmp_path, refresh_rate, train_batch with mock.patch.object( trainer.progress_bar_callback.progress, "update", wraps=trainer.progress_bar_callback.progress.update ) as progress_update: + metrics_update = mock.MagicMock() + trainer.progress_bar_callback._update_metrics = metrics_update + trainer.fit(model) assert progress_update.call_count == expected_call_count @@ -260,6 +263,9 @@ def test_rich_progress_bar_with_refresh_rate(tmp_path, refresh_rate, train_batch assert fit_val_bar.total == val_batches assert not fit_val_bar.visible + # one call for each train batch + one at the end of training epoch + one for validation end + assert metrics_update.call_count == train_batches + (1 if train_batches > 0 else 0) + (1 if val_batches > 0 else 0) + @RunIf(rich=True) @pytest.mark.parametrize("limit_val_batches", [1, 5]) diff --git a/tests/tests_pytorch/callbacks/test_finetuning_callback.py b/tests/tests_pytorch/callbacks/test_finetuning_callback.py index 07343c1ecc12a..b2fecedd342ea 100644 --- a/tests/tests_pytorch/callbacks/test_finetuning_callback.py +++ b/tests/tests_pytorch/callbacks/test_finetuning_callback.py @@ -109,8 +109,8 @@ def configure_optimizers(self): model.validation_step = None callback = TestBackboneFinetuningWarningCallback(unfreeze_backbone_at_epoch=3, verbose=False) + trainer = Trainer(limit_train_batches=1, default_root_dir=tmp_path, callbacks=[callback, chk], max_epochs=2) with pytest.warns(UserWarning, match="Did you init your optimizer in"): - trainer = Trainer(limit_train_batches=1, default_root_dir=tmp_path, callbacks=[callback, chk], max_epochs=2) trainer.fit(model) assert model.backbone.has_been_used diff --git a/tests/tests_pytorch/callbacks/test_spike.py b/tests/tests_pytorch/callbacks/test_spike.py index 692a28dcc38c4..f61a6c59ca9db 100644 --- a/tests/tests_pytorch/callbacks/test_spike.py +++ b/tests/tests_pytorch/callbacks/test_spike.py @@ -1,12 +1,13 @@ import contextlib -import sys import pytest import torch -from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, TrainingSpikeException +from lightning.fabric.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_1_0_0 +from lightning.fabric.utilities.spike import TrainingSpikeException from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks.spike import SpikeDetection +from tests_pytorch.helpers.runif import RunIf, _xfail_gloo_windows class IdentityModule(LightningModule): @@ -50,159 +51,33 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): @pytest.mark.flaky(max_runs=3) @pytest.mark.parametrize( ("global_rank_spike", "num_devices", "spike_value", "finite_only"), + # NOTE FOR ALL FOLLOWING TESTS: + # adding run on linux only because multiprocessing on other platforms takes forever [ - pytest.param(0, 1, None, True), - pytest.param(0, 1, None, False), - pytest.param(0, 1, float("inf"), True), - pytest.param(0, 1, float("inf"), False), - pytest.param(0, 1, float("-inf"), True), - pytest.param(0, 1, float("-inf"), False), - pytest.param(0, 1, float("NaN"), True), - pytest.param(0, 1, float("NaN"), False), - pytest.param( - 0, - 2, - None, - True, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 0, - 2, - None, - False, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 1, - 2, - None, - True, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 1, - 2, - None, - False, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 0, - 2, - float("inf"), - True, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 0, - 2, - float("inf"), - False, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 1, - 2, - float("inf"), - True, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 1, - 2, - float("inf"), - False, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 0, - 2, - float("-inf"), - True, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 0, - 2, - float("-inf"), - False, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 1, - 2, - float("-inf"), - True, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 1, - 2, - float("-inf"), - False, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 0, - 2, - float("NaN"), - True, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 0, - 2, - float("NaN"), - False, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 1, - 2, - float("NaN"), - True, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), - pytest.param( - 1, - 2, - float("NaN"), - False, - marks=pytest.mark.skipif( - sys.platform != "linux", reason="multiprocessing on other platforms takes forever" - ), - ), + pytest.param(0, 1, None, True, marks=_xfail_gloo_windows), + pytest.param(0, 1, None, False, marks=_xfail_gloo_windows), + pytest.param(0, 1, float("inf"), True, marks=_xfail_gloo_windows), + pytest.param(0, 1, float("inf"), False, marks=_xfail_gloo_windows), + pytest.param(0, 1, float("-inf"), True, marks=_xfail_gloo_windows), + pytest.param(0, 1, float("-inf"), False, marks=_xfail_gloo_windows), + pytest.param(0, 1, float("NaN"), True, marks=_xfail_gloo_windows), + pytest.param(0, 1, float("NaN"), False, marks=_xfail_gloo_windows), + pytest.param(0, 2, None, True, marks=RunIf(linux_only=True)), + pytest.param(0, 2, None, False, marks=RunIf(linux_only=True)), + pytest.param(1, 2, None, True, marks=RunIf(linux_only=True)), + pytest.param(1, 2, None, False, marks=RunIf(linux_only=True)), + pytest.param(0, 2, float("inf"), True, marks=RunIf(linux_only=True)), + pytest.param(0, 2, float("inf"), False, marks=RunIf(linux_only=True)), + pytest.param(1, 2, float("inf"), True, marks=RunIf(linux_only=True)), + pytest.param(1, 2, float("inf"), False, marks=RunIf(linux_only=True)), + pytest.param(0, 2, float("-inf"), True, marks=RunIf(linux_only=True)), + pytest.param(0, 2, float("-inf"), False, marks=RunIf(linux_only=True)), + pytest.param(1, 2, float("-inf"), True, marks=RunIf(linux_only=True)), + pytest.param(1, 2, float("-inf"), False, marks=RunIf(linux_only=True)), + pytest.param(0, 2, float("NaN"), True, marks=RunIf(linux_only=True)), + pytest.param(0, 2, float("NaN"), False, marks=RunIf(linux_only=True)), + pytest.param(1, 2, float("NaN"), True, marks=RunIf(linux_only=True)), + pytest.param(1, 2, float("NaN"), False, marks=RunIf(linux_only=True)), ], ) @pytest.mark.skipif(not _TORCHMETRICS_GREATER_EQUAL_1_0_0, reason="requires torchmetrics>=1.0.0") diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py index e9e11b6dbb466..abcd302149fcf 100644 --- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py +++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py @@ -354,6 +354,8 @@ def test_swa_resume_training_from_checkpoint_custom_scheduler(tmp_path, crash_on @RunIf(skip_windows=True) +# flaky with "torch.multiprocessing.spawn.ProcessExitedException: process 0 terminated with signal SIGABRT" (GLOO) +@pytest.mark.flaky(reruns=3) def test_swa_resume_training_from_checkpoint_ddp(tmp_path): model = SwaTestModel(crash_on_epoch=3) resume_model = SwaTestModel() diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 7b17498865889..d2bbea7ecdafe 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -35,7 +35,7 @@ import lightning.pytorch as pl from lightning.fabric.utilities.cloud_io import _load as pl_load from lightning.pytorch import Trainer, seed_everything -from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.callbacks import Callback, ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel, RandomIterableDataset from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException @@ -453,6 +453,12 @@ def test_model_checkpoint_format_checkpoint_name(tmp_path, monkeypatch): ckpt_name = ckpt.format_checkpoint_name({}, ver=3) assert ckpt_name == str(tmp_path / "name-v3.ckpt") + # with prefix + ckpt_name = ModelCheckpoint(monitor="early_stop_on", dirpath=tmp_path, filename="name").format_checkpoint_name( + {}, prefix="test" + ) + assert ckpt_name == str(tmp_path / "test-name.ckpt") + # using slashes ckpt = ModelCheckpoint(monitor="early_stop_on", dirpath=None, filename="{epoch}_{val/loss:.5f}") ckpt_name = ckpt.format_checkpoint_name({"epoch": 4, "val/loss": 0.03}) @@ -764,6 +770,431 @@ def test_ckpt_every_n_train_steps(tmp_path): assert set(os.listdir(tmp_path)) == set(expected) +def test_model_checkpoint_on_exception_run_condition_on_validation_start(tmp_path): + """Test that no checkpoint is saved when an exception is raised during a sanity check or a fast dev run, or when a + checkpoint has already been saved at the current training step.""" + + # Don't save checkpoint if sanity check fails + class TroubledModelSanityCheck(BoringModel): + def on_validation_start(self) -> None: + if self.trainer.sanity_checking: + print("Trouble!") + raise RuntimeError("Trouble!") + + model = TroubledModelSanityCheck() + checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="sanity_check", save_on_exception=True) + trainer = Trainer( + default_root_dir=tmp_path, + num_sanity_val_steps=4, + limit_train_batches=2, + callbacks=[checkpoint_callback], + max_epochs=2, + logger=False, + ) + + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + assert not os.path.isfile(tmp_path / "exception-sanity_check.ckpt") + + +def test_model_checkpoint_on_exception_fast_dev_run_on_train_batch_start(tmp_path): + """Test that no checkpoint is saved when an exception is raised during a sanity check or a fast dev run, or when a + checkpoint has already been saved at the current training step.""" + + # Don't save checkpoint if fast dev run fails + class TroubledModelFastDevRun(BoringModel): + def on_train_batch_start(self, batch, batch_idx) -> None: + if self.trainer.fast_dev_run and batch_idx == 1: + raise RuntimeError("Trouble!") + + model = TroubledModelFastDevRun() + checkpoint_callback = ModelCheckpoint(dirpath=tmp_path, filename="fast_dev_run", save_on_exception=True) + trainer = Trainer( + default_root_dir=tmp_path, + fast_dev_run=2, + limit_train_batches=2, + callbacks=[checkpoint_callback], + max_epochs=2, + logger=False, + ) + + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + assert not os.path.isfile(tmp_path / "exception-fast_dev_run.ckpt") + + +def test_model_checkpoint_on_exception_run_condition_on_train_batch_start(tmp_path): + """Test that no checkpoint is saved when an exception is raised during a sanity check or a fast dev run, or when a + checkpoint has already been saved at the current training step.""" + + # Don't save checkpoint if already saved a checkpoint + class TroubledModelAlreadySavedCheckpoint(BoringModel): + def on_train_batch_start(self, batch, batch_idx) -> None: + if self.trainer.global_step == 1: + raise RuntimeError("Trouble!") + + model = TroubledModelAlreadySavedCheckpoint() + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, filename="already_saved", save_on_exception=True, every_n_train_steps=1 + ) + trainer = Trainer( + default_root_dir=tmp_path, limit_train_batches=2, callbacks=[checkpoint_callback], max_epochs=2, logger=False + ) + + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + + assert not os.path.isfile(tmp_path / "exception-already_saved.ckpt") + assert os.path.isfile(tmp_path / "already_saved.ckpt") + + +class TroubledModelInTrainingStep(BoringModel): + def training_step(self, batch, batch_idx): + if batch_idx == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelInValidationStep(BoringModel): + def validation_step(self, batch, batch_idx): + if not self.trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelBackward(BoringModel): + def backward(self, loss): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnBeforeBackward(BoringModel): + def on_before_backward(self, loss): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnAfterBackward(BoringModel): + def on_after_backward(self): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnBeforeZeroGrad(BoringModel): + def on_before_zero_grad(self, optimizer): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnFitEnd(BoringModel): + def on_fit_end(self): + raise RuntimeError("Trouble!") + + +class TroubledModelOnTrainEnd(BoringModel): + def on_train_end(self): + raise RuntimeError("Trouble!") + + +class TroubledModelOnValidationStart(BoringModel): + def on_validation_start(self): + if not self.trainer.sanity_checking and self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnValidationEnd(BoringModel): + def on_validation_end(self): + if not self.trainer.sanity_checking: + raise RuntimeError("Trouble!") + + +class TroubledModelOnTrainBatchStart(BoringModel): + def on_train_batch_start(self, batch, batch_idx): + if batch_idx == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnTrainBatchEnd(BoringModel): + def on_train_batch_end(self, outputs, batch, batch_idx): + if batch_idx == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnTrainEpochStart(BoringModel): + def on_train_epoch_start(self): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnTrainEpochEnd(BoringModel): + def on_train_epoch_end(self): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnValidationBatchStart(BoringModel): + def on_validation_batch_start(self, batch, batch_idx): + if not self.trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnValidationBatchEnd(BoringModel): + def on_validation_batch_end(self, outputs, batch, batch_idx): + if not self.trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnValidationEpochStart(BoringModel): + def on_validation_epoch_start(self): + if not self.trainer.sanity_checking and self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnValidationEpochEnd(BoringModel): + def on_validation_epoch_end(self): + if not self.trainer.sanity_checking and self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnValidationModelEval(BoringModel): + def on_validation_model_eval(self): + if not self.trainer.sanity_checking and self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnValidationModelTrain(BoringModel): + def on_validation_model_train(self): + if not self.trainer.sanity_checking and self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOnBeforeOptimizerStep(BoringModel): + def on_before_optimizer_step(self, optimizer): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelConfigureGradienClipping(BoringModel): + def configure_gradient_clipping(self, optimizer, gradient_clip_val=None, gradient_clip_algorithm=None): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOptimizerStep(BoringModel): + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure=None): + optimizer.step(closure=optimizer_closure) + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledModelOptimizerZeroGrad(BoringModel): + def optimizer_zero_grad(self, epoch, batch_idx, optimizer): + if self.current_epoch == 1: + raise RuntimeError("Trouble!") + + +@pytest.mark.parametrize( + "TroubledModel", + [ + TroubledModelInTrainingStep, + TroubledModelInValidationStep, + TroubledModelBackward, + TroubledModelOnBeforeBackward, + TroubledModelOnAfterBackward, + TroubledModelOnBeforeZeroGrad, + TroubledModelOnFitEnd, + TroubledModelOnTrainEnd, + TroubledModelOnValidationStart, + TroubledModelOnValidationEnd, + TroubledModelOnTrainBatchStart, + TroubledModelOnTrainBatchEnd, + TroubledModelOnTrainEpochStart, + TroubledModelOnTrainEpochEnd, + TroubledModelOnValidationBatchStart, + TroubledModelOnValidationBatchEnd, + TroubledModelOnValidationEpochStart, + TroubledModelOnValidationEpochEnd, + TroubledModelOnValidationModelEval, + TroubledModelOnValidationModelTrain, + TroubledModelOnBeforeOptimizerStep, + TroubledModelConfigureGradienClipping, + TroubledModelOptimizerStep, + TroubledModelOptimizerZeroGrad, + ], +) +def test_model_checkpoint_on_exception_parametrized(tmp_path, TroubledModel): + """Test that the checkpoint is saved when an exception is raised in a lightning module.""" + model = TroubledModel() + + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, filename="exception", save_on_exception=True, every_n_epochs=7 + ) + + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[checkpoint_callback], + limit_train_batches=2, + max_epochs=4, + logger=False, + enable_progress_bar=False, + ) + + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + + checkpoint_path = tmp_path / "exception.ckpt" + + assert os.path.isfile(checkpoint_path) + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + assert checkpoint["state_dict"] is not None + assert checkpoint["state_dict"] != {} + + +class TroubledCallbackOnFitEnd(Callback): + def on_fit_end(self, trainer, pl_module): + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnTrainBatchStart(Callback): + def on_train_batch_start(self, trainer, pl_module, batch, batch_idx): + if batch_idx == 1: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnTrainBatchEnd(Callback): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if batch_idx == 1: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnTrainEpochStart(Callback): + def on_train_epoch_start(self, trainer, pl_module): + if trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnTrainEpochEnd(Callback): + def on_train_epoch_end(self, trainer, pl_module): + if trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnValidationEpochStart(Callback): + def on_validation_epoch_start(self, trainer, pl_module): + if not trainer.sanity_checking and trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnValidationEpochEnd(Callback): + def on_validation_epoch_end(self, trainer, pl_module): + if not trainer.sanity_checking and trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnValidationBatchStart(Callback): + def on_validation_batch_start(self, trainer, pl_module, batch, batch_idx): + if not trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnValidationBatchEnd(Callback): + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + if not trainer.sanity_checking and batch_idx == 1: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnTrainEnd(Callback): + def on_train_end(self, trainer, pl_module): + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnValidationStart(Callback): + def on_validation_start(self, trainer, pl_module): + if not trainer.sanity_checking: + raise RuntimeError("Trouble!") + + +class TroubledCallbackOnValidationEnd(Callback): + def on_validation_end(self, trainer, pl_module): + if not trainer.sanity_checking: + raise RuntimeError("Trouble!") + + +class TroubleCallbackOnBeforeBackward(Callback): + def on_before_backward(self, trainer, pl_module, loss): + if trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubleCallbackOnAfterBackward(Callback): + def on_after_backward(self, trainer, pl_module): + if trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubleCallbackOnBeforeOptimizerStep(Callback): + def on_before_optimizer_step(self, trainer, pl_module, optimizer): + if trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + +class TroubleCallbackOnBeforeZeroGrad(Callback): + def on_before_zero_grad(self, trainer, pl_module, optimizer): + if trainer.current_epoch == 1: + raise RuntimeError("Trouble!") + + +#### + + +@pytest.mark.parametrize( + "TroubledCallback", + [ + TroubledCallbackOnFitEnd, + TroubledCallbackOnTrainBatchStart, + TroubledCallbackOnTrainBatchEnd, + TroubledCallbackOnTrainEpochStart, + TroubledCallbackOnTrainEpochEnd, + TroubledCallbackOnValidationEpochStart, + TroubledCallbackOnValidationEpochEnd, + TroubledCallbackOnValidationBatchStart, + TroubledCallbackOnValidationBatchEnd, + TroubledCallbackOnTrainEnd, + TroubledCallbackOnValidationStart, + TroubledCallbackOnValidationEnd, + TroubleCallbackOnBeforeBackward, + TroubleCallbackOnAfterBackward, + TroubleCallbackOnBeforeOptimizerStep, + TroubleCallbackOnBeforeZeroGrad, + ], +) +def test_model_checkpoint_on_exception_in_other_callbacks(tmp_path, TroubledCallback): + """Test that an checkpoint is saved when an exception is raised in an other callback.""" + + model = BoringModel() + troubled_callback = TroubledCallback() + + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, filename="exception", save_on_exception=True, every_n_epochs=7 + ) + trainer = Trainer( + default_root_dir=tmp_path, + callbacks=[checkpoint_callback, troubled_callback], + max_epochs=4, + limit_train_batches=2, + logger=False, + enable_progress_bar=False, + ) + + with pytest.raises(RuntimeError, match="Trouble!"): + trainer.fit(model) + + checkpoint_path = tmp_path / "exception.ckpt" + + assert os.path.isfile(checkpoint_path) + checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False) + assert checkpoint["state_dict"] is not None + assert checkpoint["state_dict"] != {} + + @mock.patch("lightning.pytorch.callbacks.model_checkpoint.time") def test_model_checkpoint_train_time_interval(mock_datetime, tmp_path) -> None: """Tests that the checkpoints are saved at the specified time interval.""" @@ -1666,3 +2097,30 @@ def val_dataloader(self) -> DataLoader: trainer_kwargs["max_epochs"] = 4 trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs)) trainer.fit(model, ckpt_path=checkpoint_path) + + +def test_save_last_without_save_on_train_epoch_and_without_val(tmp_path): + """Test that save_last=True works correctly when save_on_train_epoch_end=False in a model without validation.""" + + # Remove validation methods to test the edge case + model = BoringModel() + model.validation_step = None + model.val_dataloader = None + + checkpoint_callback = ModelCheckpoint( + dirpath=tmp_path, + save_last=True, + save_on_train_epoch_end=False, + ) + + trainer = Trainer( + max_epochs=2, + callbacks=[checkpoint_callback], + logger=False, + enable_progress_bar=False, + ) + + trainer.fit(model) + + # save_last=True should always save last.ckpt + assert (tmp_path / "last.ckpt").exists() diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index b02d9d089a354..878298c6bfd94 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -95,6 +95,12 @@ def restore_env_variables(): "TF_GRPC_DEFAULT_OPTIONS", "XLA_FLAGS", "TORCHINDUCTOR_CACHE_DIR", # leaked by torch.compile + # TensorFlow and TPU related variables + "TF2_BEHAVIOR", + "TPU_ML_PLATFORM", + "TPU_ML_PLATFORM_VERSION", + "LD_LIBRARY_PATH", + "ENABLE_RUNTIME_UPTIME_TELEMETRY", } leaked_vars.difference_update(allowlist) assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}" @@ -333,7 +339,7 @@ def pytest_collection_modifyitems(items: list[pytest.Function], config: pytest.C options = { "standalone": "PL_RUN_STANDALONE_TESTS", - "min_cuda_gpus": "PL_RUN_CUDA_TESTS", + "min_cuda_gpus": "RUN_ONLY_CUDA_TESTS", "tpu": "PL_RUN_TPU_TESTS", } if os.getenv(options["standalone"], "0") == "1" and os.getenv(options["min_cuda_gpus"], "0") == "1": diff --git a/tests/tests_pytorch/core/test_results.py b/tests/tests_pytorch/core/test_results.py index 93982086a6b0a..c1d50e8458da6 100644 --- a/tests/tests_pytorch/core/test_results.py +++ b/tests/tests_pytorch/core/test_results.py @@ -49,8 +49,8 @@ def result_reduce_ddp_fn(strategy): assert actual.item() == dist.get_world_size() -# flaky with "process 0 terminated with signal SIGABRT" -@pytest.mark.flaky(reruns=3, only_rerun="torch.multiprocessing.spawn.ProcessExitedException") +# flaky with "torch.multiprocessing.spawn.ProcessExitedException: process 0 terminated with signal SIGABRT" +@pytest.mark.flaky(reruns=3) @RunIf(skip_windows=True) def test_result_reduce_ddp(): spawn_launch(result_reduce_ddp_fn, [torch.device("cpu")] * 2) diff --git a/tests/tests_pytorch/helpers/runif.py b/tests/tests_pytorch/helpers/runif.py index 25fadd524adf8..372f493a1fb67 100644 --- a/tests/tests_pytorch/helpers/runif.py +++ b/tests/tests_pytorch/helpers/runif.py @@ -13,9 +13,20 @@ # limitations under the License. import pytest +from lightning.fabric.utilities.imports import _IS_WINDOWS +from lightning.pytorch.utilities.imports import _TORCH_EQUAL_2_8 from lightning.pytorch.utilities.testing import _runif_reasons def RunIf(**kwargs): reasons, marker_kwargs = _runif_reasons(**kwargs) return pytest.mark.skipif(condition=len(reasons) > 0, reason=f"Requires: [{' + '.join(reasons)}]", **marker_kwargs) + + +# todo: RuntimeError: makeDeviceForHostname(): unsupported gloo device +_xfail_gloo_windows = pytest.mark.xfail( + RuntimeError, + strict=True, + condition=(_IS_WINDOWS and _TORCH_EQUAL_2_8), + reason="makeDeviceForHostname(): unsupported gloo device", +) diff --git a/tests/tests_pytorch/loops/test_double_iter_in_iterable_dataset.py b/tests/tests_pytorch/loops/test_double_iter_in_iterable_dataset.py new file mode 100644 index 0000000000000..83405e00b1541 --- /dev/null +++ b/tests/tests_pytorch/loops/test_double_iter_in_iterable_dataset.py @@ -0,0 +1,76 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This test tests the resuming of training from a checkpoint file using an IterableDataset. +# And contains code mentioned in the issue: #19427. +# Ref: https://github.com/Lightning-AI/pytorch-lightning/issues/19427 +import multiprocessing as mp +import os +import sys +from collections.abc import Iterator +from pathlib import Path +from queue import Queue + +import numpy as np +import pytest +from torch.utils.data import DataLoader, IterableDataset + +from lightning.pytorch import Trainer +from lightning.pytorch.demos.boring_classes import BoringModel + + +class QueueDataset(IterableDataset): + def __init__(self, queue: Queue) -> None: + super().__init__() + self.queue = queue + + def __iter__(self) -> Iterator: + for _ in range(5): + tensor, _ = self.queue.get(timeout=5) + yield tensor + + +def train_model(queue: Queue, max_epochs: int, ckpt_path: Path) -> None: + dataloader = DataLoader(QueueDataset(queue), num_workers=1, batch_size=None) + trainer = Trainer( + max_epochs=max_epochs, + enable_progress_bar=False, + enable_checkpointing=False, + devices=1, + logger=False, + ) + if ckpt_path.exists(): + trainer.fit(BoringModel(), dataloader, ckpt_path=str(ckpt_path)) + else: + trainer.fit(BoringModel(), dataloader) + trainer.save_checkpoint(str(ckpt_path)) + + +@pytest.mark.skipif(sys.platform == "darwin", reason="Skip on macOS due to multiprocessing issues") +def test_resume_training_with(tmp_path): + """Test resuming training from checkpoint file using a IterableDataset.""" + q = mp.Queue() + arr = np.random.random([1, 32]).astype(np.float32) + for idx in range(20): + q.put((arr, idx)) + + max_epoch = 2 + ckpt_path = tmp_path / "model.ckpt" + train_model(q, max_epoch, ckpt_path) + + assert os.path.exists(ckpt_path), f"Checkpoint file '{ckpt_path}' wasn't created" + ckpt_size = os.path.getsize(ckpt_path) + assert ckpt_size > 0, f"Checkpoint file is empty (size: {ckpt_size} bytes)" + + train_model(q, max_epoch + 2, ckpt_path) diff --git a/tests/tests_pytorch/loops/test_prediction_loop.py b/tests/tests_pytorch/loops/test_prediction_loop.py index 470cbcdc195f5..2ca05e243df8f 100644 --- a/tests/tests_pytorch/loops/test_prediction_loop.py +++ b/tests/tests_pytorch/loops/test_prediction_loop.py @@ -19,6 +19,7 @@ from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper +from tests_pytorch.helpers.runif import _xfail_gloo_windows def test_prediction_loop_stores_predictions(tmp_path): @@ -51,6 +52,7 @@ def predict_step(self, batch, batch_idx): assert trainer.predict_loop.predictions == [] +@_xfail_gloo_windows @pytest.mark.parametrize("use_distributed_sampler", [False, True]) def test_prediction_loop_batch_sampler_set_epoch_called(tmp_path, use_distributed_sampler): """Tests that set_epoch is called on the dataloader's batch sampler (if any) during prediction.""" diff --git a/tests/tests_pytorch/models/test_amp.py b/tests/tests_pytorch/models/test_amp.py index 24323f5c1d691..3262365fff2af 100644 --- a/tests/tests_pytorch/models/test_amp.py +++ b/tests/tests_pytorch/models/test_amp.py @@ -22,7 +22,7 @@ from lightning.fabric.plugins.environments import SLURMEnvironment from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset -from tests_pytorch.helpers.runif import RunIf +from tests_pytorch.helpers.runif import RunIf, _xfail_gloo_windows class AMPTestModel(BoringModel): @@ -53,7 +53,7 @@ def _assert_autocast_enabled(self): [ ("single_device", "16-mixed", 1), ("single_device", "bf16-mixed", 1), - ("ddp_spawn", "16-mixed", 2), + pytest.param("ddp_spawn", "16-mixed", 2, marks=_xfail_gloo_windows), pytest.param("ddp_spawn", "bf16-mixed", 2, marks=RunIf(skip_windows=True)), ], ) diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index f14d62b6befb4..575bcadadc404 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -17,7 +17,7 @@ import pickle import sys from argparse import Namespace -from dataclasses import dataclass +from dataclasses import dataclass, field from enum import Enum from unittest import mock @@ -881,6 +881,31 @@ def test_dataclass_lightning_module(tmp_path): assert model.hparams == {"mandatory": 33, "optional": "cocofruit"} +def test_dataclass_with_init_false_fields(): + """Test that save_hyperparameters() filters out fields with init=False and issues a warning.""" + + @dataclass + class DataClassWithInitFalseFieldsModel(BoringModel): + mandatory: int + optional: str = "optional" + non_init_field: int = field(default=999, init=False) + another_non_init: str = field(default="not_in_init", init=False) + + def __post_init__(self): + super().__init__() + self.save_hyperparameters() + + model = DataClassWithInitFalseFieldsModel(33, optional="cocofruit") + + expected_hparams = {"mandatory": 33, "optional": "cocofruit"} + assert model.hparams == expected_hparams + + assert model.non_init_field == 999 + assert model.another_non_init == "not_in_init" + assert "non_init_field" not in model.hparams + assert "another_non_init" not in model.hparams + + class NoHparamsModel(BoringModel): """Tests a model without hparams.""" diff --git a/tests/tests_pytorch/models/test_onnx.py b/tests/tests_pytorch/models/test_onnx.py index 81fd5631a3400..9f51332fbbdfa 100644 --- a/tests/tests_pytorch/models/test_onnx.py +++ b/tests/tests_pytorch/models/test_onnx.py @@ -13,6 +13,7 @@ # limitations under the License. import operator import os +import re from io import BytesIO from pathlib import Path from unittest.mock import patch @@ -25,7 +26,9 @@ import tests_pytorch.helpers.pipelines as tpipes from lightning.pytorch import Trainer +from lightning.pytorch.core.module import _ONNXSCRIPT_AVAILABLE from lightning.pytorch.demos.boring_classes import BoringModel +from lightning.pytorch.utilities.imports import _TORCH_GREATER_EQUAL_2_6 from tests_pytorch.helpers.runif import RunIf from tests_pytorch.utilities.test_model_summary import UnorderedModel @@ -139,8 +142,16 @@ def test_error_if_no_input(tmp_path): model.to_onnx(file_path) +@pytest.mark.parametrize( + "dynamo", + [ + None, + pytest.param(False, marks=RunIf(min_torch="2.5.0", dynamo=True, onnxscript=True)), + pytest.param(True, marks=RunIf(min_torch="2.5.0", dynamo=True, onnxscript=True)), + ], +) @RunIf(onnx=True) -def test_if_inference_output_is_valid(tmp_path): +def test_if_inference_output_is_valid(tmp_path, dynamo): """Test that the output inferred from ONNX model is same as from PyTorch.""" model = BoringModel() model.example_input_array = torch.randn(5, 32) @@ -153,7 +164,12 @@ def test_if_inference_output_is_valid(tmp_path): torch_out = model(model.example_input_array) file_path = os.path.join(tmp_path, "model.onnx") - model.to_onnx(file_path, model.example_input_array, export_params=True) + kwargs = { + "export_params": True, + } + if dynamo is not None: + kwargs["dynamo"] = dynamo + model.to_onnx(file_path, model.example_input_array, **kwargs) ort_kwargs = {"providers": "CPUExecutionProvider"} if compare_version("onnxruntime", operator.ge, "1.16.0") else {} ort_session = onnxruntime.InferenceSession(file_path, **ort_kwargs) @@ -167,3 +183,53 @@ def to_numpy(tensor): # compare ONNX Runtime and PyTorch results assert np.allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05) + + +@RunIf(min_torch="2.5.0", dynamo=True) +@pytest.mark.skipif(_ONNXSCRIPT_AVAILABLE, reason="Run this test only if onnxscript is not available.") +def test_model_onnx_export_missing_onnxscript(): + """Test that an error is raised if onnxscript is not available.""" + model = BoringModel() + model.example_input_array = torch.randn(5, 32) + + with pytest.raises( + ModuleNotFoundError, + match=re.escape( + f"`{type(model).__name__}.to_onnx(dynamo=True)` requires `onnxscript` and `torch>=2.5.0` to be installed.", + ), + ): + model.to_onnx(dynamo=True) + + +@RunIf(onnx=True, min_torch="2.5.0", dynamo=True, onnxscript=True) +def test_model_return_type(): + if _TORCH_GREATER_EQUAL_2_6: + from torch.onnx import ONNXProgram + else: + from torch.onnx._internal.exporter import ONNXProgram + + model = BoringModel() + model.example_input_array = torch.randn((1, 32)) + model.eval() + + onnx_pg = model.to_onnx(dynamo=True) + assert isinstance(onnx_pg, ONNXProgram) + + model_ret = model(model.example_input_array) + inf_ret = onnx_pg(model.example_input_array) + assert torch.allclose(model_ret, inf_ret[0], rtol=1e-03, atol=1e-05) + + +@RunIf(max_torch="2.5.0") +def test_model_onnx_export_wrong_torch_version(): + """Test that an error is raised if onnxscript is not available.""" + model = BoringModel() + model.example_input_array = torch.randn(5, 32) + + with pytest.raises( + ModuleNotFoundError, + match=re.escape( + f"`{type(model).__name__}.to_onnx(dynamo=True)` requires `onnxscript` and `torch>=2.5.0` to be installed.", + ), + ): + model.to_onnx(dynamo=True) diff --git a/tests/tests_pytorch/profilers/test_profiler.py b/tests/tests_pytorch/profilers/test_profiler.py index d0221d12e317f..2059141ad9d63 100644 --- a/tests/tests_pytorch/profilers/test_profiler.py +++ b/tests/tests_pytorch/profilers/test_profiler.py @@ -73,9 +73,9 @@ def test_simple_profiler_durations(simple_profiler, action: str, expected: list) np.testing.assert_allclose(simple_profiler.recorded_durations[action], expected, rtol=0.2) -def test_simple_profiler_overhead(simple_profiler, n_iter=5): +def test_simple_profiler_overhead(simple_profiler): """Ensure that the profiler doesn't introduce too much overhead during training.""" - for _ in range(n_iter): + for _ in range(5): with simple_profiler.profile("no-op"): pass @@ -284,8 +284,9 @@ def test_advanced_profiler_durations(advanced_profiler, action: str, expected: l @pytest.mark.flaky(reruns=3) -def test_advanced_profiler_overhead(advanced_profiler, n_iter=5): +def test_advanced_profiler_overhead(advanced_profiler): """Ensure that the profiler doesn't introduce too much overhead during training.""" + n_iter = 5 for _ in range(n_iter): with advanced_profiler.profile("no-op"): pass @@ -336,6 +337,12 @@ def test_advanced_profiler_deepcopy(advanced_profiler): assert deepcopy(advanced_profiler) +def test_advanced_profiler_nested(advanced_profiler): + """Ensure AdvancedProfiler does not raise ValueError for nested profiling actions (Python 3.12+ compatibility).""" + with advanced_profiler.profile("outer"), advanced_profiler.profile("inner"): + pass # Should not raise ValueError + + @pytest.fixture def pytorch_profiler(tmp_path): return PyTorchProfiler(dirpath=tmp_path, filename="profiler") @@ -614,8 +621,8 @@ def test_pytorch_profiler_raises_warning_for_limited_steps(tmp_path, trainer_con warning_cache.clear() with pytest.warns(UserWarning, match="not enough steps to properly record traces"): getattr(trainer, trainer_fn)(model) - assert trainer.profiler._schedule is None - warning_cache.clear() + assert trainer.profiler._schedule is None + warning_cache.clear() def test_profile_callbacks(tmp_path): diff --git a/tests/tests_pytorch/serve/test_servable_module_validator.py b/tests/tests_pytorch/serve/test_servable_module_validator.py index ba90949132ba2..c20621c72ff88 100644 --- a/tests/tests_pytorch/serve/test_servable_module_validator.py +++ b/tests/tests_pytorch/serve/test_servable_module_validator.py @@ -5,6 +5,7 @@ from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.serve.servable_module_validator import ServableModule, ServableModuleValidator +from tests_pytorch.helpers.runif import _xfail_gloo_windows class ServableBoringModel(BoringModel, ServableModule): @@ -28,13 +29,14 @@ def configure_response(self): return {"output": [0, 1]} -@pytest.mark.xfail(strict=False, reason="test is too flaky in CI") # todo +@pytest.mark.flaky(reruns=3) def test_servable_module_validator(): model = ServableBoringModel() callback = ServableModuleValidator() callback.on_train_start(Trainer(accelerator="cpu"), model) +@_xfail_gloo_windows @pytest.mark.flaky(reruns=3) def test_servable_module_validator_with_trainer(tmp_path, mps_count_0): callback = ServableModuleValidator() diff --git a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py index d0b4ab617df66..f729b521dc5d6 100644 --- a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py @@ -25,7 +25,7 @@ from lightning.pytorch.strategies import DDPStrategy from lightning.pytorch.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher from lightning.pytorch.trainer.states import TrainerFn -from tests_pytorch.helpers.runif import RunIf +from tests_pytorch.helpers.runif import RunIf, _xfail_gloo_windows @mock.patch("lightning.pytorch.strategies.launchers.multiprocessing.mp.get_all_start_methods", return_value=[]) @@ -194,6 +194,8 @@ def on_fit_start(self) -> None: assert torch.equal(self.layer.weight.data, self.tied_layer.weight.data) +@_xfail_gloo_windows +@pytest.mark.flaky(reruns=3) def test_memory_sharing_disabled(tmp_path): """Test that the multiprocessing launcher disables memory sharing on model parameters and buffers to avoid race conditions on model updates.""" @@ -219,6 +221,7 @@ def test_check_for_missing_main_guard(): launcher.launch(function=Mock()) +@_xfail_gloo_windows def test_fit_twice_raises(mps_count_0): model = BoringModel() trainer = Trainer( diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index 3877d6c051017..f3d98cf444c36 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -491,13 +491,15 @@ def test_strategy_choice_ddp_torchelastic(_, __, mps_count_0, cuda_count_2): "LOCAL_RANK": "1", }, ) -@mock.patch("lightning.fabric.accelerators.cuda.num_cuda_devices", return_value=2) -@mock.patch("lightning.fabric.accelerators.mps.MPSAccelerator.is_available", return_value=False) -def test_torchelastic_priority_over_slurm(*_): +def test_torchelastic_priority_over_slurm(monkeypatch): """Test that the TorchElastic cluster environment is chosen over SLURM when both are detected.""" + with monkeypatch.context(): + mock_cuda_count(monkeypatch, 2) + mock_mps_count(monkeypatch, 0) + mock_hpu_count(monkeypatch, 0) + connector = _AcceleratorConnector(strategy="ddp") assert TorchElasticEnvironment.detect() assert SLURMEnvironment.detect() - connector = _AcceleratorConnector(strategy="ddp") assert isinstance(connector.strategy.cluster_environment, TorchElasticEnvironment) @@ -580,6 +582,11 @@ class AcceleratorSubclass(CPUAccelerator): Trainer(accelerator=AcceleratorSubclass(), strategy=FSDPStrategySubclass()) +@RunIf(min_cuda_gpus=1) +def test_check_fsdp_strategy_and_fallback_with_cudaaccelerator(): + Trainer(strategy="fsdp", accelerator=CUDAAccelerator()) + + @mock.patch.dict(os.environ, {}, clear=True) def test_unsupported_tpu_choice(xla_available, tpu_available): # if user didn't set strategy, _Connector will choose the SingleDeviceXLAStrategy or XLAStrategy @@ -1003,6 +1010,7 @@ def _mock_tpu_available(value): with monkeypatch.context(): mock_cuda_count(monkeypatch, 2) mock_mps_count(monkeypatch, 0) + mock_hpu_count(monkeypatch, 0) _mock_tpu_available(True) connector = _AcceleratorConnector() assert isinstance(connector.accelerator, XLAAccelerator) diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index 1bb0d1478e7d3..367c2340ce542 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -37,7 +37,7 @@ from lightning.pytorch.utilities.combined_loader import CombinedLoader from lightning.pytorch.utilities.data import _update_dataloader from lightning.pytorch.utilities.exceptions import MisconfigurationException -from tests_pytorch.helpers.runif import RunIf +from tests_pytorch.helpers.runif import RunIf, _xfail_gloo_windows @RunIf(skip_windows=True) @@ -123,6 +123,7 @@ def on_train_end(self): self.ctx.__exit__(None, None, None) +@_xfail_gloo_windows @pytest.mark.parametrize("num_workers", [0, 1, 2]) def test_dataloader_persistent_workers_performance_warning(num_workers, tmp_path): """Test that when the multiprocessing start-method is 'spawn', we recommend setting `persistent_workers=True`.""" diff --git a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py index 050818287ba45..6322698ef3b73 100644 --- a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py +++ b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py @@ -170,3 +170,44 @@ def test_distributed_sampler_with_overfit_batches(): train_sampler = trainer.train_dataloader.sampler assert isinstance(train_sampler, DistributedSampler) assert train_sampler.shuffle is False + + +def test_overfit_batches_same_batch_for_train_and_val(tmp_path): + """Test that when overfit_batches=1, the same batch is used for both training and validation.""" + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.train_batches = [] + self.val_batches = [] + + def training_step(self, batch, batch_idx): + self.train_batches.append(batch) + return super().training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + self.val_batches.append(batch) + return super().validation_step(batch, batch_idx) + + model = TestModel() + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=2, + overfit_batches=1, + check_val_every_n_epoch=1, + enable_model_summary=False, + ) + trainer.fit(model) + + # Verify that the same batch was used for both training and validation + assert len(model.train_batches) > 0 + assert len(model.val_batches) > 0 + + # Compare the actual batch contents + train_batch = model.train_batches[0] + val_batch = model.val_batches[0] + + # Check if the batches are identical + assert torch.equal(train_batch, val_batch), ( + "Training and validation batches should be identical when overfit_batches=1" + ) diff --git a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py index be6de37ddff3a..98385f2d5681a 100644 --- a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py @@ -234,9 +234,9 @@ def on_test_epoch_end(self): @pytest.mark.parametrize("suffix", [False, True]) -def test_multi_dataloaders_add_suffix_properly(tmp_path, suffix): +def test_multi_dataloaders_add_suffix_properly(suffix, tmp_path): class TestModel(BoringModel): - def test_step(self, batch, batch_idx, dataloader_idx=0): + def test_step(self, batch, batch_idx, dataloader_idx=0): # noqa: PT028 out = super().test_step(batch, batch_idx) self.log("test_loss", out["y"], on_step=True, on_epoch=True) return out @@ -441,7 +441,7 @@ def on_test_epoch_end(self, _, pl_module): class TestModel(BoringModel): seen_losses = {i: [] for i in range(num_dataloaders)} - def test_step(self, batch, batch_idx, dataloader_idx=0): + def test_step(self, batch, batch_idx, dataloader_idx=0): # noqa: PT028 loss = super().test_step(batch, batch_idx)["y"] self.log("test_loss", loss) self.seen_losses[dataloader_idx].append(loss) diff --git a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py index be99489cfdf89..6916eae68e9c0 100644 --- a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py @@ -35,7 +35,7 @@ from lightning.pytorch.trainer.states import RunningStage from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11 -from tests_pytorch.helpers.runif import RunIf +from tests_pytorch.helpers.runif import RunIf, _xfail_gloo_windows def test__training_step__log(tmp_path): @@ -346,7 +346,7 @@ def validation_step(self, batch, batch_idx): ("devices", "accelerator"), [ (1, "cpu"), - (2, "cpu"), + pytest.param(2, "cpu", marks=_xfail_gloo_windows), pytest.param(2, "gpu", marks=RunIf(min_cuda_gpus=2)), ], ) diff --git a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py index 3f89e1459298d..dd8042ecf2058 100644 --- a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py +++ b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py @@ -304,8 +304,36 @@ def on_train_epoch_end(self, *_, **__): trainer.fit(model) +class CustomMapping(collections.abc.Mapping): + """A custom implementation of Mapping for testing purposes.""" + + def __init__(self, *args, **kwargs): + self._store = dict(*args, **kwargs) + + def __getitem__(self, key): + return self._store[key] + + def __iter__(self): + return iter(self._store) + + def __len__(self): + return len(self._store) + + def __repr__(self): + return f"{self.__class__.__name__}({self._store})" + + def __copy__(self): + cls = self.__class__ + new_obj = cls(self._store.copy()) + return new_obj + + def copy(self): + return self.__copy__() + + @RunIf(min_cuda_gpus=1) -def test_multiple_optimizers_step(tmp_path): +@pytest.mark.parametrize("dicttype", [dict, CustomMapping]) +def test_multiple_optimizers_step(tmp_path, dicttype): """Tests that `step` works with several optimizers.""" class TestModel(ManualOptModel): @@ -335,7 +363,7 @@ def training_step(self, batch, batch_idx): opt_b.step() opt_b.zero_grad() - return {"loss1": loss_1.detach(), "loss2": loss_2.detach()} + return dicttype(loss1=loss_1.detach(), loss2=loss_2.detach()) # sister test: tests/plugins/test_amp_plugins.py::test_amp_gradient_unscale def on_after_backward(self) -> None: diff --git a/tests/tests_pytorch/trainer/test_config_validator.py b/tests/tests_pytorch/trainer/test_config_validator.py index cfca98e04c8c8..fedf913dc9839 100644 --- a/tests/tests_pytorch/trainer/test_config_validator.py +++ b/tests/tests_pytorch/trainer/test_config_validator.py @@ -16,7 +16,6 @@ import pytest import torch -from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import LightningDataModule, LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.trainer.configuration_validator import ( @@ -46,20 +45,19 @@ def test_wrong_configure_optimizers(tmp_path): trainer.fit(model) -def test_fit_val_loop_config(tmp_path): +@pytest.mark.parametrize("model_attrib", ["validation_step", "val_dataloader"]) +def test_fit_val_loop_config(model_attrib, tmp_path): """When either val loop or val data are missing raise warning.""" trainer = Trainer(default_root_dir=tmp_path, max_epochs=1) - # no val data has val loop - with pytest.warns(UserWarning, match=r"You passed in a `val_dataloader` but have no `validation_step`"): - model = BoringModel() - model.validation_step = None - trainer.fit(model) - - # has val loop but no val data - with pytest.warns(PossibleUserWarning, match=r"You defined a `validation_step` but have no `val_dataloader`"): - model = BoringModel() - model.val_dataloader = None + model = BoringModel() + setattr(model, model_attrib, None) + match_msg = ( + r"You passed in a `val_dataloader` but have no `validation_step`" + if model_attrib == "validation_step" + else "You defined a `validation_step` but have no `val_dataloader`" + ) + with pytest.warns(UserWarning, match=match_msg): trainer.fit(model) diff --git a/tests/tests_pytorch/trainer/test_dataloaders.py b/tests/tests_pytorch/trainer/test_dataloaders.py index 7fbe55030770e..a69176b00d74f 100644 --- a/tests/tests_pytorch/trainer/test_dataloaders.py +++ b/tests/tests_pytorch/trainer/test_dataloaders.py @@ -545,13 +545,14 @@ def test_warning_with_few_workers(_, tmp_path, ckpt_path, stage): trainer = Trainer(default_root_dir=tmp_path, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2) - with pytest.warns(UserWarning, match=f"The '{stage}_dataloader' does not have many workers"): - if stage == "test": - if ckpt_path in ("specific", "best"): - trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl) - ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "specific" else ckpt_path + if stage == "test": + if ckpt_path in ("specific", "best"): + trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl) + ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "specific" else ckpt_path + with pytest.warns(UserWarning, match=f"The '{stage}_dataloader' does not have many workers"): trainer.test(model, dataloaders=train_dl, ckpt_path=ckpt_path) - else: + else: + with pytest.warns(UserWarning, match=f"The '{stage}_dataloader' does not have many workers"): trainer.fit(model, train_dataloaders=train_dl, val_dataloaders=val_dl) @@ -579,16 +580,15 @@ def training_step(self, batch, batch_idx): trainer = Trainer(default_root_dir=tmp_path, max_epochs=1, limit_val_batches=0.1, limit_train_batches=0.2) - with pytest.warns( - UserWarning, - match=f"The '{stage}_dataloader' does not have many workers", - ): - if stage == "test": - if ckpt_path in ("specific", "best"): - trainer.fit(model, train_dataloaders=train_multi_dl, val_dataloaders=val_multi_dl) - ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "specific" else ckpt_path + if stage == "test": + if ckpt_path in ("specific", "best"): + trainer.fit(model, train_dataloaders=train_multi_dl, val_dataloaders=val_multi_dl) + ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "specific" else ckpt_path + + with pytest.warns(UserWarning, match=f"The '{stage}_dataloader' does not have many workers"): trainer.test(model, dataloaders=test_multi_dl, ckpt_path=ckpt_path) - else: + else: + with pytest.warns(UserWarning, match=f"The '{stage}_dataloader' does not have many workers"): trainer.fit(model, train_dataloaders=train_multi_dl, val_dataloaders=val_multi_dl) @@ -669,28 +669,35 @@ def test_auto_add_worker_init_fn_distributed(tmp_path, monkeypatch): trainer.fit(model, train_dataloaders=dataloader) -def test_warning_with_small_dataloader_and_logging_interval(tmp_path): +@pytest.mark.parametrize("log_interval", [2, 11]) +def test_warning_with_small_dataloader_and_logging_interval(log_interval, tmp_path): """Test that a warning message is shown if the dataloader length is too short for the chosen logging interval.""" model = BoringModel() dataloader = DataLoader(RandomDataset(32, length=10)) model.train_dataloader = lambda: dataloader - with pytest.warns(UserWarning, match=r"The number of training batches \(10\) is smaller than the logging interval"): - trainer = Trainer(default_root_dir=tmp_path, max_epochs=1, log_every_n_steps=11, logger=CSVLogger(tmp_path)) + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=1, + log_every_n_steps=log_interval, + limit_train_batches=1 if log_interval < 10 else None, + logger=CSVLogger(tmp_path), + ) + with pytest.warns( + UserWarning, + match=rf"The number of training batches \({log_interval - 1}\) is smaller than the logging interval", + ): trainer.fit(model) - with pytest.warns(UserWarning, match=r"The number of training batches \(1\) is smaller than the logging interval"): - trainer = Trainer( - default_root_dir=tmp_path, - max_epochs=1, - log_every_n_steps=2, - limit_train_batches=1, - logger=CSVLogger(tmp_path), - ) - trainer.fit(model) +def test_warning_with_small_dataloader_and_fast_dev_run(tmp_path): + """Test that a warning message is shown if the dataloader length is too short for the chosen logging interval.""" + model = BoringModel() + dataloader = DataLoader(RandomDataset(32, length=10)) + model.train_dataloader = lambda: dataloader + + trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True, log_every_n_steps=2) with no_warning_call(UserWarning, match="The number of training batches"): - trainer = Trainer(default_root_dir=tmp_path, fast_dev_run=True, log_every_n_steps=2) trainer.fit(model) diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index 18ae7ce77bdfc..da79e2fdc411b 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -55,7 +55,7 @@ from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher, _SubprocessScriptLauncher from lightning.pytorch.trainer.states import RunningStage, TrainerFn from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE +from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE, _TORCH_EQUAL_2_8 from tests_pytorch.conftest import mock_cuda_count, mock_mps_count from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf @@ -1729,6 +1729,8 @@ def test_exception_when_lightning_module_is_not_set_on_trainer(fn): @RunIf(min_cuda_gpus=1) +# FixMe: the memory raises to 1024 from expected 512 +@pytest.mark.xfail(AssertionError, strict=True, condition=_TORCH_EQUAL_2_8, reason="temporarily disabled for torch 2.8") def test_multiple_trainer_constant_memory_allocated(tmp_path): """This tests ensures calling the trainer several times reset the memory back to 0.""" @@ -1750,8 +1752,6 @@ def current_memory(): gc.collect() return torch.cuda.memory_allocated(0) - initial = current_memory() - model = TestModel() trainer_kwargs = { "default_root_dir": tmp_path, @@ -1763,6 +1763,7 @@ def current_memory(): "callbacks": Check(), } trainer = Trainer(**trainer_kwargs) + initial = current_memory() trainer.fit(model) assert trainer.strategy.model is model diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index ec894688ccb6c..e2d1b6bd4ee84 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -23,6 +23,7 @@ from lightning_utilities.test.warning import no_warning_call from lightning.pytorch import Trainer, seed_everything +from lightning.pytorch.callbacks import EarlyStopping from lightning.pytorch.callbacks.lr_finder import LearningRateFinder from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.tuner.lr_finder import _LRFinder @@ -538,3 +539,142 @@ def training_step(self, batch: Any, batch_idx: int) -> STEP_OUTPUT: suggested_lr = lr_finder.suggestion() assert math.isfinite(suggested_lr) assert math.isclose(model.lr, suggested_lr) + + +def test_lr_finder_with_early_stopping(tmp_path): + class ModelWithValidation(BoringModel): + def __init__(self): + super().__init__() + self.learning_rate = 0.1 + + def validation_step(self, batch, batch_idx): + output = self.step(batch) + # Log validation loss that EarlyStopping will monitor + self.log("val_loss", output, on_epoch=True) + return output + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) + + # Add ReduceLROnPlateau scheduler that monitors val_loss (issue #20355) + plateau_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.5, patience=2 + ) + scheduler_config = {"scheduler": plateau_scheduler, "interval": "epoch", "monitor": "val_loss"} + + return {"optimizer": optimizer, "lr_scheduler": scheduler_config} + + model = ModelWithValidation() + + # Both callbacks that previously caused issues + callbacks = [ + LearningRateFinder(num_training_steps=100, update_attr=False), + EarlyStopping(monitor="val_loss", patience=3), + ] + + trainer = Trainer( + default_root_dir=tmp_path, + max_epochs=10, + callbacks=callbacks, + limit_train_batches=5, + limit_val_batches=3, + enable_model_summary=False, + enable_progress_bar=False, + ) + + trainer.fit(model) + assert trainer.state.finished + + # Verify that both callbacks were active + lr_finder_callback = None + early_stopping_callback = None + for callback in trainer.callbacks: + if isinstance(callback, LearningRateFinder): + lr_finder_callback = callback + elif isinstance(callback, EarlyStopping): + early_stopping_callback = callback + + assert lr_finder_callback is not None, "LearningRateFinder callback should be present" + assert early_stopping_callback is not None, "EarlyStopping callback should be present" + + # Verify learning rate finder ran and has results + assert lr_finder_callback.optimal_lr is not None, "Learning rate finder should have results" + suggestion = lr_finder_callback.optimal_lr.suggestion() + if suggestion is not None: + assert suggestion > 0, "Learning rate suggestion should be positive" + + +def test_gradient_correctness(): + """Test that torch.gradient uses correct spacing parameter.""" + lr_finder = _LRFinder(mode="exponential", lr_min=1e-6, lr_max=1e-1, num_training=20) + + # Synthetic example + lrs = torch.linspace(0, 2 * math.pi, steps=1000) + losses = torch.sin(lrs) + lr_finder.results = {"lr": lrs.tolist(), "loss": losses.tolist()} + + # Test the suggestion method + suggestion = lr_finder.suggestion(skip_begin=2, skip_end=2) + assert suggestion is not None + assert abs(suggestion - math.pi) < 1e-2, "Suggestion should be close to pi for this synthetic example" + + +def test_exponential_vs_linear_mode_gradient_difference(tmp_path): + """Test that exponential and linear modes produce different but valid suggestions. + + This verifies that the spacing fix works for both modes and that they behave differently as expected due to their + different lr progressions. + + """ + + class TestModel(BoringModel): + def __init__(self): + super().__init__() + self.lr = 1e-3 + + seed_everything(42) + + # Test both modes with identical parameters + model_linear = TestModel() + model_exp = TestModel() + + trainer_linear = Trainer(default_root_dir=tmp_path, max_epochs=1) + trainer_exp = Trainer(default_root_dir=tmp_path, max_epochs=1) + + tuner_linear = Tuner(trainer_linear) + tuner_exp = Tuner(trainer_exp) + + lr_finder_linear = tuner_linear.lr_find(model_linear, min_lr=1e-6, max_lr=1e-1, num_training=50, mode="linear") + lr_finder_exp = tuner_exp.lr_find(model_exp, min_lr=1e-6, max_lr=1e-1, num_training=50, mode="exponential") + + # Both should produce valid suggestions + suggestion_linear = lr_finder_linear.suggestion() + suggestion_exp = lr_finder_exp.suggestion() + + assert suggestion_linear is not None + assert suggestion_exp is not None + assert suggestion_linear > 0 + assert suggestion_exp > 0 + + # Verify that gradient computation uses correct spacing for both modes + for lr_finder, mode in [(lr_finder_linear, "linear"), (lr_finder_exp, "exponential")]: + losses = torch.tensor(lr_finder.results["loss"][10:-10]) + lrs = torch.tensor(lr_finder.results["lr"][10:-10]) + is_finite = torch.isfinite(losses) + losses_filtered = losses[is_finite] + lrs_filtered = lrs[is_finite] + + if len(losses_filtered) >= 2: + # Test that gradient computation works and produces finite results + gradients = torch.gradient(losses_filtered, spacing=[lrs_filtered])[0] + assert torch.isfinite(gradients).all(), f"Non-finite gradients in {mode} mode" + assert len(gradients) == len(losses_filtered) + + # Verify gradients with spacing differ from gradients without spacing + gradients_no_spacing = torch.gradient(losses_filtered)[0] + + # For exponential mode, these should definitely be different, for linear mode, they might be similar + if mode == "exponential": + assert not torch.allclose(gradients, gradients_no_spacing, rtol=0.1), ( + "Gradients should differ significantly in exponential mode when using proper spacing" + ) diff --git a/tests/tests_pytorch/utilities/test_model_summary.py b/tests/tests_pytorch/utilities/test_model_summary.py index 54c5572d01767..a243ed760cabf 100644 --- a/tests/tests_pytorch/utilities/test_model_summary.py +++ b/tests/tests_pytorch/utilities/test_model_summary.py @@ -18,6 +18,7 @@ import pytest import torch import torch.nn as nn +from lightning_utilities.test.warning import no_warning_call from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel @@ -323,19 +324,33 @@ def test_empty_model_size(max_depth): pytest.param("mps", marks=RunIf(mps=True)), ], ) -def test_model_size_precision(tmp_path, accelerator): - """Test model size for half and full precision.""" - model = PreCalculatedModel() +@pytest.mark.parametrize("precision", ["16-true", "32-true", "64-true"]) +def test_model_size_precision(tmp_path, accelerator, precision): + """Test model size for different precision types.""" + model = PreCalculatedModel(precision=int(precision.split("-")[0])) # fit model trainer = Trainer( - default_root_dir=tmp_path, accelerator=accelerator, devices=1, max_steps=1, max_epochs=1, precision=32 + default_root_dir=tmp_path, accelerator=accelerator, devices=1, max_steps=1, max_epochs=1, precision=precision ) trainer.fit(model) summary = summarize(model) assert model.pre_calculated_model_size == summary.model_size +def test_model_size_warning_on_unsupported_precision(tmp_path): + """Test that a warning is raised when the precision is not supported.""" + model = PreCalculatedModel(precision=32) # fallback to 32 bits + + # supported precision by lightning but not by the model summary + trainer = Trainer(max_epochs=1, precision="16-mixed", default_root_dir=tmp_path) + trainer.fit(model) + + with pytest.warns(UserWarning, match="Precision .* is not supported by the model summary.*"): + summary = summarize(model) + assert model.pre_calculated_model_size == summary.model_size + + def test_lazy_model_summary(): """Test that the model summary can work with lazy layers.""" lazy_model = LazyModel() @@ -343,6 +358,7 @@ def test_lazy_model_summary(): with pytest.warns(UserWarning, match="The total number of parameters detected may be inaccurate."): assert summary.total_parameters == 0 + with no_warning_call(): assert summary.trainable_parameters == 0