diff --git a/.azure/gpu-benchmarks.yml b/.azure/gpu-benchmarks.yml index 82ee475d53d31..b77dbfc4f792a 100644 --- a/.azure/gpu-benchmarks.yml +++ b/.azure/gpu-benchmarks.yml @@ -108,5 +108,6 @@ jobs: condition: and(succeeded(), eq(variables['PACKAGE_NAME'], 'fabric')) env: PL_RUN_CUDA_TESTS: "1" + PL_RUN_STANDALONE_TESTS: "1" displayName: "Testing: fabric standalone tasks" timeoutInMinutes: "10" diff --git a/.azure/gpu-tests-fabric.yml b/.azure/gpu-tests-fabric.yml index 0d970a552cecc..4d738d9110599 100644 --- a/.azure/gpu-tests-fabric.yml +++ b/.azure/gpu-tests-fabric.yml @@ -56,11 +56,14 @@ jobs: options: "--gpus=all --shm-size=2gb -v /var/tmp:/var/tmp" strategy: matrix: + "Fabric | oldest": + image: "pytorchlightning/pytorch_lightning:base-cuda-py3.10-torch2.1-cuda12.1.1" + PACKAGE_NAME: "fabric" "Fabric | latest": - image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.3-cuda12.1.0" + image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.6-cuda12.4.1" PACKAGE_NAME: "fabric" "Lightning | latest": - image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.5-cuda12.1.0" + image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.6-cuda12.4.1" PACKAGE_NAME: "lightning" workspace: clean: all @@ -77,9 +80,8 @@ jobs: displayName: "set env. vars" - bash: | echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}" - echo "##vso[task.setvariable variable=TORCHVISION_URL]https://download.pytorch.org/whl/test/cu124/torchvision-0.19.0%2Bcu124-cp${PYTHON_VERSION_MM}-cp${PYTHON_VERSION_MM}-linux_x86_64.whl" condition: endsWith(variables['Agent.JobName'], 'future') - displayName: "set env. vars 4 future" + displayName: "extend env. vars 4 future" - bash: | echo $(DEVICES) @@ -105,8 +107,9 @@ jobs: displayName: "Adjust dependencies" - bash: | + set -e extra=$(python -c "print({'lightning': 'fabric-'}.get('$(PACKAGE_NAME)', ''))") - pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}" --find-links="${TORCHVISION_URL}" + pip install -e ".[${extra}dev]" pytest-timeout -U --extra-index-url="${TORCH_URL}" pip install setuptools==75.6.0 jsonargparse==4.35.0 displayName: "Install package & dependencies" @@ -114,6 +117,7 @@ jobs: set -e python requirements/collect_env_details.py python -c "import torch ; mgpu = torch.cuda.device_count() ; assert mgpu == 2, f'GPU: {mgpu}'" + python requirements/pytorch/check-avail-extras.py python -c "import bitsandbytes" displayName: "Env details" @@ -140,10 +144,12 @@ jobs: displayName: "Testing: fabric standard" timeoutInMinutes: "10" - - bash: bash ./run_standalone_tests.sh "tests_fabric" + - bash: | + wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/run_standalone_tests.sh + bash ./run_standalone_tests.sh "tests_fabric" workingDirectory: tests/ env: - PL_STANDALONE_TESTS_SOURCE: $(COVERAGE_SOURCE) + PL_RUN_STANDALONE_TESTS: "1" displayName: "Testing: fabric standalone" timeoutInMinutes: "10" diff --git a/.azure/gpu-tests-pytorch.yml b/.azure/gpu-tests-pytorch.yml index e09ad011908cb..414f98dab3f66 100644 --- a/.azure/gpu-tests-pytorch.yml +++ b/.azure/gpu-tests-pytorch.yml @@ -49,11 +49,14 @@ jobs: cancelTimeoutInMinutes: "2" strategy: matrix: + "PyTorch | oldest": + image: "pytorchlightning/pytorch_lightning:base-cuda-py3.10-torch2.1-cuda12.1.1" + PACKAGE_NAME: "pytorch" "PyTorch | latest": - image: "pytorchlightning/pytorch_lightning:base-cuda-py3.11-torch2.3-cuda12.1.0" + image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.6-cuda12.4.1" PACKAGE_NAME: "pytorch" "Lightning | latest": - image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.5-cuda12.1.0" + image: "pytorchlightning/pytorch_lightning:base-cuda-py3.12-torch2.6-cuda12.4.1" PACKAGE_NAME: "lightning" pool: lit-rtx-3090 variables: @@ -81,9 +84,8 @@ jobs: displayName: "set env. vars" - bash: | echo "##vso[task.setvariable variable=TORCH_URL]https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM}" - echo "##vso[task.setvariable variable=TORCHVISION_URL]https://download.pytorch.org/whl/test/cu124/torchvision-0.19.0%2Bcu124-cp${PYTHON_VERSION_MM}-cp${PYTHON_VERSION_MM}-linux_x86_64.whl" condition: endsWith(variables['Agent.JobName'], 'future') - displayName: "set env. vars 4 future" + displayName: "extend env. vars 4 future" - bash: | echo $(DEVICES) @@ -109,8 +111,9 @@ jobs: displayName: "Adjust dependencies" - bash: | + set -e extra=$(python -c "print({'lightning': 'pytorch-'}.get('$(PACKAGE_NAME)', ''))") - pip install -e ".[${extra}dev]" pytest-timeout -U --find-links="${TORCH_URL}" --find-links="${TORCHVISION_URL}" + pip install -e ".[${extra}dev]" pytest-timeout -U --extra-index-url="${TORCH_URL}" pip install setuptools==75.6.0 jsonargparse==4.35.0 displayName: "Install package & dependencies" @@ -161,11 +164,13 @@ jobs: displayName: "Testing: PyTorch standard" timeoutInMinutes: "35" - - bash: bash ./run_standalone_tests.sh "tests_pytorch" + - bash: | + wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/run_standalone_tests.sh + bash ./run_standalone_tests.sh "tests_pytorch" workingDirectory: tests/ env: PL_USE_MOCKED_MNIST: "1" - PL_STANDALONE_TESTS_SOURCE: $(COVERAGE_SOURCE) + PL_RUN_STANDALONE_TESTS: "1" displayName: "Testing: PyTorch standalone tests" timeoutInMinutes: "35" diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 775dc5dee77dc..91cf94023786c 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -182,14 +182,14 @@ We welcome any useful contribution! For your convenience here's a recommended wo 1. Use tags in PR name for the following cases: - **\[blocked by #\]** if your work is dependent on other PRs. - - **\[wip\]** when you start to re-edit your work, mark it so no one will accidentally merge it in meantime. + - **[wip]** when you start to re-edit your work, mark it so no one will accidentally merge it in meantime. ### Question & Answer #### How can I help/contribute? All types of contributions are welcome - reporting bugs, fixing documentation, adding test cases, solving issues, and preparing bug fixes. -To get started with code contributions, look for issues marked with the label [good first issue](https://github.com/Lightning-AI/lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) or chose something close to your domain with the label [help wanted](https://github.com/Lightning-AI/lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22). Before coding, make sure that the issue description is clear and comment on the issue so that we can assign it to you (or simply self-assign if you can). +To get started with code contributions, look for issues marked with the label [good first issue](https://github.com/Lightning-AI/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) or chose something close to your domain with the label [help wanted](https://github.com/Lightning-AI/pytorch-lightning/issues?q=is%3Aopen+is%3Aissue+label%3A%22help+wanted%22). Before coding, make sure that the issue description is clear and comment on the issue so that we can assign it to you (or simply self-assign if you can). #### Is there a recommendation for branch names? diff --git a/.github/actions/pip-wheels/action.yml b/.github/actions/pip-wheels/action.yml index 28d6e346b7aa2..19f2e7bf5e182 100644 --- a/.github/actions/pip-wheels/action.yml +++ b/.github/actions/pip-wheels/action.yml @@ -46,8 +46,8 @@ runs: run: | # cat requirements.dump pip wheel -r requirements.dump --prefer-binary \ - --wheel-dir=.wheels \ - -f ${{ inputs.torch-url }} -f ${{ inputs.wheel-dir }} + --wheel-dir=".wheels" \ + --extra-index-url=${{ inputs.torch-url }} -f ${{ inputs.wheel-dir }} ls -lh .wheels/ shell: bash diff --git a/.github/checkgroup.yml b/.github/checkgroup.yml index b1d54bc5e12fc..271284635b638 100644 --- a/.github/checkgroup.yml +++ b/.github/checkgroup.yml @@ -19,30 +19,7 @@ subprojects: - "!*.md" - "!**/*.md" checks: - - "pl-cpu (macOS-14, lightning, 3.9, 2.1, oldest)" - - "pl-cpu (macOS-14, lightning, 3.10, 2.1)" - - "pl-cpu (macOS-14, lightning, 3.11, 2.2.2)" - - "pl-cpu (macOS-14, lightning, 3.11, 2.3)" - - "pl-cpu (macOS-14, lightning, 3.12, 2.4.1)" - - "pl-cpu (macOS-14, lightning, 3.12, 2.5.1)" - - "pl-cpu (ubuntu-20.04, lightning, 3.9, 2.1, oldest)" - - "pl-cpu (ubuntu-20.04, lightning, 3.10, 2.1)" - - "pl-cpu (ubuntu-20.04, lightning, 3.11, 2.2.2)" - - "pl-cpu (ubuntu-20.04, lightning, 3.11, 2.3)" - - "pl-cpu (ubuntu-22.04, lightning, 3.12, 2.4.1)" - - "pl-cpu (ubuntu-22.04, lightning, 3.12, 2.5.1)" - - "pl-cpu (windows-2022, lightning, 3.9, 2.1, oldest)" - - "pl-cpu (windows-2022, lightning, 3.10, 2.1)" - - "pl-cpu (windows-2022, lightning, 3.11, 2.2.2)" - - "pl-cpu (windows-2022, lightning, 3.11, 2.3)" - - "pl-cpu (windows-2022, lightning, 3.12, 2.4.1)" - - "pl-cpu (windows-2022, lightning, 3.12, 2.5.1)" - - "pl-cpu (macOS-14, pytorch, 3.9, 2.1)" - - "pl-cpu (ubuntu-20.04, pytorch, 3.9, 2.1)" - - "pl-cpu (windows-2022, pytorch, 3.9, 2.1)" - - "pl-cpu (macOS-14, pytorch, 3.12, 2.5.1)" - - "pl-cpu (ubuntu-22.04, pytorch, 3.12, 2.5.1)" - - "pl-cpu (windows-2022, pytorch, 3.12, 2.5.1)" + - "pl-cpu-guardian" # aggregated check for all cases - id: "pytorch_lightning: Azure GPU" paths: @@ -172,30 +149,7 @@ subprojects: - "!*.md" - "!**/*.md" checks: - - "fabric-cpu (macOS-14, lightning, 3.9, 2.1, oldest)" - - "fabric-cpu (macOS-14, lightning, 3.10, 2.1)" - - "fabric-cpu (macOS-14, lightning, 3.11, 2.2.2)" - - "fabric-cpu (macOS-14, lightning, 3.11, 2.3)" - - "fabric-cpu (macOS-14, lightning, 3.12, 2.4.1)" - - "fabric-cpu (macOS-14, lightning, 3.12, 2.5.1)" - - "fabric-cpu (ubuntu-20.04, lightning, 3.9, 2.1, oldest)" - - "fabric-cpu (ubuntu-20.04, lightning, 3.10, 2.1)" - - "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.2.2)" - - "fabric-cpu (ubuntu-20.04, lightning, 3.11, 2.3)" - - "fabric-cpu (ubuntu-22.04, lightning, 3.12, 2.4.1)" - - "fabric-cpu (ubuntu-22.04, lightning, 3.12, 2.5.1)" - - "fabric-cpu (windows-2022, lightning, 3.9, 2.1, oldest)" - - "fabric-cpu (windows-2022, lightning, 3.10, 2.1)" - - "fabric-cpu (windows-2022, lightning, 3.11, 2.2.2)" - - "fabric-cpu (windows-2022, lightning, 3.11, 2.3)" - - "fabric-cpu (windows-2022, lightning, 3.12, 2.4.1)" - - "fabric-cpu (windows-2022, lightning, 3.12, 2.5.1)" - - "fabric-cpu (macOS-14, fabric, 3.9, 2.1)" - - "fabric-cpu (ubuntu-20.04, fabric, 3.9, 2.1)" - - "fabric-cpu (windows-2022, fabric, 3.9, 2.1)" - - "fabric-cpu (macOS-14, fabric, 3.12, 2.5.1)" - - "fabric-cpu (ubuntu-22.04, fabric, 3.12, 2.5.1)" - - "fabric-cpu (windows-2022, fabric, 3.12, 2.5.1)" + - "fabric-cpu-guardian" # aggregated check for all cases - id: "lightning_fabric: Azure GPU" paths: @@ -259,27 +213,4 @@ subprojects: - "!*.md" - "!**/*.md" checks: - - "install-pkg (ubuntu-22.04, fabric, 3.9)" - - "install-pkg (ubuntu-22.04, fabric, 3.11)" - - "install-pkg (ubuntu-22.04, pytorch, 3.9)" - - "install-pkg (ubuntu-22.04, pytorch, 3.11)" - - "install-pkg (ubuntu-22.04, lightning, 3.9)" - - "install-pkg (ubuntu-22.04, lightning, 3.11)" - - "install-pkg (ubuntu-22.04, notset, 3.9)" - - "install-pkg (ubuntu-22.04, notset, 3.11)" - - "install-pkg (macOS-14, fabric, 3.9)" - - "install-pkg (macOS-14, fabric, 3.11)" - - "install-pkg (macOS-14, pytorch, 3.9)" - - "install-pkg (macOS-14, pytorch, 3.11)" - - "install-pkg (macOS-14, lightning, 3.9)" - - "install-pkg (macOS-14, lightning, 3.11)" - - "install-pkg (macOS-14, notset, 3.9)" - - "install-pkg (macOS-14, notset, 3.11)" - - "install-pkg (windows-2022, fabric, 3.9)" - - "install-pkg (windows-2022, fabric, 3.11)" - - "install-pkg (windows-2022, pytorch, 3.9)" - - "install-pkg (windows-2022, pytorch, 3.11)" - - "install-pkg (windows-2022, lightning, 3.9)" - - "install-pkg (windows-2022, lightning, 3.11)" - - "install-pkg (windows-2022, notset, 3.9)" - - "install-pkg (windows-2022, notset, 3.11)" + - "install-pkg-guardian" # aggregated check for all cases diff --git a/.github/workflows/README.md b/.github/workflows/README.md index 58f4afe529509..3bdc8f9a0b07f 100644 --- a/.github/workflows/README.md +++ b/.github/workflows/README.md @@ -16,7 +16,7 @@ Brief description of all our automation tools used for boosting development perf | .azure-pipelines/gpu-benchmarks.yml | Run speed/memory benchmarks for parity with vanila PyTorch. | GPU | | .github/workflows/ci-flagship-apps.yml | Run end-2-end tests with full applications, including deployment to the production cloud. | CPU | | .github/workflows/ci-tests-pytorch.yml | Run all tests except for accelerator-specific, standalone and slow tests. | CPU | -| .github/workflows/tpu-tests.yml | Run only TPU-specific tests. Requires that the PR title contains '\[TPU\]' | TPU | +| .github/workflows/tpu-tests.yml | Run only TPU-specific tests. Requires that the PR title contains '[TPU]' | TPU | \* Each standalone test needs to be run in separate processes to avoid unwanted interactions between test cases. diff --git a/.github/workflows/_legacy-checkpoints.yml b/.github/workflows/_legacy-checkpoints.yml index 0161ab57bca52..4107633424388 100644 --- a/.github/workflows/_legacy-checkpoints.yml +++ b/.github/workflows/_legacy-checkpoints.yml @@ -43,7 +43,7 @@ on: env: LEGACY_FOLDER: "tests/legacy" - TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html" + TORCH_URL: "https://download.pytorch.org/whl/cpu/" defaults: run: @@ -67,12 +67,12 @@ jobs: PACKAGE_NAME: pytorch FREEZE_REQUIREMENTS: 1 timeout-minutes: 20 - run: pip install . -f ${TORCH_URL} + run: pip install . --extra-index-url="${TORCH_URL}" if: inputs.pl_version == '' - name: Install PL version timeout-minutes: 20 - run: pip install "pytorch-lightning==${{ inputs.pl_version }}" -f ${TORCH_URL} + run: pip install "pytorch-lightning==${{ inputs.pl_version }}" --extra-index-url="${TORCH_URL}" if: inputs.pl_version != '' - name: Adjust tests -> PL diff --git a/.github/workflows/call-clear-cache.yml b/.github/workflows/call-clear-cache.yml index 1dddbe8f72bb0..b736d2a91f55f 100644 --- a/.github/workflows/call-clear-cache.yml +++ b/.github/workflows/call-clear-cache.yml @@ -23,18 +23,18 @@ on: jobs: cron-clear: if: github.event_name == 'schedule' || github.event_name == 'pull_request' - uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.11.9 + uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.14.1 with: - scripts-ref: v0.11.8 + scripts-ref: v0.14.1 dry-run: ${{ github.event_name == 'pull_request' }} pattern: "latest|docs" age-days: 7 direct-clear: if: github.event_name == 'workflow_dispatch' || github.event_name == 'pull_request' - uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.11.9 + uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.14.1 with: - scripts-ref: v0.11.8 + scripts-ref: v0.14.1 dry-run: ${{ github.event_name == 'pull_request' }} pattern: ${{ inputs.pattern || 'pypi_wheels' }} # setting str in case of PR / debugging age-days: ${{ fromJSON(inputs.age-days) || 0 }} # setting 0 in case of PR / debugging diff --git a/.github/workflows/ci-check-md-links.yml b/.github/workflows/ci-check-md-links.yml index d0dc889230112..b619b756c1349 100644 --- a/.github/workflows/ci-check-md-links.yml +++ b/.github/workflows/ci-check-md-links.yml @@ -14,7 +14,7 @@ on: jobs: check-md-links: - uses: Lightning-AI/utilities/.github/workflows/check-md-links.yml@v0.11.9 + uses: Lightning-AI/utilities/.github/workflows/check-md-links.yml@v0.14.1 with: config-file: ".github/markdown-links-config.json" base-branch: "master" diff --git a/.github/workflows/ci-pkg-install.yml b/.github/workflows/ci-pkg-install.yml index 61055c9b5ac3d..6e38c26f4174e 100644 --- a/.github/workflows/ci-pkg-install.yml +++ b/.github/workflows/ci-pkg-install.yml @@ -103,3 +103,17 @@ jobs: LIGHTING_TESTING: 1 # path for require wrapper PY_IGNORE_IMPORTMISMATCH: 1 run: python -m pytest src/lit/${PKG_NAME} --ignore-glob="**/cli/*-template/**" --doctest-plus + + install-pkg-guardian: + runs-on: ubuntu-latest + needs: install-pkg + if: always() + steps: + - run: echo "${{ needs.install-pkg.result }}" + - name: failing... + if: needs.install-pkg.result == 'failure' + run: exit 1 + - name: cancelled or skipped... + if: contains(fromJSON('["cancelled", "skipped"]'), needs.install-pkg.result) + timeout-minutes: 1 + run: sleep 90 diff --git a/.github/workflows/ci-schema.yml b/.github/workflows/ci-schema.yml index 32cd82f12784b..aec5f9b4bc261 100644 --- a/.github/workflows/ci-schema.yml +++ b/.github/workflows/ci-schema.yml @@ -8,7 +8,7 @@ on: jobs: check: - uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.11.9 + uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.14.1 with: # skip azure due to the wrong schema file by MSFT # https://github.com/Lightning-AI/lightning-flash/pull/1455#issuecomment-1244793607 diff --git a/.github/workflows/ci-tests-fabric.yml b/.github/workflows/ci-tests-fabric.yml index c6d1dbf5b5ff2..f3061de2010db 100644 --- a/.github/workflows/ci-tests-fabric.yml +++ b/.github/workflows/ci-tests-fabric.yml @@ -56,36 +56,28 @@ jobs: - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" } - { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" } # only run PyTorch latest with Python latest, use Fabric scope to limit dependency issues - - { os: "macOS-14", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.5.1" } - - { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.5.1" } - - { os: "windows-2022", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.5.1" } + - { os: "macOS-14", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.6" } + - { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.6" } + - { os: "windows-2022", pkg-name: "fabric", python-version: "3.12.7", pytorch-version: "2.6" } # "oldest" versions tests, only on minimum Python - - { os: "macOS-14", pkg-name: "lightning", python-version: "3.9", pytorch-version: "2.1", requires: "oldest" } - - { - os: "ubuntu-20.04", - pkg-name: "lightning", - python-version: "3.9", - pytorch-version: "2.1", - requires: "oldest", - } - - { - os: "windows-2022", - pkg-name: "lightning", - python-version: "3.9", - pytorch-version: "2.1", - requires: "oldest", - } + - { os: "macOS-14", pkg-name: "fabric", pytorch-version: "2.1", requires: "oldest" } + - { os: "ubuntu-20.04", pkg-name: "fabric", pytorch-version: "2.1", requires: "oldest" } + - { os: "windows-2022", pkg-name: "fabric", pytorch-version: "2.1", requires: "oldest" } # "fabric" installs the standalone package - - { os: "macOS-14", pkg-name: "fabric", python-version: "3.9", pytorch-version: "2.1" } - - { os: "ubuntu-20.04", pkg-name: "fabric", python-version: "3.9", pytorch-version: "2.1" } - - { os: "windows-2022", pkg-name: "fabric", python-version: "3.9", pytorch-version: "2.1" } + - { os: "macOS-14", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.5" } + - { os: "ubuntu-20.04", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.5" } + - { os: "windows-2022", pkg-name: "fabric", python-version: "3.10", pytorch-version: "2.5" } + # adding recently cut Torch 2.7 - FUTURE + # - { os: "macOS-14", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.7" } + # - { os: "ubuntu-22.04", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.7" } + # - { os: "windows-2022", pkg-name: "fabric", python-version: "3.12", pytorch-version: "2.7" } timeout-minutes: 25 # because of building grpcio on Mac env: PACKAGE_NAME: ${{ matrix.pkg-name }} FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }} PYPI_CACHE_DIR: "_pip-wheels" - TORCH_URL_STABLE: "https://download.pytorch.org/whl/cpu/torch_stable.html" - TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/torch" + TORCH_URL_STABLE: "https://download.pytorch.org/whl/cpu/" + TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/" # TODO: Remove this - Enable running MPS tests on this platform DISABLE_MPS: ${{ matrix.os == 'macOS-14' && '1' || '0' }} steps: @@ -94,7 +86,7 @@ jobs: - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: - python-version: ${{ matrix.python-version }} + python-version: ${{ matrix.python-version || '3.9' }} - name: basic setup run: pip install -q -r .actions/requirements.txt @@ -126,8 +118,8 @@ jobs: - name: Env. variables run: | - # Switch PyTorch URL - python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.5' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV + # Switch PyTorch URL between stable and test/future + python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.7' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV # Switch coverage scope python -c "print('COVERAGE_SCOPE=' + str('lightning' if '${{matrix.pkg-name}}' == 'lightning' else 'lightning_fabric'))" >> $GITHUB_ENV # if you install mono-package set dependency only for this subpackage @@ -139,7 +131,7 @@ jobs: timeout-minutes: 20 run: | pip install -e ".[${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies]" -U --prefer-binary \ - --find-links="${TORCH_URL}" --find-links="${PYPI_CACHE_DIR}" + --extra-index-url="${TORCH_URL}" --find-links="${PYPI_CACHE_DIR}" pip list - name: Dump handy wheels if: github.event_name == 'push' && github.ref == 'refs/heads/master' @@ -186,3 +178,17 @@ jobs: flags: ${{ env.COVERAGE_SCOPE }},cpu,pytest,python${{ matrix.python-version }} name: CPU-coverage fail_ci_if_error: false + + fabric-cpu-guardian: + runs-on: ubuntu-latest + needs: fabric-cpu + if: always() + steps: + - run: echo "${{ needs.fabric-cpu.result }}" + - name: failing... + if: needs.fabric-cpu.result == 'failure' + run: exit 1 + - name: cancelled or skipped... + if: contains(fromJSON('["cancelled", "skipped"]'), needs.fabric-cpu.result) + timeout-minutes: 1 + run: sleep 90 diff --git a/.github/workflows/ci-tests-pytorch.yml b/.github/workflows/ci-tests-pytorch.yml index 0112d21d33a5c..7a769d5b52d1a 100644 --- a/.github/workflows/ci-tests-pytorch.yml +++ b/.github/workflows/ci-tests-pytorch.yml @@ -60,35 +60,27 @@ jobs: - { os: "ubuntu-22.04", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" } - { os: "windows-2022", pkg-name: "lightning", python-version: "3.12.7", pytorch-version: "2.5.1" } # only run PyTorch latest with Python latest, use PyTorch scope to limit dependency issues - - { os: "macOS-14", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.5.1" } - - { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.5.1" } - - { os: "windows-2022", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.5.1" } + - { os: "macOS-14", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.6" } + - { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.6" } + - { os: "windows-2022", pkg-name: "pytorch", python-version: "3.12.7", pytorch-version: "2.6" } # "oldest" versions tests, only on minimum Python - - { os: "macOS-14", pkg-name: "lightning", python-version: "3.9", pytorch-version: "2.1", requires: "oldest" } - - { - os: "ubuntu-20.04", - pkg-name: "lightning", - python-version: "3.9", - pytorch-version: "2.1", - requires: "oldest", - } - - { - os: "windows-2022", - pkg-name: "lightning", - python-version: "3.9", - pytorch-version: "2.1", - requires: "oldest", - } + - { os: "macOS-14", pkg-name: "pytorch", pytorch-version: "2.1", requires: "oldest" } + - { os: "ubuntu-20.04", pkg-name: "pytorch", pytorch-version: "2.1", requires: "oldest" } + - { os: "windows-2022", pkg-name: "pytorch", pytorch-version: "2.1", requires: "oldest" } # "pytorch" installs the standalone package - - { os: "macOS-14", pkg-name: "pytorch", python-version: "3.9", pytorch-version: "2.1" } - - { os: "ubuntu-20.04", pkg-name: "pytorch", python-version: "3.9", pytorch-version: "2.1" } - - { os: "windows-2022", pkg-name: "pytorch", python-version: "3.9", pytorch-version: "2.1" } + - { os: "macOS-14", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.5" } + - { os: "ubuntu-20.04", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.5" } + - { os: "windows-2022", pkg-name: "pytorch", python-version: "3.10", pytorch-version: "2.5" } + # adding recently cut Torch 2.7 - FUTURE + # - { os: "macOS-14", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.7" } + # - { os: "ubuntu-22.04", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.7" } + # - { os: "windows-2022", pkg-name: "pytorch", python-version: "3.12", pytorch-version: "2.7" } timeout-minutes: 50 env: PACKAGE_NAME: ${{ matrix.pkg-name }} - TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html" - TORCH_URL_STABLE: "https://download.pytorch.org/whl/cpu/torch_stable.html" - TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/torch" + TORCH_URL: "https://download.pytorch.org/whl/cpu/" + TORCH_URL_STABLE: "https://download.pytorch.org/whl/cpu/" + TORCH_URL_TEST: "https://download.pytorch.org/whl/test/cpu/" FREEZE_REQUIREMENTS: ${{ ! (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/heads/release/')) }} PYPI_CACHE_DIR: "_pip-wheels" # TODO: Remove this - Enable running MPS tests on this platform @@ -99,7 +91,7 @@ jobs: - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: - python-version: ${{ matrix.python-version }} + python-version: ${{ matrix.python-version || '3.9' }} - name: basic setup run: pip install -q -r .actions/requirements.txt @@ -132,8 +124,8 @@ jobs: - name: Env. variables run: | - # Switch PyTorch URL - python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.5' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV + # Switch PyTorch URL between stable and test/future + python -c "print('TORCH_URL=' + str('${{env.TORCH_URL_TEST}}' if '${{ matrix.pytorch-version }}' == '2.7' else '${{env.TORCH_URL_STABLE}}'))" >> $GITHUB_ENV # Switch coverage scope python -c "print('COVERAGE_SCOPE=' + str('lightning' if '${{matrix.pkg-name}}' == 'lightning' else 'pytorch_lightning'))" >> $GITHUB_ENV # if you install mono-package set dependency only for this subpackage @@ -146,7 +138,7 @@ jobs: run: | pip install ".[${EXTRA_PREFIX}extra,${EXTRA_PREFIX}test,${EXTRA_PREFIX}strategies]" -U --prefer-binary \ -r requirements/_integrations/accelerators.txt \ - --find-links="${TORCH_URL}" --find-links="${PYPI_CACHE_DIR}" + --extra-index-url="${TORCH_URL}" --find-links="${PYPI_CACHE_DIR}" pip list - name: Drop LAI from extensions if: ${{ matrix.pkg-name != 'lightning' }} @@ -223,3 +215,17 @@ jobs: flags: ${{ env.COVERAGE_SCOPE }},cpu,pytest-full,python${{ matrix.python-version }},pytorch${{ matrix.pytorch-version }} name: CPU-coverage fail_ci_if_error: false + + pl-cpu-guardian: + runs-on: ubuntu-latest + needs: pl-cpu + if: always() + steps: + - run: echo "${{ needs.pl-cpu.result }}" + - name: failing... + if: needs.pl-cpu.result == 'failure' + run: exit 1 + - name: cancelled or skipped... + if: contains(fromJSON('["cancelled", "skipped"]'), needs.pl-cpu.result) + timeout-minutes: 1 + run: sleep 90 diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index 7ab558aa7b07f..b623cdc9337f3 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -100,6 +100,7 @@ jobs: - { python_version: "3.11", pytorch_version: "2.3.1", cuda_version: "12.1.1" } - { python_version: "3.11", pytorch_version: "2.4.1", cuda_version: "12.1.1" } - { python_version: "3.12", pytorch_version: "2.5.1", cuda_version: "12.1.1" } + - { python_version: "3.12", pytorch_version: "2.6.0", cuda_version: "12.4.1" } steps: - uses: actions/checkout@v4 - uses: docker/setup-buildx-action@v3 diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index d84b7bed7a34a..9b2bab5ab98d4 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -46,7 +46,7 @@ defaults: env: FREEZE_REQUIREMENTS: "1" - TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html" + TORCH_URL: "https://download.pytorch.org/whl/cpu/" PYPI_CACHE_DIR: "_pip-wheels" PYPI_LOCAL_DIR: "pypi_pkgs/" @@ -106,7 +106,7 @@ jobs: mkdir -p ${PYPI_CACHE_DIR} # in case cache was not hit ls -lh ${PYPI_CACHE_DIR} pip install .[all] -U -r requirements/${{ matrix.pkg-name }}/docs.txt \ - -f ${PYPI_LOCAL_DIR} -f ${PYPI_CACHE_DIR} -f ${TORCH_URL} + -f ${PYPI_LOCAL_DIR} -f ${PYPI_CACHE_DIR} --extra-index-url="${TORCH_URL}" pip list - name: Install req. for Notebooks/tutorials if: matrix.pkg-name == 'pytorch' @@ -174,6 +174,21 @@ jobs: with: project_id: ${{ secrets.GCS_PROJECT }} + # Uploading docs as archive to GCS, so they can be as backup + - name: Upload docs as archive to GCS 🪣 + if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' + working-directory: docs/build + run: | + zip ${{ env.VERSION }}.zip -r html/ + gsutil cp ${{ env.VERSION }}.zip ${GCP_TARGET} + + - name: Inject version selector + working-directory: docs/build + run: | + pip install -q wget + python -m wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/inject-selector-script.py + python inject-selector-script.py html ${{ matrix.pkg-name }} + # Uploading docs to GCS, so they can be served on lightning.ai - name: Upload docs/${{ matrix.pkg-name }}/stable to GCS 🪣 if: startsWith(github.ref, 'refs/heads/release/') && github.event_name == 'push' @@ -188,11 +203,3 @@ jobs: - name: Upload docs/${{ matrix.pkg-name }}/release to GCS 🪣 if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' run: gsutil -m rsync -d -R docs/build/html/ ${GCP_TARGET}/${{ env.VERSION }} - - # Uploading docs as archive to GCS, so they can be as backup - - name: Upload docs as archive to GCS 🪣 - if: startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' - working-directory: docs/build - run: | - zip ${{ env.VERSION }}.zip -r html/ - gsutil cp ${{ env.VERSION }}.zip ${GCP_TARGET} diff --git a/.github/workflows/release-pkg.yml b/.github/workflows/release-pkg.yml index c7828d70f7103..9786c2f57b3c7 100644 --- a/.github/workflows/release-pkg.yml +++ b/.github/workflows/release-pkg.yml @@ -23,7 +23,7 @@ defaults: env: FREEZE_REQUIREMENTS: 1 - TORCH_URL: "https://download.pytorch.org/whl/cpu/torch_stable.html" + TORCH_URL: "https://download.pytorch.org/whl/cpu/" PYTHON_VER: "3.9" jobs: @@ -60,7 +60,7 @@ jobs: python-version: ${{ env.PYTHON_VER }} - name: install Package run: | - pip install . -f ${TORCH_URL} + pip install . --extra-index-url="${TORCH_URL}" pip list - name: package Version id: lai-package diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c5e65de1d7eb7..f2e475f602913 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ ci: repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: end-of-file-fixer - id: trailing-whitespace @@ -65,12 +65,12 @@ repos: args: ["--in-place"] - repo: https://github.com/sphinx-contrib/sphinx-lint - rev: v0.9.1 + rev: v1.0.0 hooks: - id: sphinx-lint - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.5.0 + rev: v0.8.6 hooks: # try to fix what is possible - id: ruff @@ -81,7 +81,7 @@ repos: - id: ruff - repo: https://github.com/executablebooks/mdformat - rev: 0.7.17 + rev: 0.7.21 hooks: - id: mdformat additional_dependencies: diff --git a/dockers/base-cuda/Dockerfile b/dockers/base-cuda/Dockerfile index 0e56f2fa93bd9..0da0cf9b2de9f 100644 --- a/dockers/base-cuda/Dockerfile +++ b/dockers/base-cuda/Dockerfile @@ -92,9 +92,8 @@ RUN \ -r requirements/pytorch/extra.txt \ -r requirements/pytorch/test.txt \ -r requirements/pytorch/strategies.txt \ - --find-links="https://download.pytorch.org/whl/cu${CUDA_VERSION_MM//'.'/''}/torch_stable.html" \ - --find-links="https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM//'.'/''}/torch" \ - --find-links="https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM//'.'/''}/pytorch-triton" + --extra-index-url="https://download.pytorch.org/whl/cu${CUDA_VERSION_MM//'.'/''}/" \ + --extra-index-url="https://download.pytorch.org/whl/test/cu${CUDA_VERSION_MM//'.'/''}/" RUN \ # Show what we have diff --git a/docs/source-fabric/_templates/theme_variables.jinja b/docs/source-fabric/_templates/theme_variables.jinja index cce7263609621..e2ebb66281716 100644 --- a/docs/source-fabric/_templates/theme_variables.jinja +++ b/docs/source-fabric/_templates/theme_variables.jinja @@ -1,6 +1,6 @@ {%- set external_urls = { 'github': 'https://github.com/Lightning-AI/lightning', - 'github_issues': 'https://github.com/Lightning-AI/lightning/issues', + 'github_issues': 'https://github.com/Lightning-AI/pytorch-lightning/issues', 'contributing': 'https://github.com/Lightning-AI/lightning/blob/master/.github/CONTRIBUTING.md', 'governance': 'https://lightning.ai/docs/pytorch/latest/community/governance.html', 'docs': 'https://lightning.ai/docs/fabric/', diff --git a/docs/source-fabric/links.rst b/docs/source-fabric/links.rst index d8e1d12b1ac6d..bd7b431e76642 100644 --- a/docs/source-fabric/links.rst +++ b/docs/source-fabric/links.rst @@ -1,3 +1,3 @@ -.. _PyTorchJob: https://www.kubeflow.org/docs/components/training/pytorch/ +.. _PyTorchJob: https://www.kubeflow.org/docs/components/trainer/legacy-v1/user-guides/pytorch/ .. _Kubeflow: https://www.kubeflow.org .. _Trainer: https://lightning.ai/docs/pytorch/stable/common/trainer.html diff --git a/docs/source-pytorch/_templates/theme_variables.jinja b/docs/source-pytorch/_templates/theme_variables.jinja index 912c1882b9138..3e6a4d19bdeed 100644 --- a/docs/source-pytorch/_templates/theme_variables.jinja +++ b/docs/source-pytorch/_templates/theme_variables.jinja @@ -1,6 +1,6 @@ {%- set external_urls = { 'github': 'https://github.com/Lightning-AI/lightning', - 'github_issues': 'https://github.com/Lightning-AI/lightning/issues', + 'github_issues': 'https://github.com/Lightning-AI/pytorch-lightning/issues', 'contributing': 'https://github.com/Lightning-AI/lightning/blob/master/.github/CONTRIBUTING.md', 'governance': 'https://lightning.ai/docs/pytorch/latest/community/governance.html', 'docs': 'https://lightning.ai/docs/pytorch/latest/', diff --git a/docs/source-pytorch/accelerators/accelerator_prepare.rst b/docs/source-pytorch/accelerators/accelerator_prepare.rst index 4d1c539f23273..356c5d78dff1c 100644 --- a/docs/source-pytorch/accelerators/accelerator_prepare.rst +++ b/docs/source-pytorch/accelerators/accelerator_prepare.rst @@ -123,7 +123,7 @@ It is possible to perform some computation manually and log the reduced result o # When you call `self.log` only on rank 0, don't forget to add # `rank_zero_only=True` to avoid deadlocks on synchronization. - # Caveat: monitoring this is unimplemented, see https://github.com/Lightning-AI/lightning/issues/15852 + # Caveat: monitoring this is unimplemented, see https://github.com/Lightning-AI/pytorch-lightning/issues/15852 if self.trainer.is_global_zero: self.log("my_reduced_metric", mean, rank_zero_only=True) diff --git a/docs/source-pytorch/accelerators/gpu_intermediate.rst b/docs/source-pytorch/accelerators/gpu_intermediate.rst index 2774a4cf8fc6f..e5dcd151375b2 100644 --- a/docs/source-pytorch/accelerators/gpu_intermediate.rst +++ b/docs/source-pytorch/accelerators/gpu_intermediate.rst @@ -25,10 +25,6 @@ Lightning supports multiple ways of doing distributed training. .. note:: If you request multiple GPUs or nodes without setting a strategy, DDP will be automatically used. -For a deeper understanding of what Lightning is doing, feel free to read this -`guide `_. - - ---- diff --git a/docs/source-pytorch/advanced/ddp_optimizations.rst b/docs/source-pytorch/advanced/ddp_optimizations.rst index d2d14375155a5..34ca5d743a8f5 100644 --- a/docs/source-pytorch/advanced/ddp_optimizations.rst +++ b/docs/source-pytorch/advanced/ddp_optimizations.rst @@ -58,7 +58,7 @@ On a Multi-Node Cluster, Set NCCL Parameters ******************************************** `NCCL `__ is the NVIDIA Collective Communications Library that is used by PyTorch to handle communication across nodes and GPUs. -There are reported benefits in terms of speedups when adjusting NCCL parameters as seen in this `issue `__. +There are reported benefits in terms of speedups when adjusting NCCL parameters as seen in this `issue `__. In the issue, we see a 30% speed improvement when training the Transformer XLM-RoBERTa and a 15% improvement in training with Detectron2. NCCL parameters can be adjusted via environment variables. diff --git a/docs/source-pytorch/advanced/model_parallel/deepspeed.rst b/docs/source-pytorch/advanced/model_parallel/deepspeed.rst index 9689f8c217eaf..3a3846500ff35 100644 --- a/docs/source-pytorch/advanced/model_parallel/deepspeed.rst +++ b/docs/source-pytorch/advanced/model_parallel/deepspeed.rst @@ -319,7 +319,7 @@ Additionally, DeepSpeed supports offloading to NVMe drives for even larger model ) trainer.fit(model) -When offloading to NVMe you may notice that the speed is slow. There are parameters that need to be tuned based on the drives that you are using. Running the `aio_bench_perf_sweep.py `__ script can help you to find optimum parameters. See the `issue `__ for more information on how to parse the information. +When offloading to NVMe you may notice that the speed is slow. There are parameters that need to be tuned based on the drives that you are using. Running the `aio_bench_perf_sweep.py `__ script can help you to find optimum parameters. See the `issue `__ for more information on how to parse the information. .. _deepspeed-activation-checkpointing: diff --git a/docs/source-pytorch/common/lightning_module.rst b/docs/source-pytorch/common/lightning_module.rst index 15e3af75d7aec..e3a8edccc3a44 100644 --- a/docs/source-pytorch/common/lightning_module.rst +++ b/docs/source-pytorch/common/lightning_module.rst @@ -1094,7 +1094,7 @@ for more information. on_train_epoch_start() - for batch in train_dataloader(): + for batch_idx, batch in enumerate(train_dataloader()): on_train_batch_start() on_before_batch_transfer() diff --git a/docs/source-pytorch/common/tbptt.rst b/docs/source-pytorch/common/tbptt.rst index 063ef8c33d319..04b8ea33b9235 100644 --- a/docs/source-pytorch/common/tbptt.rst +++ b/docs/source-pytorch/common/tbptt.rst @@ -12,48 +12,91 @@ hidden states should be kept in-between each time-dimension split. .. code-block:: python import torch + import torch.nn as nn + import torch.nn.functional as F import torch.optim as optim - import pytorch_lightning as pl - from pytorch_lightning import LightningModule + from torch.utils.data import Dataset, DataLoader - class LitModel(LightningModule): + import lightning as L + + + class AverageDataset(Dataset): + def __init__(self, dataset_len=300, sequence_len=100): + self.dataset_len = dataset_len + self.sequence_len = sequence_len + self.input_seq = torch.randn(dataset_len, sequence_len, 10) + top, bottom = self.input_seq.chunk(2, -1) + self.output_seq = top + bottom.roll(shifts=1, dims=-1) + + def __len__(self): + return self.dataset_len + + def __getitem__(self, item): + return self.input_seq[item], self.output_seq[item] + + + class LitModel(L.LightningModule): def __init__(self): super().__init__() + self.batch_size = 10 + self.in_features = 10 + self.out_features = 5 + self.hidden_dim = 20 + # 1. Switch to manual optimization self.automatic_optimization = False - self.truncated_bptt_steps = 10 - self.my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN + + self.rnn = nn.LSTM(self.in_features, self.hidden_dim, batch_first=True) + self.linear_out = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features) + + def forward(self, x, hs): + seq, hs = self.rnn(x, hs) + return self.linear_out(seq), hs # 2. Remove the `hiddens` argument def training_step(self, batch, batch_idx): - # 3. Split the batch in chunks along the time dimension - split_batches = split_batch(batch, self.truncated_bptt_steps) - - batch_size = 10 - hidden_dim = 20 - hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device) - for split_batch in range(split_batches): - # 4. Perform the optimization in a loop - loss, hiddens = self.my_rnn(split_batch, hiddens) - self.backward(loss) - self.optimizer.step() - self.optimizer.zero_grad() + x, y = batch + split_x, split_y = [ + x.tensor_split(self.truncated_bptt_steps, dim=1), + y.tensor_split(self.truncated_bptt_steps, dim=1) + ] + + hiddens = None + optimizer = self.optimizers() + losses = [] + + # 4. Perform the optimization in a loop + for x, y in zip(split_x, split_y): + y_pred, hiddens = self(x, hiddens) + loss = F.mse_loss(y_pred, y) + + optimizer.zero_grad() + self.manual_backward(loss) + optimizer.step() # 5. "Truncate" - hiddens = hiddens.detach() + hiddens = [h.detach() for h in hiddens] + losses.append(loss.detach()) + + avg_loss = sum(losses) / len(losses) + self.log("train_loss", avg_loss, prog_bar=True) # 6. Remove the return of `hiddens` # Returning loss in manual optimization is not needed return None def configure_optimizers(self): - return optim.Adam(self.my_rnn.parameters(), lr=0.001) + return optim.Adam(self.parameters(), lr=0.001) + + def train_dataloader(self): + return DataLoader(AverageDataset(), batch_size=self.batch_size) + if __name__ == "__main__": model = LitModel() - trainer = pl.Trainer(max_epochs=5) - trainer.fit(model, train_dataloader) # Define your own dataloader + trainer = L.Trainer(max_epochs=5) + trainer.fit(model) diff --git a/docs/source-pytorch/data/alternatives.rst b/docs/source-pytorch/data/alternatives.rst index 976f6f9de7297..02c751fe3134a 100644 --- a/docs/source-pytorch/data/alternatives.rst +++ b/docs/source-pytorch/data/alternatives.rst @@ -90,7 +90,7 @@ the desired GPU in your pipeline. When moving data to a specific device, you can WebDataset ^^^^^^^^^^ -The `WebDataset `__ makes it easy to write I/O pipelines for large datasets. +The `WebDataset `__ makes it easy to write I/O pipelines for large datasets. Datasets can be stored locally or in the cloud. ``WebDataset`` is just an instance of a standard IterableDataset. The webdataset library contains a small wrapper (``WebLoader``) that adds a fluid interface to the DataLoader (and is otherwise identical). diff --git a/docs/source-pytorch/data/iterables.rst b/docs/source-pytorch/data/iterables.rst index 58b7ff42c26e1..759400714d3de 100644 --- a/docs/source-pytorch/data/iterables.rst +++ b/docs/source-pytorch/data/iterables.rst @@ -50,7 +50,7 @@ To choose a different mode, you can use the :class:`~lightning.pytorch.utilities Currently, the ``trainer.predict`` method only supports the ``"sequential"`` mode, while ``trainer.fit`` method does not support it. -Support for this feature is tracked in this `issue `__. +Support for this feature is tracked in this `issue `__. Note that when using the ``"sequential"`` mode, you need to add an additional argument ``dataloader_idx`` to some specific hooks. Lightning will `raise an error `__ informing you of this requirement. diff --git a/docs/source-pytorch/links.rst b/docs/source-pytorch/links.rst index 64ec918bf8e25..5291f9548d9e4 100644 --- a/docs/source-pytorch/links.rst +++ b/docs/source-pytorch/links.rst @@ -1,2 +1,2 @@ -.. _PyTorchJob: https://www.kubeflow.org/docs/components/training/pytorch/ +.. _PyTorchJob: https://www.kubeflow.org/docs/components/trainer/legacy-v1/user-guides/pytorch/ .. _Kubeflow: https://www.kubeflow.org diff --git a/docs/source-pytorch/starter/converting.rst b/docs/source-pytorch/starter/converting.rst index 1b8991f66a214..e7df3850af241 100644 --- a/docs/source-pytorch/starter/converting.rst +++ b/docs/source-pytorch/starter/converting.rst @@ -192,6 +192,6 @@ The predict loop will not be used until you call :meth:`~lightning.pytorch.train model = LitModel() trainer.predict(model) -.. note:: ``model.eval()`` and ``torch.no_grad()`` are called automatically for testing. +.. note:: ``model.eval()`` and ``torch.no_grad()`` are called automatically for predicting. .. tip:: ``trainer.predict()`` loads the best checkpoint automatically by default if checkpointing is enabled. diff --git a/docs/source-pytorch/tuning/profiler_intermediate.rst b/docs/source-pytorch/tuning/profiler_intermediate.rst index 802bfc5e6db4e..87aed86ac3653 100644 --- a/docs/source-pytorch/tuning/profiler_intermediate.rst +++ b/docs/source-pytorch/tuning/profiler_intermediate.rst @@ -55,7 +55,7 @@ The profiler will generate an output like this: Self CPU time total: 1.681ms .. note:: - When using the PyTorch Profiler, wall clock time will not not be representative of the true wall clock time. + When using the PyTorch Profiler, wall clock time will not be representative of the true wall clock time. This is due to forcing profiled operations to be measured synchronously, when many CUDA ops happen asynchronously. It is recommended to use this Profiler to find bottlenecks/breakdowns, however for end to end wall clock time use the ``SimpleProfiler``. @@ -142,7 +142,7 @@ This profiler will record ``training_step``, ``validation_step``, ``test_step``, The output above shows the profiling for the action ``training_step``. .. note:: - When using the PyTorch Profiler, wall clock time will not not be representative of the true wall clock time. + When using the PyTorch Profiler, wall clock time will not be representative of the true wall clock time. This is due to forcing profiled operations to be measured synchronously, when many CUDA ops happen asynchronously. It is recommended to use this Profiler to find bottlenecks/breakdowns, however for end to end wall clock time use the ``SimpleProfiler``. diff --git a/docs/source-pytorch/versioning.rst b/docs/source-pytorch/versioning.rst index d923b01c7edb3..10c6ec2fdf8e5 100644 --- a/docs/source-pytorch/versioning.rst +++ b/docs/source-pytorch/versioning.rst @@ -61,8 +61,8 @@ For API removal, renaming or other forms of backwards-incompatible changes, the #. From that version onward, the deprecation warning gets converted into a helpful error, which will remain until next major release. This policy is not strict. Shorter or longer deprecation cycles may apply to some cases. -For example, in the past DDP2 was removed without a deprecation process because the feature was broken and unusable beyond fixing as discussed in `#12584 `_. -Also, `#10410 `_ is an example that a longer deprecation applied to. We deprecated the accelerator arguments, such as ``Trainer(gpus=...)``, in 1.7, however, because the APIs were so core that they would impact almost all use cases, we decided not to introduce the breaking change until 2.0. +For example, in the past DDP2 was removed without a deprecation process because the feature was broken and unusable beyond fixing as discussed in `#12584 `_. +Also, `#10410 `_ is an example that a longer deprecation applied to. We deprecated the accelerator arguments, such as ``Trainer(gpus=...)``, in 1.7, however, because the APIs were so core that they would impact almost all use cases, we decided not to introduce the breaking change until 2.0. Compatibility matrix ******************** diff --git a/docs/source-pytorch/visualize/loggers.rst b/docs/source-pytorch/visualize/loggers.rst index bdf95ec1b675e..f4fd5b23b2311 100644 --- a/docs/source-pytorch/visualize/loggers.rst +++ b/docs/source-pytorch/visualize/loggers.rst @@ -54,3 +54,37 @@ Track and Visualize Experiments + +.. _mlflow_logger: + +MLflow Logger +------------- + +The MLflow logger in PyTorch Lightning now includes a `checkpoint_path_prefix` parameter. This parameter allows you to prefix the checkpoint artifact's path when logging checkpoints as artifacts. + +Example usage: + +.. code-block:: python + + import lightning as L + from lightning.pytorch.loggers import MLFlowLogger + + mlf_logger = MLFlowLogger( + experiment_name="lightning_logs", + tracking_uri="file:./ml-runs", + checkpoint_path_prefix="my_prefix" + ) + trainer = L.Trainer(logger=mlf_logger) + + # Your LightningModule definition + class LitModel(L.LightningModule): + def training_step(self, batch, batch_idx): + # example + self.logger.experiment.whatever_ml_flow_supports(...) + + def any_lightning_module_function_or_hook(self): + self.logger.experiment.whatever_ml_flow_supports(...) + + # Train your model + model = LitModel() + trainer.fit(model) diff --git a/examples/fabric/build_your_own_trainer/run.py b/examples/fabric/build_your_own_trainer/run.py index c0c2ff28ddc41..936b590f5041a 100644 --- a/examples/fabric/build_your_own_trainer/run.py +++ b/examples/fabric/build_your_own_trainer/run.py @@ -1,8 +1,9 @@ -import lightning as L import torch from torchmetrics.functional.classification.accuracy import accuracy from trainer import MyCustomTrainer +import lightning as L + class MNISTModule(L.LightningModule): def __init__(self) -> None: diff --git a/examples/fabric/build_your_own_trainer/trainer.py b/examples/fabric/build_your_own_trainer/trainer.py index f4f31c114f084..d9d081a2aea69 100644 --- a/examples/fabric/build_your_own_trainer/trainer.py +++ b/examples/fabric/build_your_own_trainer/trainer.py @@ -3,15 +3,16 @@ from functools import partial from typing import Any, Literal, Optional, Union, cast -import lightning as L import torch +from lightning_utilities import apply_to_collection +from tqdm import tqdm + +import lightning as L from lightning.fabric.accelerators import Accelerator from lightning.fabric.loggers import Logger from lightning.fabric.strategies import Strategy from lightning.fabric.wrappers import _unwrap_objects from lightning.pytorch.utilities.model_helpers import is_overridden -from lightning_utilities import apply_to_collection -from tqdm import tqdm class MyCustomTrainer: diff --git a/examples/fabric/dcgan/train_fabric.py b/examples/fabric/dcgan/train_fabric.py index f7a18b2b5bc17..66f11e1c6fcfe 100644 --- a/examples/fabric/dcgan/train_fabric.py +++ b/examples/fabric/dcgan/train_fabric.py @@ -16,9 +16,10 @@ import torch.utils.data import torchvision.transforms as transforms import torchvision.utils -from lightning.fabric import Fabric, seed_everything from torchvision.datasets import CelebA +from lightning.fabric import Fabric, seed_everything + # Root directory for dataset dataroot = "data/" # Number of workers for dataloader diff --git a/examples/fabric/fp8_distributed_transformer/train.py b/examples/fabric/fp8_distributed_transformer/train.py index ba88603268945..a30e2de2fc5ed 100644 --- a/examples/fabric/fp8_distributed_transformer/train.py +++ b/examples/fabric/fp8_distributed_transformer/train.py @@ -1,15 +1,16 @@ -import lightning as L import torch import torch.nn as nn import torch.nn.functional as F -from lightning.fabric.strategies import ModelParallelStrategy -from lightning.pytorch.demos import Transformer, WikiText2 from torch.distributed._composable.fsdp.fully_shard import fully_shard from torch.distributed.device_mesh import DeviceMesh from torch.utils.data import DataLoader from torchao.float8 import Float8LinearConfig, convert_to_float8_training from tqdm import tqdm +import lightning as L +from lightning.fabric.strategies import ModelParallelStrategy +from lightning.pytorch.demos import Transformer, WikiText2 + def configure_model(model: nn.Module, device_mesh: DeviceMesh) -> nn.Module: float8_config = Float8LinearConfig( diff --git a/examples/fabric/image_classifier/train_fabric.py b/examples/fabric/image_classifier/train_fabric.py index 02487a65e3989..d207595e9d2ba 100644 --- a/examples/fabric/image_classifier/train_fabric.py +++ b/examples/fabric/image_classifier/train_fabric.py @@ -36,11 +36,12 @@ import torch.nn.functional as F import torch.optim as optim import torchvision.transforms as T -from lightning.fabric import Fabric, seed_everything from torch.optim.lr_scheduler import StepLR from torchmetrics.classification import Accuracy from torchvision.datasets import MNIST +from lightning.fabric import Fabric, seed_everything + DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "..", "Datasets") diff --git a/examples/fabric/kfold_cv/train_fabric.py b/examples/fabric/kfold_cv/train_fabric.py index b3aa08e9aae9b..05d9885190dbc 100644 --- a/examples/fabric/kfold_cv/train_fabric.py +++ b/examples/fabric/kfold_cv/train_fabric.py @@ -20,12 +20,13 @@ import torch.nn.functional as F import torch.optim as optim import torchvision.transforms as T -from lightning.fabric import Fabric, seed_everything from sklearn import model_selection from torch.utils.data import DataLoader, SubsetRandomSampler from torchmetrics.classification import Accuracy from torchvision.datasets import MNIST +from lightning.fabric import Fabric, seed_everything + DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "..", "Datasets") diff --git a/examples/fabric/language_model/train.py b/examples/fabric/language_model/train.py index cafe6ceeb18b1..01947893be926 100644 --- a/examples/fabric/language_model/train.py +++ b/examples/fabric/language_model/train.py @@ -1,9 +1,10 @@ -import lightning as L import torch import torch.nn.functional as F -from lightning.pytorch.demos import Transformer, WikiText2 from torch.utils.data import DataLoader, random_split +import lightning as L +from lightning.pytorch.demos import Transformer, WikiText2 + def main(): L.seed_everything(42) diff --git a/examples/fabric/meta_learning/train_fabric.py b/examples/fabric/meta_learning/train_fabric.py index 3cf1390477aeb..203155f7b2ada 100644 --- a/examples/fabric/meta_learning/train_fabric.py +++ b/examples/fabric/meta_learning/train_fabric.py @@ -18,6 +18,7 @@ import cherry import learn2learn as l2l import torch + from lightning.fabric import Fabric, seed_everything diff --git a/examples/fabric/reinforcement_learning/rl/agent.py b/examples/fabric/reinforcement_learning/rl/agent.py index 16a4cd6d86c73..b3d024d720d11 100644 --- a/examples/fabric/reinforcement_learning/rl/agent.py +++ b/examples/fabric/reinforcement_learning/rl/agent.py @@ -3,11 +3,11 @@ import gymnasium as gym import torch import torch.nn.functional as F -from lightning.pytorch import LightningModule from torch import Tensor from torch.distributions import Categorical from torchmetrics import MeanMetric +from lightning.pytorch import LightningModule from rl.loss import entropy_loss, policy_loss, value_loss from rl.utils import layer_init diff --git a/examples/fabric/reinforcement_learning/train_fabric.py b/examples/fabric/reinforcement_learning/train_fabric.py index 4df52d7cd0455..7c9536fbd9532 100644 --- a/examples/fabric/reinforcement_learning/train_fabric.py +++ b/examples/fabric/reinforcement_learning/train_fabric.py @@ -25,13 +25,14 @@ import gymnasium as gym import torch import torchmetrics -from lightning.fabric import Fabric -from lightning.fabric.loggers import TensorBoardLogger from rl.agent import PPOLightningAgent from rl.utils import linear_annealing, make_env, parse_args, test from torch import Tensor from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler +from lightning.fabric import Fabric +from lightning.fabric.loggers import TensorBoardLogger + def train( fabric: Fabric, diff --git a/examples/fabric/reinforcement_learning/train_fabric_decoupled.py b/examples/fabric/reinforcement_learning/train_fabric_decoupled.py index 7150ac3a12529..f3b0c74cb9064 100644 --- a/examples/fabric/reinforcement_learning/train_fabric_decoupled.py +++ b/examples/fabric/reinforcement_learning/train_fabric_decoupled.py @@ -25,17 +25,18 @@ import gymnasium as gym import torch -from lightning.fabric import Fabric -from lightning.fabric.loggers import TensorBoardLogger -from lightning.fabric.plugins.collectives import TorchCollective -from lightning.fabric.plugins.collectives.collective import CollectibleGroup -from lightning.fabric.strategies import DDPStrategy from rl.agent import PPOLightningAgent from rl.utils import linear_annealing, make_env, parse_args, test from torch.distributed.algorithms.join import Join from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler from torchmetrics import MeanMetric +from lightning.fabric import Fabric +from lightning.fabric.loggers import TensorBoardLogger +from lightning.fabric.plugins.collectives import TorchCollective +from lightning.fabric.plugins.collectives.collective import CollectibleGroup +from lightning.fabric.strategies import DDPStrategy + @torch.no_grad() def player(args, world_collective: TorchCollective, player_trainer_collective: TorchCollective): @@ -273,7 +274,7 @@ def trainer( if group_rank == 0: metrics = {} - # Lerning rate annealing + # Learning rate annealing if args.anneal_lr: linear_annealing(optimizer, update, num_updates, args.learning_rate) if group_rank == 0: diff --git a/examples/fabric/tensor_parallel/README.md b/examples/fabric/tensor_parallel/README.md index e66d9acd2848b..1f551109cc5e7 100644 --- a/examples/fabric/tensor_parallel/README.md +++ b/examples/fabric/tensor_parallel/README.md @@ -41,5 +41,5 @@ Training successfully completed! Peak memory usage: 17.95 GB ``` -> \[!NOTE\] +> [!NOTE] > The `ModelParallelStrategy` is experimental and subject to change. Report issues on [GitHub](https://github.com/Lightning-AI/pytorch-lightning/issues). diff --git a/examples/fabric/tensor_parallel/train.py b/examples/fabric/tensor_parallel/train.py index 4a98f12cf6168..35ee9074f18a8 100644 --- a/examples/fabric/tensor_parallel/train.py +++ b/examples/fabric/tensor_parallel/train.py @@ -1,13 +1,14 @@ -import lightning as L import torch import torch.nn.functional as F from data import RandomTokenDataset -from lightning.fabric.strategies import ModelParallelStrategy from model import ModelArgs, Transformer from parallelism import parallelize from torch.distributed.tensor.parallel import loss_parallel from torch.utils.data import DataLoader +import lightning as L +from lightning.fabric.strategies import ModelParallelStrategy + def train(): strategy = ModelParallelStrategy( diff --git a/examples/pytorch/basics/autoencoder.py b/examples/pytorch/basics/autoencoder.py index d6f594b12f57b..332c9a811e3e4 100644 --- a/examples/pytorch/basics/autoencoder.py +++ b/examples/pytorch/basics/autoencoder.py @@ -22,13 +22,14 @@ import torch import torch.nn.functional as F +from torch import nn +from torch.utils.data import DataLoader, random_split + from lightning.pytorch import LightningDataModule, LightningModule, Trainer, callbacks, cli_lightning_logo from lightning.pytorch.cli import LightningCLI from lightning.pytorch.demos.mnist_datamodule import MNIST from lightning.pytorch.utilities import rank_zero_only from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE -from torch import nn -from torch.utils.data import DataLoader, random_split if _TORCHVISION_AVAILABLE: import torchvision diff --git a/examples/pytorch/basics/backbone_image_classifier.py b/examples/pytorch/basics/backbone_image_classifier.py index fceb97dc41cff..965f636d7fc0b 100644 --- a/examples/pytorch/basics/backbone_image_classifier.py +++ b/examples/pytorch/basics/backbone_image_classifier.py @@ -21,12 +21,13 @@ from typing import Optional import torch +from torch.nn import functional as F +from torch.utils.data import DataLoader, random_split + from lightning.pytorch import LightningDataModule, LightningModule, cli_lightning_logo from lightning.pytorch.cli import LightningCLI from lightning.pytorch.demos.mnist_datamodule import MNIST from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE -from torch.nn import functional as F -from torch.utils.data import DataLoader, random_split if _TORCHVISION_AVAILABLE: from torchvision import transforms diff --git a/examples/pytorch/basics/profiler_example.py b/examples/pytorch/basics/profiler_example.py index 5b9ba08fec761..366aaecefe7e4 100644 --- a/examples/pytorch/basics/profiler_example.py +++ b/examples/pytorch/basics/profiler_example.py @@ -28,6 +28,7 @@ import torch import torchvision import torchvision.transforms as T + from lightning.pytorch import LightningDataModule, LightningModule, cli_lightning_logo from lightning.pytorch.cli import LightningCLI from lightning.pytorch.profilers.pytorch import PyTorchProfiler diff --git a/examples/pytorch/basics/transformer.py b/examples/pytorch/basics/transformer.py index 93cb39d829acc..dbd990d7f2759 100644 --- a/examples/pytorch/basics/transformer.py +++ b/examples/pytorch/basics/transformer.py @@ -1,9 +1,10 @@ -import lightning as L import torch import torch.nn.functional as F -from lightning.pytorch.demos import Transformer, WikiText2 from torch.utils.data import DataLoader, random_split +import lightning as L +from lightning.pytorch.demos import Transformer, WikiText2 + class LanguageModel(L.LightningModule): def __init__(self, vocab_size): diff --git a/examples/pytorch/bug_report/bug_report_model.py b/examples/pytorch/bug_report/bug_report_model.py index aa3f4cad710fe..551ea21721754 100644 --- a/examples/pytorch/bug_report/bug_report_model.py +++ b/examples/pytorch/bug_report/bug_report_model.py @@ -1,9 +1,10 @@ import os import torch -from lightning.pytorch import LightningModule, Trainer from torch.utils.data import DataLoader, Dataset +from lightning.pytorch import LightningModule, Trainer + class RandomDataset(Dataset): def __init__(self, size, length): diff --git a/examples/pytorch/domain_templates/computer_vision_fine_tuning.py b/examples/pytorch/domain_templates/computer_vision_fine_tuning.py index d03bb0c4edd16..69721214748ee 100644 --- a/examples/pytorch/domain_templates/computer_vision_fine_tuning.py +++ b/examples/pytorch/domain_templates/computer_vision_fine_tuning.py @@ -46,11 +46,6 @@ import torch import torch.nn.functional as F -from lightning.pytorch import LightningDataModule, LightningModule, cli_lightning_logo -from lightning.pytorch.callbacks.finetuning import BaseFinetuning -from lightning.pytorch.cli import LightningCLI -from lightning.pytorch.utilities import rank_zero_info -from lightning.pytorch.utilities.model_helpers import get_torchvision_model from torch import nn, optim from torch.optim.lr_scheduler import MultiStepLR from torch.optim.optimizer import Optimizer @@ -60,6 +55,12 @@ from torchvision.datasets import ImageFolder from torchvision.datasets.utils import download_and_extract_archive +from lightning.pytorch import LightningDataModule, LightningModule, cli_lightning_logo +from lightning.pytorch.callbacks.finetuning import BaseFinetuning +from lightning.pytorch.cli import LightningCLI +from lightning.pytorch.utilities import rank_zero_info +from lightning.pytorch.utilities.model_helpers import get_torchvision_model + log = logging.getLogger(__name__) DATA_URL = "https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip" diff --git a/examples/pytorch/domain_templates/generative_adversarial_net.py b/examples/pytorch/domain_templates/generative_adversarial_net.py index 417e167df0d93..7ce7682d82c76 100644 --- a/examples/pytorch/domain_templates/generative_adversarial_net.py +++ b/examples/pytorch/domain_templates/generative_adversarial_net.py @@ -25,6 +25,7 @@ import torch import torch.nn as nn import torch.nn.functional as F + from lightning.pytorch import cli_lightning_logo from lightning.pytorch.core import LightningModule from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule diff --git a/examples/pytorch/domain_templates/imagenet.py b/examples/pytorch/domain_templates/imagenet.py index 49f9b509ce6e7..fd2050e2ed38b 100644 --- a/examples/pytorch/domain_templates/imagenet.py +++ b/examples/pytorch/domain_templates/imagenet.py @@ -43,13 +43,14 @@ import torch.utils.data.distributed import torchvision.datasets as datasets import torchvision.transforms as transforms +from torch.utils.data import Dataset +from torchmetrics import Accuracy + from lightning.pytorch import LightningModule from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar from lightning.pytorch.cli import LightningCLI from lightning.pytorch.strategies import ParallelStrategy from lightning.pytorch.utilities.model_helpers import get_torchvision_model -from torch.utils.data import Dataset -from torchmetrics import Accuracy class ImageNetLightningModel(LightningModule): diff --git a/examples/pytorch/domain_templates/reinforce_learn_Qnet.py b/examples/pytorch/domain_templates/reinforce_learn_Qnet.py index b3bfaaea93e7f..193e6495a4182 100644 --- a/examples/pytorch/domain_templates/reinforce_learn_Qnet.py +++ b/examples/pytorch/domain_templates/reinforce_learn_Qnet.py @@ -41,11 +41,12 @@ import torch import torch.nn as nn import torch.optim as optim -from lightning.pytorch import LightningModule, Trainer, cli_lightning_logo, seed_everything from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader from torch.utils.data.dataset import IterableDataset +from lightning.pytorch import LightningModule, Trainer, cli_lightning_logo, seed_everything + class DQN(nn.Module): """Simple MLP network. diff --git a/examples/pytorch/domain_templates/reinforce_learn_ppo.py b/examples/pytorch/domain_templates/reinforce_learn_ppo.py index 1fb083894c284..af503dbb925cd 100644 --- a/examples/pytorch/domain_templates/reinforce_learn_ppo.py +++ b/examples/pytorch/domain_templates/reinforce_learn_ppo.py @@ -35,12 +35,13 @@ import gym import torch -from lightning.pytorch import LightningModule, Trainer, cli_lightning_logo, seed_everything from torch import nn from torch.distributions import Categorical, Normal from torch.optim.optimizer import Optimizer from torch.utils.data import DataLoader, IterableDataset +from lightning.pytorch import LightningModule, Trainer, cli_lightning_logo, seed_everything + def create_mlp(input_shape: tuple[int], n_actions: int, hidden_size: int = 128): """Simple Multi-Layer Perceptron network.""" diff --git a/examples/pytorch/domain_templates/semantic_segmentation.py b/examples/pytorch/domain_templates/semantic_segmentation.py index 12ecbeeb5f0a9..0f19349f7a0fc 100644 --- a/examples/pytorch/domain_templates/semantic_segmentation.py +++ b/examples/pytorch/domain_templates/semantic_segmentation.py @@ -19,11 +19,12 @@ import torch import torch.nn.functional as F import torchvision.transforms as transforms -from lightning.pytorch import LightningModule, Trainer, cli_lightning_logo from PIL import Image from torch import nn from torch.utils.data import DataLoader, Dataset +from lightning.pytorch import LightningModule, Trainer, cli_lightning_logo + DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1) DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33) diff --git a/examples/pytorch/fp8_distributed_transformer/train.py b/examples/pytorch/fp8_distributed_transformer/train.py index 6c7be98ee7dbd..78aa6f13be6c2 100644 --- a/examples/pytorch/fp8_distributed_transformer/train.py +++ b/examples/pytorch/fp8_distributed_transformer/train.py @@ -1,13 +1,14 @@ -import lightning as L import torch import torch.nn as nn import torch.nn.functional as F -from lightning.pytorch.demos import Transformer, WikiText2 -from lightning.pytorch.strategies import ModelParallelStrategy from torch.distributed._composable.fsdp.fully_shard import fully_shard from torch.utils.data import DataLoader from torchao.float8 import Float8LinearConfig, convert_to_float8_training +import lightning as L +from lightning.pytorch.demos import Transformer, WikiText2 +from lightning.pytorch.strategies import ModelParallelStrategy + class LanguageModel(L.LightningModule): def __init__(self, vocab_size): diff --git a/examples/pytorch/hpu/mnist_sample.py b/examples/pytorch/hpu/mnist_sample.py index 4d2e22c03fe7e..0d04074519c8c 100644 --- a/examples/pytorch/hpu/mnist_sample.py +++ b/examples/pytorch/hpu/mnist_sample.py @@ -13,11 +13,12 @@ # limitations under the License. import torch from jsonargparse import lazy_instance +from lightning_habana import HPUPrecisionPlugin +from torch.nn import functional as F + from lightning.pytorch import LightningModule from lightning.pytorch.cli import LightningCLI from lightning.pytorch.demos.mnist_datamodule import MNISTDataModule -from lightning_habana import HPUPrecisionPlugin -from torch.nn import functional as F class LitClassifier(LightningModule): diff --git a/examples/pytorch/servable_module/production.py b/examples/pytorch/servable_module/production.py index da0c42d12a865..854ff1176b619 100644 --- a/examples/pytorch/servable_module/production.py +++ b/examples/pytorch/servable_module/production.py @@ -8,11 +8,12 @@ import torch import torchvision import torchvision.transforms as T +from PIL import Image as PILImage + from lightning.pytorch import LightningDataModule, LightningModule, cli_lightning_logo from lightning.pytorch.cli import LightningCLI from lightning.pytorch.serve import ServableModule, ServableModuleValidator from lightning.pytorch.utilities.model_helpers import get_torchvision_model -from PIL import Image as PILImage DATASETS_PATH = path.join(path.dirname(__file__), "..", "..", "Datasets") diff --git a/examples/pytorch/tensor_parallel/README.md b/examples/pytorch/tensor_parallel/README.md index d8b81b6de1bff..92e6fcb038268 100644 --- a/examples/pytorch/tensor_parallel/README.md +++ b/examples/pytorch/tensor_parallel/README.md @@ -45,5 +45,5 @@ Training successfully completed! Peak memory usage: 36.73 GB ``` -> \[!NOTE\] +> [!NOTE] > The `ModelParallelStrategy` is experimental and subject to change. Report issues on [GitHub](https://github.com/Lightning-AI/pytorch-lightning/issues). diff --git a/examples/pytorch/tensor_parallel/train.py b/examples/pytorch/tensor_parallel/train.py index 6a91e1242e4af..f9b971fd39f82 100644 --- a/examples/pytorch/tensor_parallel/train.py +++ b/examples/pytorch/tensor_parallel/train.py @@ -1,13 +1,14 @@ -import lightning as L import torch import torch.nn.functional as F from data import RandomTokenDataset -from lightning.pytorch.strategies import ModelParallelStrategy from model import ModelArgs, Transformer from parallelism import parallelize from torch.distributed.tensor.parallel import loss_parallel from torch.utils.data import DataLoader +import lightning as L +from lightning.pytorch.strategies import ModelParallelStrategy + class Llama3(L.LightningModule): def __init__(self): diff --git a/requirements/ci.txt b/requirements/ci.txt index cdebc301790e9..6ea6bbadbdc96 100644 --- a/requirements/ci.txt +++ b/requirements/ci.txt @@ -1,7 +1,8 @@ setuptools <70.1.1 wheel <0.44.0 awscli >=1.30.0, <1.31.0 -twine ==4.0.1 +twine ==6.0.1 importlib-metadata <8.0.0 wget +pkginfo ==1.12.0 packaging <24.2 diff --git a/requirements/collect_env_details.py b/requirements/collect_env_details.py index 5e6f9ba3dd350..392dd637fa8fd 100644 --- a/requirements/collect_env_details.py +++ b/requirements/collect_env_details.py @@ -70,7 +70,7 @@ def nice_print(details: dict, level: int = 0) -> list: lines += [level * LEVEL_OFFSET + key] lines += [(level + 1) * LEVEL_OFFSET + "- " + v for v in details[k]] else: - template = "{:%is} {}" % KEY_PADDING + template = "{:%is} {}" % KEY_PADDING # noqa: UP031 key_val = template.format(key, details[k]) lines += [(level * LEVEL_OFFSET) + key_val] return lines diff --git a/requirements/fabric/base.txt b/requirements/fabric/base.txt index 42c055e85ca7d..70cd75c1c0d37 100644 --- a/requirements/fabric/base.txt +++ b/requirements/fabric/base.txt @@ -4,5 +4,5 @@ torch >=2.1.0, <2.6.0 fsspec[http] >=2022.5.0, <2024.4.0 packaging >=20.0, <=23.1 -typing-extensions >=4.4.0, <4.10.0 +typing-extensions >=4.4.0, <4.11.0 lightning-utilities >=0.10.0, <0.12.0 diff --git a/requirements/fabric/strategies.txt b/requirements/fabric/strategies.txt index 394aceb39cd6b..5b7f170cbd866 100644 --- a/requirements/fabric/strategies.txt +++ b/requirements/fabric/strategies.txt @@ -6,5 +6,4 @@ # note: is a bug around 0.10 with `MPS_Accelerator must implement all abstract methods` # shall be resolved by https://github.com/microsoft/DeepSpeed/issues/4372 deepspeed >=0.8.2, <=0.9.3; platform_system != "Windows" and platform_system != "Darwin" # strict -bitsandbytes >=0.44.0,<0.44.2; sys_platform == 'linux' or sys_platform == 'win32' -bitsandbytes >=0.42.0,<0.43.0 ; sys_platform == 'darwin' +bitsandbytes >=0.45.2,<0.45.3; platform_system != "Darwin" diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt index 94aca759c37e2..cdf3cc03e2985 100644 --- a/requirements/pytorch/base.txt +++ b/requirements/pytorch/base.txt @@ -7,5 +7,5 @@ PyYAML >=5.4, <6.1.0 fsspec[http] >=2022.5.0, <2024.4.0 torchmetrics >=0.7.0, <1.5.0 # needed for using fixed compare_version packaging >=20.0, <=23.1 -typing-extensions >=4.4.0, <4.10.0 +typing-extensions >=4.4.0, <4.11.0 lightning-utilities >=0.10.0, <0.12.0 diff --git a/requirements/pytorch/extra.txt b/requirements/pytorch/extra.txt index 70c6548817b4a..e14cb38297caa 100644 --- a/requirements/pytorch/extra.txt +++ b/requirements/pytorch/extra.txt @@ -8,5 +8,4 @@ hydra-core >=1.2.0, <1.4.0 jsonargparse[signatures] >=4.27.7, <=4.35.0 rich >=12.3.0, <13.6.0 tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute -bitsandbytes >=0.44.0,<0.44.2; sys_platform == 'linux' or sys_platform == 'win32' -bitsandbytes >=0.42.0,<0.43.0 ; sys_platform == 'darwin' +bitsandbytes >=0.45.2,<0.45.3; platform_system != "Darwin" diff --git a/src/lightning/__setup__.py b/src/lightning/__setup__.py index 2d3bb0e7d1f33..b71410c4f18cc 100644 --- a/src/lightning/__setup__.py +++ b/src/lightning/__setup__.py @@ -98,14 +98,13 @@ def _setup_args() -> dict[str, Any]: "entry_points": { "console_scripts": [ "fabric = lightning.fabric.cli:_main", - "lightning = lightning.fabric.cli:_legacy_main", ], }, "setup_requires": [], "install_requires": install_requires, "extras_require": _prepare_extras(), "project_urls": { - "Bug Tracker": "https://github.com/Lightning-AI/lightning/issues", + "Bug Tracker": "https://github.com/Lightning-AI/pytorch-lightning/issues", "Documentation": "https://lightning.ai/lightning-docs", "Source Code": "https://github.com/Lightning-AI/lightning", }, diff --git a/src/lightning/data/README.md b/src/lightning/data/README.md index 525a7e14f894d..c61e7eacf26f2 100644 --- a/src/lightning/data/README.md +++ b/src/lightning/data/README.md @@ -31,11 +31,11 @@ Find the reproducible [Studio Benchmark](https://lightning.ai/lightning-ai/studi ### Imagenet-1.2M Streaming from AWS S3 -| Framework | Images / sec 1st Epoch (float32) | Images / sec 2nd Epoch (float32) | Images / sec 1st Epoch (torch16) | Images / sec 2nd Epoch (torch16) | -| ----------- | --------------------------------- | ---------------------------------- | -------------------------------- | -------------------------------- | -| PL Data | **5800.34** | **6589.98** | **6282.17** | **7221.88** | -| Web Dataset | 3134.42 | 3924.95 | 3343.40 | 4424.62 | -| Mosaic ML | 2898.61 | 5099.93 | 2809.69 | 5158.98 | +| Framework | Images / sec 1st Epoch (float32) | Images / sec 2nd Epoch (float32) | Images / sec 1st Epoch (torch16) | Images / sec 2nd Epoch (torch16) | +| ----------- | -------------------------------- | -------------------------------- | -------------------------------- | -------------------------------- | +| PL Data | **5800.34** | **6589.98** | **6282.17** | **7221.88** | +| Web Dataset | 3134.42 | 3924.95 | 3343.40 | 4424.62 | +| Mosaic ML | 2898.61 | 5099.93 | 2809.69 | 5158.98 | Higher is better. diff --git a/src/lightning/fabric/CHANGELOG.md b/src/lightning/fabric/CHANGELOG.md index 7e401aff67671..0c922f0d31d0e 100644 --- a/src/lightning/fabric/CHANGELOG.md +++ b/src/lightning/fabric/CHANGELOG.md @@ -4,6 +4,18 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). + +## [2.5.1] - 2025-03-18 + +### Changed + +- Added logging support for list of dicts without collapsing to a single key ([#19957](https://github.com/Lightning-AI/pytorch-lightning/issues/19957)) + +### Removed + +- Removed legacy support for `lightning run model`. Use `fabric run` instead. ([#20588](https://github.com/Lightning-AI/pytorch-lightning/pull/20588)) + + ## [2.5.0] - 2024-12-19 ### Added @@ -331,7 +343,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fixed computing the next version folder in `CSVLogger` ([#17139](https://github.com/Lightning-AI/lightning/pull/17139)) -- Fixed inconsistent settings for FSDP Precision ([#17670](https://github.com/Lightning-AI/lightning/issues/17670)) +- Fixed inconsistent settings for FSDP Precision ([#17670](https://github.com/Lightning-AI/pytorch-lightning/issues/17670)) ## [2.0.2] - 2023-04-24 diff --git a/src/lightning/fabric/_graveyard/tpu.py b/src/lightning/fabric/_graveyard/tpu.py index c537ffc032322..138830e4e3b1b 100644 --- a/src/lightning/fabric/_graveyard/tpu.py +++ b/src/lightning/fabric/_graveyard/tpu.py @@ -71,7 +71,7 @@ class TPUPrecision(XLAPrecision): def __init__(self, *args: Any, **kwargs: Any) -> None: rank_zero_deprecation( - "The `TPUPrecision` class is deprecated. Use `lightning.fabric.plugins.precision.XLAPrecision`" " instead." + "The `TPUPrecision` class is deprecated. Use `lightning.fabric.plugins.precision.XLAPrecision` instead." ) super().__init__(precision="32-true") @@ -85,8 +85,7 @@ class XLABf16Precision(XLAPrecision): def __init__(self, *args: Any, **kwargs: Any) -> None: rank_zero_deprecation( - "The `XLABf16Precision` class is deprecated. Use" - " `lightning.fabric.plugins.precision.XLAPrecision` instead." + "The `XLABf16Precision` class is deprecated. Use `lightning.fabric.plugins.precision.XLAPrecision` instead." ) super().__init__(precision="bf16-true") @@ -100,8 +99,7 @@ class TPUBf16Precision(XLABf16Precision): def __init__(self, *args: Any, **kwargs: Any) -> None: rank_zero_deprecation( - "The `TPUBf16Precision` class is deprecated. Use" - " `lightning.fabric.plugins.precision.XLAPrecision` instead." + "The `TPUBf16Precision` class is deprecated. Use `lightning.fabric.plugins.precision.XLAPrecision` instead." ) super().__init__(*args, **kwargs) diff --git a/src/lightning/fabric/cli.py b/src/lightning/fabric/cli.py index 5f18884e83d79..2268614abb97b 100644 --- a/src/lightning/fabric/cli.py +++ b/src/lightning/fabric/cli.py @@ -14,8 +14,6 @@ import logging import os import re -import subprocess -import sys from argparse import Namespace from typing import Any, Optional @@ -50,25 +48,6 @@ def _get_supported_strategies() -> list[str]: if _CLICK_AVAILABLE: import click - def _legacy_main() -> None: - """Legacy CLI handler for fabric. - - Raises deprecation warning and runs through fabric cli if necessary, else runs the entrypoint directly - - """ - hparams = sys.argv[1:] - if len(hparams) >= 2 and hparams[0] == "run" and hparams[1] == "model": - print( - "`lightning run model` is deprecated and will be removed in future versions." - " Please call `fabric run` instead." - ) - _main() - return - - if _LIGHTNING_SDK_AVAILABLE: - subprocess.run([sys.executable, "-m", "lightning_sdk.cli.entrypoint"] + hparams) - return - @click.group() def _main() -> None: pass diff --git a/src/lightning/fabric/connector.py b/src/lightning/fabric/connector.py index 0ade7f69c3629..85d30a07ce207 100644 --- a/src/lightning/fabric/connector.py +++ b/src/lightning/fabric/connector.py @@ -251,8 +251,7 @@ def _check_config_and_set_final_flags( if plugins_flags_types.get(Precision.__name__) and precision_input is not None: raise ValueError( - f"Received both `precision={precision_input}` and `plugins={self._precision_instance}`." - f" Choose one." + f"Received both `precision={precision_input}` and `plugins={self._precision_instance}`. Choose one." ) self._precision_input = "32-true" if precision_input is None else precision_input diff --git a/src/lightning/fabric/plugins/environments/kubeflow.py b/src/lightning/fabric/plugins/environments/kubeflow.py index ce2dd002e57bd..23a1c0d1753af 100644 --- a/src/lightning/fabric/plugins/environments/kubeflow.py +++ b/src/lightning/fabric/plugins/environments/kubeflow.py @@ -28,7 +28,7 @@ class KubeflowEnvironment(ClusterEnvironment): This environment, unlike others, does not get auto-detected and needs to be passed to the Fabric/Trainer constructor manually. - .. _PyTorchJob: https://www.kubeflow.org/docs/components/training/pytorch/ + .. _PyTorchJob: https://www.kubeflow.org/docs/components/trainer/legacy-v1/user-guides/pytorch/ .. _Kubeflow: https://www.kubeflow.org """ diff --git a/src/lightning/fabric/plugins/precision/bitsandbytes.py b/src/lightning/fabric/plugins/precision/bitsandbytes.py index ecb1d8a442655..b78157d1c4074 100644 --- a/src/lightning/fabric/plugins/precision/bitsandbytes.py +++ b/src/lightning/fabric/plugins/precision/bitsandbytes.py @@ -40,7 +40,7 @@ log = logging.getLogger(__name__) -_BITSANDBYTES_AVAILABLE = RequirementCache("bitsandbytes>=0.42.0") +_BITSANDBYTES_AVAILABLE = RequirementCache("bitsandbytes") class BitsandbytesPrecision(Precision): diff --git a/src/lightning/fabric/strategies/deepspeed.py b/src/lightning/fabric/strategies/deepspeed.py index 1e94fa1166f93..41820c1cc433f 100644 --- a/src/lightning/fabric/strategies/deepspeed.py +++ b/src/lightning/fabric/strategies/deepspeed.py @@ -144,7 +144,7 @@ def __init__( nvme_path: Filesystem path for NVMe device for optimizer/parameter state offloading. optimizer_buffer_count: Number of buffers in buffer pool for optimizer state offloading - when ``offload_optimizer_device`` is set to to ``nvme``. + when ``offload_optimizer_device`` is set to ``nvme``. This should be at least the number of states maintained per parameter by the optimizer. For example, Adam optimizer has 4 states (parameter, gradient, momentum, and variance). @@ -329,8 +329,7 @@ def setup_module_and_optimizers( """ if len(optimizers) != 1: raise ValueError( - f"Currently only one optimizer is supported with DeepSpeed." - f" Got {len(optimizers)} optimizers instead." + f"Currently only one optimizer is supported with DeepSpeed. Got {len(optimizers)} optimizers instead." ) self._deepspeed_engine, optimizer = self._initialize_engine(module, optimizers[0]) diff --git a/src/lightning/fabric/strategies/launchers/multiprocessing.py b/src/lightning/fabric/strategies/launchers/multiprocessing.py index d9b96dca5471d..3b3e180e63f41 100644 --- a/src/lightning/fabric/strategies/launchers/multiprocessing.py +++ b/src/lightning/fabric/strategies/launchers/multiprocessing.py @@ -78,7 +78,7 @@ def __init__( def is_interactive_compatible(self) -> bool: # The start method 'spawn' is not supported in interactive environments # The start method 'fork' is the only one supported in Jupyter environments, with constraints around CUDA - # initialization. For more context, see https://github.com/Lightning-AI/lightning/issues/7550 + # initialization. For more context, see https://github.com/Lightning-AI/pytorch-lightning/issues/7550 return self._start_method == "fork" @override diff --git a/src/lightning/fabric/strategies/launchers/subprocess_script.py b/src/lightning/fabric/strategies/launchers/subprocess_script.py index a28fe971c7ac4..8a78eb3c7dfbf 100644 --- a/src/lightning/fabric/strategies/launchers/subprocess_script.py +++ b/src/lightning/fabric/strategies/launchers/subprocess_script.py @@ -156,7 +156,7 @@ def _check_can_spawn_children(self) -> None: def _basic_subprocess_cmd() -> Sequence[str]: - import __main__ # local import to avoid https://github.com/Lightning-AI/lightning/issues/15218 + import __main__ # local import to avoid https://github.com/Lightning-AI/pytorch-lightning/issues/15218 if __main__.__spec__ is None: # pragma: no-cover return [sys.executable, os.path.abspath(sys.argv[0])] + sys.argv[1:] @@ -167,7 +167,7 @@ def _hydra_subprocess_cmd(local_rank: int) -> tuple[Sequence[str], str]: from hydra.core.hydra_config import HydraConfig from hydra.utils import get_original_cwd, to_absolute_path - import __main__ # local import to avoid https://github.com/Lightning-AI/lightning/issues/15218 + import __main__ # local import to avoid https://github.com/Lightning-AI/pytorch-lightning/issues/15218 # when user is using hydra find the absolute path if __main__.__spec__ is None: # pragma: no-cover diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index ad1fc19074d06..ace23a9c7a2c5 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -292,8 +292,7 @@ def load_checkpoint( if isinstance(state, Optimizer): raise NotImplementedError( - f"Loading a single optimizer object from a checkpoint is not supported yet with" - f" {type(self).__name__}." + f"Loading a single optimizer object from a checkpoint is not supported yet with {type(self).__name__}." ) return _load_checkpoint(path=path, state=state, strict=strict) diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index a1c5a6f6dcd1b..5a9ec1edc1ca8 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -34,6 +34,7 @@ _TORCH_EQUAL_2_4_0 = compare_version("torch", operator.eq, "2.4.0") _TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0") _TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1") +_TORCH_LESS_EQUAL_2_6 = compare_version("torch", operator.le, "2.6.0") _PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) diff --git a/src/lightning/fabric/utilities/logger.py b/src/lightning/fabric/utilities/logger.py index dd2b0a3663fc9..04b9069dd0788 100644 --- a/src/lightning/fabric/utilities/logger.py +++ b/src/lightning/fabric/utilities/logger.py @@ -91,6 +91,8 @@ def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent {'a/b': 123} >>> _flatten_dict({5: {'a': 123}}) {'5/a': 123} + >>> _flatten_dict({"dl": [{"a": 1, "c": 3}, {"b": 2, "d": 5}], "l": [1, 2, 3, 4]}) + {'dl/0/a': 1, 'dl/0/c': 3, 'dl/1/b': 2, 'dl/1/d': 5, 'l': [1, 2, 3, 4]} """ result: dict[str, Any] = {} @@ -103,6 +105,10 @@ def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent if isinstance(v, MutableMapping): result = {**result, **_flatten_dict(v, parent_key=new_key, delimiter=delimiter)} + # Also handle the case where v is a list of dictionaries + elif isinstance(v, list) and all(isinstance(item, MutableMapping) for item in v): + for i, item in enumerate(v): + result = {**result, **_flatten_dict(item, parent_key=f"{new_key}/{i}", delimiter=delimiter)} else: result[new_key] = v return result diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 9f7317c218c30..627e8790cb940 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -4,6 +4,31 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). + +## [2.5.1] - 2025-03-18 + +### Changed + +- Allow LightningCLI to use a customized argument parser class ([#20596](https://github.com/Lightning-AI/pytorch-lightning/pull/20596)) +- Change `wandb` default x-axis to `tensorboard`'s `global_step` when `sync_tensorboard=True` ([#20611](https://github.com/Lightning-AI/pytorch-lightning/pull/20611)) +- Added a new `checkpoint_path_prefix` parameter to the MLflow logger which can control the path to where the MLflow artifacts for the model checkpoints are stored ([#20538](https://github.com/Lightning-AI/pytorch-lightning/pull/20538)) +- CometML logger was updated to support the recent Comet SDK ([#20275](https://github.com/Lightning-AI/pytorch-lightning/pull/20275)) +- bump: testing with latest `torch` 2.6 ([#20509](https://github.com/Lightning-AI/pytorch-lightning/pull/20509)) + +### Fixed + +- Fixed CSVLogger logging hyperparameter at every write which increase latency ([#20594](https://github.com/Lightning-AI/pytorch-lightning/pull/20594)) +- Fixed OverflowError when resuming from checkpoint with an iterable dataset ([#20565](https://github.com/Lightning-AI/pytorch-lightning/issues/20565)) +- Fixed swapped _R_co and _P to prevent type error ([#20508](https://github.com/Lightning-AI/pytorch-lightning/issues/20508)) +- Always call `WandbLogger.experiment` first in `_call_setup_hook` to ensure `tensorboard` logs can sync to `wandb` ([#20610](https://github.com/Lightning-AI/pytorch-lightning/pull/20610)) +- Fixed TBPTT example ([#20528](https://github.com/Lightning-AI/pytorch-lightning/pull/20528)) +- Fixed test compatibility as AdamW became subclass of Adam ([#20574](https://github.com/Lightning-AI/pytorch-lightning/pull/20574)) +- Fixed file extension of model checkpoints uploaded by NeptuneLogger ([#20581](https://github.com/Lightning-AI/pytorch-lightning/pull/20581)) +- Reset trainer variable `should_stop` when `fit` is called ([#19177](https://github.com/Lightning-AI/pytorch-lightning/pull/19177)) +- Fixed making `WandbLogger` upload models from all `ModelCheckpoint` callbacks, not just one ([#20191](https://github.com/Lightning-AI/pytorch-lightning/pull/20191)) +- Error when logging to MLFlow deleted experiment ([#20556](https://github.com/Lightning-AI/pytorch-lightning/pull/20556)) + + ## [2.5.0] - 2024-12-19 ### Added @@ -199,16 +224,16 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed handling checkpoint dirpath suffix in NeptuneLogger ([#18863](https://github.com/Lightning-AI/lightning/pull/18863)) - Fixed an edge case where `ModelCheckpoint` would alternate between versioned and unversioned filename ([#19064](https://github.com/Lightning-AI/lightning/pull/19064)) - Fixed broadcast at initialization in `MPIEnvironment` ([#19074](https://github.com/Lightning-AI/lightning/pull/19074)) -- Fixed the tensor conversion in `self.log` to respect the default dtype ([#19046](https://github.com/Lightning-AI/lightning/issues/19046)) +- Fixed the tensor conversion in `self.log` to respect the default dtype ([#19046](https://github.com/Lightning-AI/pytorch-lightning/issues/19046)) ## [2.1.2] - 2023-11-15 ### Fixed -- Fixed an issue causing permission errors on Windows when attempting to create a symlink for the "last" checkpoint ([#18942](https://github.com/Lightning-AI/lightning/issues/18942)) -- Fixed an issue where Metric instances from `torchmetrics` wouldn't get moved to the device when using FSDP ([#18954](https://github.com/Lightning-AI/lightning/issues/18954)) -- Fixed an issue preventing the user to `Trainer.save_checkpoint()` an FSDP model when `Trainer.test/validate/predict()` ran after `Trainer.fit()` ([#18992](https://github.com/Lightning-AI/lightning/issues/18992)) +- Fixed an issue causing permission errors on Windows when attempting to create a symlink for the "last" checkpoint ([#18942](https://github.com/Lightning-AI/pytorch-lightning/issues/18942)) +- Fixed an issue where Metric instances from `torchmetrics` wouldn't get moved to the device when using FSDP ([#18954](https://github.com/Lightning-AI/pytorch-lightning/issues/18954)) +- Fixed an issue preventing the user to `Trainer.save_checkpoint()` an FSDP model when `Trainer.test/validate/predict()` ran after `Trainer.fit()` ([#18992](https://github.com/Lightning-AI/pytorch-lightning/issues/18992)) ## [2.1.1] - 2023-11-06 @@ -216,10 +241,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed - Fixed an issue when replacing an existing `last.ckpt` file with a symlink ([#18793](https://github.com/Lightning-AI/lightning/pull/18793)) -- Fixed an issue when `BatchSizeFinder` `steps_per_trial` parameter ends up defining how many validation batches to run during the entire training ([#18394](https://github.com/Lightning-AI/lightning/issues/18394)) -- Fixed an issue saving the `last.ckpt` file when using `ModelCheckpoint` on a remote filesystem and no logger is used ([#18867](https://github.com/Lightning-AI/lightning/issues/18867)) +- Fixed an issue when `BatchSizeFinder` `steps_per_trial` parameter ends up defining how many validation batches to run during the entire training ([#18394](https://github.com/Lightning-AI/pytorch-lightning/issues/18394)) +- Fixed an issue saving the `last.ckpt` file when using `ModelCheckpoint` on a remote filesystem and no logger is used ([#18867](https://github.com/Lightning-AI/pytorch-lightning/issues/18867)) - Refined the FSDP saving logic and error messaging when path exists ([#18884](https://github.com/Lightning-AI/lightning/pull/18884)) -- Fixed an issue parsing the version from folders that don't include a version number in `TensorBoardLogger` and `CSVLogger` ([#18897](https://github.com/Lightning-AI/lightning/issues/18897)) +- Fixed an issue parsing the version from folders that don't include a version number in `TensorBoardLogger` and `CSVLogger` ([#18897](https://github.com/Lightning-AI/pytorch-lightning/issues/18897)) ## [2.1.0] - 2023-10-11 diff --git a/src/lightning/pytorch/callbacks/early_stopping.py b/src/lightning/pytorch/callbacks/early_stopping.py index 78c4215f9ce23..d108894f614e6 100644 --- a/src/lightning/pytorch/callbacks/early_stopping.py +++ b/src/lightning/pytorch/callbacks/early_stopping.py @@ -145,7 +145,7 @@ def _validate_condition_metric(self, logs: dict[str, Tensor]) -> bool: error_msg = ( f"Early stopping conditioned on metric `{self.monitor}` which is not available." " Pass in or modify your `EarlyStopping` callback to use any of the following:" - f' `{"`, `".join(list(logs.keys()))}`' + f" `{'`, `'.join(list(logs.keys()))}`" ) if monitor_val is None: diff --git a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py index 5643a038e00c1..375bd15f29051 100644 --- a/src/lightning/pytorch/callbacks/stochastic_weight_avg.py +++ b/src/lightning/pytorch/callbacks/stochastic_weight_avg.py @@ -354,7 +354,7 @@ def _clear_schedulers(trainer: "pl.Trainer") -> None: # Note that this relies on the callback state being restored before the scheduler state is # restored, and doesn't work if restore_checkpoint_after_setup is True, but at the time of # writing that is only True for deepspeed which is already not supported by SWA. - # See https://github.com/Lightning-AI/lightning/issues/11665 for background. + # See https://github.com/Lightning-AI/pytorch-lightning/issues/11665 for background. if trainer.lr_scheduler_configs: assert len(trainer.lr_scheduler_configs) == 1 trainer.lr_scheduler_configs.clear() diff --git a/src/lightning/pytorch/cli.py b/src/lightning/pytorch/cli.py index c79f2481c8af4..75a6347c95356 100644 --- a/src/lightning/pytorch/cli.py +++ b/src/lightning/pytorch/cli.py @@ -314,6 +314,7 @@ def __init__( trainer_defaults: Optional[dict[str, Any]] = None, seed_everything_default: Union[bool, int] = True, parser_kwargs: Optional[Union[dict[str, Any], dict[str, dict[str, Any]]]] = None, + parser_class: type[LightningArgumentParser] = LightningArgumentParser, subclass_mode_model: bool = False, subclass_mode_data: bool = False, args: ArgsType = None, @@ -367,6 +368,7 @@ def __init__( self.trainer_defaults = trainer_defaults or {} self.seed_everything_default = seed_everything_default self.parser_kwargs = parser_kwargs or {} + self.parser_class = parser_class self.auto_configure_optimizers = auto_configure_optimizers self.model_class = model_class @@ -404,7 +406,7 @@ def _setup_parser_kwargs(self, parser_kwargs: dict[str, Any]) -> tuple[dict[str, def init_parser(self, **kwargs: Any) -> LightningArgumentParser: """Method that instantiates the argument parser.""" kwargs.setdefault("dump_header", [f"lightning.pytorch=={pl.__version__}"]) - parser = LightningArgumentParser(**kwargs) + parser = self.parser_class(**kwargs) parser.add_argument( "-c", "--config", action=ActionConfigFile, help="Path to a configuration file in json or yaml format." ) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index f1d1da924eac4..b8624daac3fa3 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -979,7 +979,7 @@ def configure_optimizers(self) -> OptimizerLRScheduler: # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, - # Metric to to monitor for schedulers like `ReduceLROnPlateau` + # Metric to monitor for schedulers like `ReduceLROnPlateau` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping diff --git a/src/lightning/pytorch/loggers/comet.py b/src/lightning/pytorch/loggers/comet.py index 9c05317655129..b544212e755e2 100644 --- a/src/lightning/pytorch/loggers/comet.py +++ b/src/lightning/pytorch/loggers/comet.py @@ -20,23 +20,26 @@ import os from argparse import Namespace from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Literal, Optional, Union from lightning_utilities.core.imports import RequirementCache from torch import Tensor from torch.nn import Module from typing_extensions import override -from lightning.fabric.utilities.logger import _add_prefix, _convert_params, _flatten_dict +from lightning.fabric.utilities.logger import _convert_params +from lightning.fabric.utilities.rank_zero import _get_rank from lightning.pytorch.loggers.logger import Logger, rank_zero_experiment -from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.rank_zero import rank_zero_only if TYPE_CHECKING: from comet_ml import ExistingExperiment, Experiment, OfflineExperiment log = logging.getLogger(__name__) -_COMET_AVAILABLE = RequirementCache("comet-ml>=3.31.0", module="comet_ml") +_COMET_AVAILABLE = RequirementCache("comet-ml>=3.44.4", module="comet_ml") + +FRAMEWORK_NAME = "pytorch-lightning" +comet_experiment = Union["Experiment", "ExistingExperiment", "OfflineExperiment"] class CometLogger(Logger): @@ -61,13 +64,11 @@ class CometLogger(Logger): # arguments made to CometLogger are passed on to the comet_ml.Experiment class comet_logger = CometLogger( - api_key=os.environ.get("COMET_API_KEY"), + api_key=os.environ.get("COMET_API_KEY"), # Optional workspace=os.environ.get("COMET_WORKSPACE"), # Optional - save_dir=".", # Optional - project_name="default_project", # Optional - rest_api_key=os.environ.get("COMET_REST_API_KEY"), # Optional + project="default_project", # Optional experiment_key=os.environ.get("COMET_EXPERIMENT_KEY"), # Optional - experiment_name="lightning_logs", # Optional + name="lightning_logs", # Optional ) trainer = Trainer(logger=comet_logger) @@ -79,11 +80,10 @@ class CometLogger(Logger): # arguments made to CometLogger are passed on to the comet_ml.Experiment class comet_logger = CometLogger( - save_dir=".", workspace=os.environ.get("COMET_WORKSPACE"), # Optional - project_name="default_project", # Optional - rest_api_key=os.environ.get("COMET_REST_API_KEY"), # Optional - experiment_name="lightning_logs", # Optional + project="default_project", # Optional + name="lightning_logs", # Optional + online=False ) trainer = Trainer(logger=comet_logger) @@ -107,6 +107,9 @@ def __init__(self, *args, **kwarg): # log multiple parameters logger.log_hyperparams({"batch_size": 16, "learning_rate": 0.001}) + # log nested parameters + logger.log_hyperparams({"specific": {'param': {'subparam': "value"}}}) + **Log Metrics:** .. code-block:: python @@ -117,6 +120,9 @@ def __init__(self, *args, **kwarg): # add multiple metrics logger.log_metrics({"train/loss": 0.001, "val/loss": 0.002}) + # add nested metrics + logger.log_metrics({"specific": {'metric': {'submetric': "value"}}}) + **Access the Comet Experiment object:** You can gain access to the underlying Comet @@ -167,100 +173,137 @@ def __init__(self, *args, **kwarg): - `Comet Documentation `__ Args: - api_key: Required in online mode. API key, found on Comet.ml. If not given, this - will be loaded from the environment variable COMET_API_KEY or ~/.comet.config - if either exists. - save_dir: Required in offline mode. The path for the directory to save local - comet logs. If given, this also sets the directory for saving checkpoints. - project_name: Optional. Send your experiment to a specific project. - Otherwise will be sent to Uncategorized Experiments. - If the project name does not already exist, Comet.ml will create a new project. - rest_api_key: Optional. Rest API key found in Comet.ml settings. - This is used to determine version number - experiment_name: Optional. String representing the name for this particular experiment on Comet.ml. - experiment_key: Optional. If set, restores from existing experiment. - offline: If api_key and save_dir are both given, this determines whether - the experiment will be in online or offline mode. This is useful if you use - save_dir to control the checkpoints directory and have a ~/.comet.config - file but still want to run offline experiments. - prefix: A string to put at the beginning of metric keys. - \**kwargs: Additional arguments like `workspace`, `log_code`, etc. used by + api_key: Comet API key. It's recommended to configure the API Key with `comet login`. + workspace: Comet workspace name. If not provided, uses the default workspace. + project: Comet project name. Defaults to `Uncategorized`. + experiment_key: The Experiment identifier to be used for logging. This is used either to append + data to an Existing Experiment or to control the key of new experiments (for example to match another + identifier). Must be an alphanumeric string whose length is between 32 and 50 characters. + mode: Control how the Comet experiment is started. + * ``"get_or_create"``: Starts a fresh experiment if required, or persists logging to an existing one. + * ``"get"``: Continue logging to an existing experiment identified by the ``experiment_key`` value. + * ``"create"``: Always creates of a new experiment, useful for HPO sweeps. + online: If True, the data will be logged to Comet server, otherwise it will be stored + locally in an offline experiment. Default is ``True``. + prefix: The prefix to add to names of the logged metrics. + example: prefix=`exp1`, then metric name will be logged as `exp1_metric_name` + **kwargs: Additional arguments like `name`, `log_code`, `offline_directory` etc. used by :class:`CometExperiment` can be passed as keyword arguments in this logger. Raises: ModuleNotFoundError: If required Comet package is not installed on the device. - MisconfigurationException: - If neither ``api_key`` nor ``save_dir`` are passed as arguments. """ - LOGGER_JOIN_CHAR = "-" - def __init__( self, + *, api_key: Optional[str] = None, - save_dir: Optional[str] = None, - project_name: Optional[str] = None, - rest_api_key: Optional[str] = None, - experiment_name: Optional[str] = None, + workspace: Optional[str] = None, + project: Optional[str] = None, experiment_key: Optional[str] = None, - offline: bool = False, - prefix: str = "", + mode: Optional[Literal["get_or_create", "get", "create"]] = None, + online: Optional[bool] = None, + prefix: Optional[str] = None, **kwargs: Any, ): if not _COMET_AVAILABLE: raise ModuleNotFoundError(str(_COMET_AVAILABLE)) + super().__init__() - self._experiment = None - self._save_dir: Optional[str] - self.rest_api_key: Optional[str] + + ################################################## + # HANDLE PASSED OLD TYPE PARAMS + + # handle old "experiment_name" param + if "experiment_name" in kwargs: + log.warning("The parameter `experiment_name` is deprecated, please use `name` instead.") + experiment_name = kwargs.pop("experiment_name") + + if "name" not in kwargs: + kwargs["name"] = experiment_name + else: + log.warning("You specified both `experiment_name` and `name` parameters, please use `name` only") + + # handle old "project_name" param + if "project_name" in kwargs: + log.warning("The parameter `project_name` is deprecated, please use `project` instead.") + if project is None: + project = kwargs.pop("project_name") + else: + log.warning("You specified both `project_name` and `project` parameters, please use `project` only") + + # handle old "offline" experiment flag + if "offline" in kwargs: + log.warning("The parameter `offline is deprecated, please use `online` instead.") + if online is None: + online = kwargs.pop("offline") + else: + log.warning("You specified both `offline` and `online` parameters, please use `online` only") + + # handle old "save_dir" param + if "save_dir" in kwargs: + log.warning("The parameter `save_dir` is deprecated, please use `offline_directory` instead.") + if "offline_directory" not in kwargs: + kwargs["offline_directory"] = kwargs.pop("save_dir") + else: + log.warning( + "You specified both `save_dir` and `offline_directory` parameters, " + "please use `offline_directory` only" + ) + ################################################## + + self._api_key: Optional[str] = api_key + self._experiment: Optional[comet_experiment] = None + self._workspace: Optional[str] = workspace + self._mode: Optional[Literal["get_or_create", "get", "create"]] = mode + self._online: Optional[bool] = online + self._project_name: Optional[str] = project + self._experiment_key: Optional[str] = experiment_key + self._prefix: Optional[str] = prefix + self._kwargs: dict[str, Any] = kwargs # needs to be set before the first `comet_ml` import + # because comet_ml imported after another machine learning libraries (Torch) os.environ["COMET_DISABLE_AUTO_LOGGING"] = "1" import comet_ml - # Determine online or offline mode based on which arguments were passed to CometLogger - api_key = api_key or comet_ml.config.get_api_key(None, comet_ml.config.get_config()) - - if api_key is not None and save_dir is not None: - self.mode = "offline" if offline else "online" - self.api_key = api_key - self._save_dir = save_dir - elif api_key is not None: - self.mode = "online" - self.api_key = api_key - self._save_dir = None - elif save_dir is not None: - self.mode = "offline" - self._save_dir = save_dir - else: - # If neither api_key nor save_dir are passed as arguments, raise an exception - raise MisconfigurationException("CometLogger requires either api_key or save_dir during initialization.") - - log.info(f"CometLogger will be initialized in {self.mode} mode") - - self._project_name: Optional[str] = project_name - self._experiment_key: Optional[str] = experiment_key - self._experiment_name: Optional[str] = experiment_name - self._prefix: str = prefix - self._kwargs: Any = kwargs - self._future_experiment_key: Optional[str] = None + config_kwargs = self._kwargs.copy() + if online is False: + config_kwargs["disabled"] = True + self._comet_config = comet_ml.ExperimentConfig(**config_kwargs) - if rest_api_key is not None: - from comet_ml.api import API + # create real experiment only on main node/process (when strategy=auto/ddp) + if _get_rank() is not None and _get_rank() != 0: + return + + self._create_experiment() + + def _create_experiment(self) -> None: + import comet_ml - # Comet.ml rest API, used to determine version number - self.rest_api_key = rest_api_key - self.comet_api = API(self.rest_api_key) - else: - self.rest_api_key = None - self.comet_api = None + self._experiment = comet_ml.start( + api_key=self._api_key, + workspace=self._workspace, + project=self._project_name, + experiment_key=self._experiment_key, + mode=self._mode, + online=self._online, + experiment_config=self._comet_config, + ) + + if self._experiment is None: + raise comet_ml.exceptions.ExperimentNotFound("Failed to create Comet experiment.") + + self._experiment_key = self._experiment.get_key() + self._project_name = self._experiment.project_name + self._experiment.log_other("Created from", FRAMEWORK_NAME) @property @rank_zero_experiment - def experiment(self) -> Union["Experiment", "ExistingExperiment", "OfflineExperiment"]: + def experiment(self) -> comet_experiment: r"""Actual Comet object. To use Comet features in your :class:`~lightning.pytorch.core.LightningModule` do the following. @@ -269,38 +312,11 @@ def experiment(self) -> Union["Experiment", "ExistingExperiment", "OfflineExperi self.logger.experiment.some_comet_function() """ - if self._experiment is not None and self._experiment.alive: - return self._experiment - - if self._future_experiment_key is not None: - os.environ["COMET_EXPERIMENT_KEY"] = self._future_experiment_key - - from comet_ml import ExistingExperiment, Experiment, OfflineExperiment - - try: - if self.mode == "online": - if self._experiment_key is None: - self._experiment = Experiment(api_key=self.api_key, project_name=self._project_name, **self._kwargs) - self._experiment_key = self._experiment.get_key() - else: - self._experiment = ExistingExperiment( - api_key=self.api_key, - project_name=self._project_name, - previous_experiment=self._experiment_key, - **self._kwargs, - ) - else: - self._experiment = OfflineExperiment( - offline_directory=self.save_dir, project_name=self._project_name, **self._kwargs - ) - self._experiment.log_other("Created from", "pytorch-lightning") - finally: - if self._future_experiment_key is not None: - os.environ.pop("COMET_EXPERIMENT_KEY") - self._future_experiment_key = None - if self._experiment_name: - self._experiment.set_name(self._experiment_name) + # if by some chance there is no experiment created yet (for example, when strategy=ddp_spawn) + # then we will create a new one + if not self._experiment: + self._create_experiment() return self._experiment @@ -308,43 +324,44 @@ def experiment(self) -> Union["Experiment", "ExistingExperiment", "OfflineExperi @rank_zero_only def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: params = _convert_params(params) - params = _flatten_dict(params) - self.experiment.log_parameters(params) + self.experiment.__internal_api__log_parameters__( + parameters=params, + framework=FRAMEWORK_NAME, + flatten_nested=True, + source="manual", + ) @override @rank_zero_only def log_metrics(self, metrics: Mapping[str, Union[Tensor, float]], step: Optional[int] = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" - # Comet.ml expects metrics to be a dictionary of detached tensors on CPU + # Comet.com expects metrics to be a dictionary of detached tensors on CPU metrics_without_epoch = metrics.copy() for key, val in metrics_without_epoch.items(): if isinstance(val, Tensor): metrics_without_epoch[key] = val.cpu().detach() epoch = metrics_without_epoch.pop("epoch", None) - metrics_without_epoch = _add_prefix(metrics_without_epoch, self._prefix, self.LOGGER_JOIN_CHAR) - self.experiment.log_metrics(metrics_without_epoch, step=step, epoch=epoch) - - def reset_experiment(self) -> None: - self._experiment = None + self.experiment.__internal_api__log_metrics__( + metrics_without_epoch, + step=step, + epoch=epoch, + prefix=self._prefix, + framework=FRAMEWORK_NAME, + ) @override @rank_zero_only def finalize(self, status: str) -> None: - r"""When calling ``self.experiment.end()``, that experiment won't log any more data to Comet. That's why, if you - need to log any more data, you need to create an ExistingCometExperiment. For example, to log data when testing - your model after training, because when training is finalized :meth:`CometLogger.finalize` is called. - - This happens automatically in the :meth:`~CometLogger.experiment` property, when - ``self._experiment`` is set to ``None``, i.e. ``self.reset_experiment()``. - - """ + """We will not end experiment (will not call self._experiment.end()) here to have an ability to continue using + it after training is complete but instead of ending we will upload/save all the data.""" if self._experiment is None: # When using multiprocessing, finalize() should be a no-op on the main process, as no experiment has been # initialized there return - self.experiment.end() - self.reset_experiment() + + # just save the data + self.experiment.flush() @property @override @@ -355,61 +372,31 @@ def save_dir(self) -> Optional[str]: The path to the save directory. """ - return self._save_dir + return self._comet_config.offline_directory @property @override - def name(self) -> str: + def name(self) -> Optional[str]: """Gets the project name. Returns: - The project name if it is specified, else "comet-default". + The project name if it is specified. """ - # Don't create an experiment if we don't have one - if self._experiment is not None and self._experiment.project_name is not None: - return self._experiment.project_name - - if self._project_name is not None: - return self._project_name - - return "comet-default" + return self._project_name @property @override - def version(self) -> str: + def version(self) -> Optional[str]: """Gets the version. Returns: - The first one of the following that is set in the following order - - 1. experiment id. - 2. experiment key. - 3. "COMET_EXPERIMENT_KEY" environment variable. - 4. future experiment key. - - If none are present generates a new guid. + The experiment key if present """ # Don't create an experiment if we don't have one if self._experiment is not None: - return self._experiment.id - - if self._experiment_key is not None: - return self._experiment_key - - if "COMET_EXPERIMENT_KEY" in os.environ: - return os.environ["COMET_EXPERIMENT_KEY"] - - if self._future_experiment_key is not None: - return self._future_experiment_key - - import comet_ml - - # Pre-generate an experiment key - self._future_experiment_key = comet_ml.generate_guid() - - return self._future_experiment_key + return self._experiment.get_key() def __getstate__(self) -> dict[str, Any]: state = self.__dict__.copy() @@ -417,7 +404,7 @@ def __getstate__(self) -> dict[str, Any]: # Save the experiment id in case an experiment object already exists, # this way we could create an ExistingExperiment pointing to the same # experiment - state["_experiment_key"] = self._experiment.id if self._experiment is not None else None + state["_experiment_key"] = self._experiment.get_key() if self._experiment is not None else None # Remove the experiment object as it contains hard to pickle objects # (like network connections), the experiment object will be recreated if @@ -428,4 +415,7 @@ def __getstate__(self) -> dict[str, Any]: @override def log_graph(self, model: Module, input_array: Optional[Tensor] = None) -> None: if self._experiment is not None: - self._experiment.set_model_graph(model) + self._experiment.__internal_api__set_model_graph__( + graph=model, + framework=FRAMEWORK_NAME, + ) diff --git a/src/lightning/pytorch/loggers/csv_logs.py b/src/lightning/pytorch/loggers/csv_logs.py index 8606264dc3cdb..5ad7353310af4 100644 --- a/src/lightning/pytorch/loggers/csv_logs.py +++ b/src/lightning/pytorch/loggers/csv_logs.py @@ -55,15 +55,10 @@ def __init__(self, log_dir: str) -> None: self.hparams: dict[str, Any] = {} def log_hparams(self, params: dict[str, Any]) -> None: - """Record hparams.""" + """Record hparams and save into files.""" self.hparams.update(params) - - @override - def save(self) -> None: - """Save recorded hparams and metrics into files.""" hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE) save_hparams_to_yaml(hparams_file, self.hparams) - return super().save() class CSVLogger(Logger, FabricCSVLogger): @@ -144,7 +139,7 @@ def save_dir(self) -> str: @override @rank_zero_only - def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: Optional[Union[dict[str, Any], Namespace]] = None) -> None: params = _convert_params(params) self.experiment.log_hparams(params) diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index e3d99987b7f58..02396d8021633 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -97,7 +97,7 @@ def any_lightning_module_function_or_hook(self): :paramref:`~lightning.pytorch.callbacks.Checkpoint.save_top_k` ``== -1`` which also logs every checkpoint during training. * if ``log_model == False`` (default), no checkpoint is logged. - + checkpoint_path_prefix: A string to prefix the checkpoint artifact's path. prefix: A string to put at the beginning of metric keys. artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate default. @@ -121,6 +121,7 @@ def __init__( tags: Optional[dict[str, Any]] = None, save_dir: Optional[str] = "./mlruns", log_model: Literal[True, False, "all"] = False, + checkpoint_path_prefix: str = "", prefix: str = "", artifact_location: Optional[str] = None, run_id: Optional[str] = None, @@ -147,6 +148,7 @@ def __init__( self._artifact_location = artifact_location self._log_batch_kwargs = {} if synchronous is None else {"synchronous": synchronous} self._initialized = False + self._checkpoint_path_prefix = checkpoint_path_prefix from mlflow.tracking import MlflowClient @@ -178,7 +180,7 @@ def experiment(self) -> "MlflowClient": if self._experiment_id is None: expt = self._mlflow_client.get_experiment_by_name(self._experiment_name) - if expt is not None: + if expt is not None and expt.lifecycle_stage != "deleted": self._experiment_id = expt.experiment_id else: log.warning(f"Experiment with name {self._experiment_name} not found. Creating it.") @@ -361,13 +363,13 @@ def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> Non aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"] # Artifact path on mlflow - artifact_path = Path(p).stem + artifact_path = Path(self._checkpoint_path_prefix) / Path(p).stem # Log the checkpoint self.experiment.log_artifact(self._run_id, p, artifact_path) # Create a temporary directory to log on mlflow - with tempfile.TemporaryDirectory(prefix="test", suffix="test", dir=os.getcwd()) as tmp_dir: + with tempfile.TemporaryDirectory() as tmp_dir: # Log the metadata with open(f"{tmp_dir}/metadata.yaml", "w") as tmp_file_metadata: yaml.dump(metadata, tmp_file_metadata, default_flow_style=False) diff --git a/src/lightning/pytorch/loggers/neptune.py b/src/lightning/pytorch/loggers/neptune.py index a363f589b29b4..bf9669c824784 100644 --- a/src/lightning/pytorch/loggers/neptune.py +++ b/src/lightning/pytorch/loggers/neptune.py @@ -508,8 +508,6 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None: if not self._log_model_checkpoints: return - from neptune.types import File - file_names = set() checkpoints_namespace = self._construct_path_with_prefix("model/checkpoints") @@ -517,8 +515,7 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None: if hasattr(checkpoint_callback, "last_model_path") and checkpoint_callback.last_model_path: model_last_name = self._get_full_model_name(checkpoint_callback.last_model_path, checkpoint_callback) file_names.add(model_last_name) - with open(checkpoint_callback.last_model_path, "rb") as fp: - self.run[f"{checkpoints_namespace}/{model_last_name}"] = File.from_stream(fp) + self.run[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path) # save best k models if hasattr(checkpoint_callback, "best_k_models"): @@ -533,8 +530,7 @@ def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None: model_name = self._get_full_model_name(checkpoint_callback.best_model_path, checkpoint_callback) file_names.add(model_name) - with open(checkpoint_callback.best_model_path, "rb") as fp: - self.run[f"{checkpoints_namespace}/{model_name}"] = File.from_stream(fp) + self.run[f"{checkpoints_namespace}/{model_name}"].upload(checkpoint_callback.best_model_path) # remove old models logged to experiment if they are not part of best k models at this point if self.run.exists(checkpoints_namespace): diff --git a/src/lightning/pytorch/loggers/utilities.py b/src/lightning/pytorch/loggers/utilities.py index e1752c67d9183..ced8a6f1f2bd3 100644 --- a/src/lightning/pytorch/loggers/utilities.py +++ b/src/lightning/pytorch/loggers/utilities.py @@ -74,7 +74,7 @@ def _log_hyperparams(trainer: "pl.Trainer") -> None: continue lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key] if ( - type(lm_val) != type(dm_val) + type(lm_val) != type(dm_val) # noqa: E721 or (isinstance(lm_val, Tensor) and id(lm_val) != id(dm_val)) or lm_val != dm_val ): diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index 2429748f73179..5bef27192b127 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -322,7 +322,7 @@ def __init__( self._prefix = prefix self._experiment = experiment self._logged_model_time: dict[str, float] = {} - self._checkpoint_callback: Optional[ModelCheckpoint] = None + self._checkpoint_callbacks: dict[int, ModelCheckpoint] = {} # paths are processed as strings if save_dir is not None: @@ -410,8 +410,11 @@ def experiment(self) -> Union["Run", "RunDisabled"]: if isinstance(self._experiment, (Run, RunDisabled)) and getattr( self._experiment, "define_metric", None ): - self._experiment.define_metric("trainer/global_step") - self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True) + if self._wandb_init.get("sync_tensorboard"): + self._experiment.define_metric("*", step_metric="global_step") + else: + self._experiment.define_metric("trainer/global_step") + self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True) return self._experiment @@ -434,7 +437,7 @@ def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) - if step is not None: + if step is not None and not self._wandb_init.get("sync_tensorboard"): self.experiment.log(dict(metrics, **{"trainer/global_step": step})) else: self.experiment.log(metrics) @@ -588,7 +591,7 @@ def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1: self._scan_and_log_checkpoints(checkpoint_callback) elif self._log_model is True: - self._checkpoint_callback = checkpoint_callback + self._checkpoint_callbacks[id(checkpoint_callback)] = checkpoint_callback @staticmethod @rank_zero_only @@ -641,8 +644,9 @@ def finalize(self, status: str) -> None: # Currently, checkpoints only get logged on success return # log checkpoints as artifacts - if self._checkpoint_callback and self._experiment is not None: - self._scan_and_log_checkpoints(self._checkpoint_callback) + if self._experiment is not None: + for checkpoint_callback in self._checkpoint_callbacks.values(): + self._scan_and_log_checkpoints(checkpoint_callback) def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None: import wandb diff --git a/src/lightning/pytorch/loops/evaluation_loop.py b/src/lightning/pytorch/loops/evaluation_loop.py index d007466ee3b1c..7f033dbd8e2c2 100644 --- a/src/lightning/pytorch/loops/evaluation_loop.py +++ b/src/lightning/pytorch/loops/evaluation_loop.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math import os import shutil import sys @@ -268,7 +269,10 @@ def increment_progress_to_evaluation_end(self) -> None: if self.skip: return self.reset() - max_batch = int(max(self.max_batches)) + max_batch = max(self.max_batches) + if isinstance(max_batch, float) and math.isinf(max_batch): + return + max_batch = int(max_batch) if max_batch == -1: return self.batch_progress.increment_by(max_batch, True) diff --git a/src/lightning/pytorch/plugins/precision/xla.py b/src/lightning/pytorch/plugins/precision/xla.py index 7682cdc4502f9..6890cc4c1d825 100644 --- a/src/lightning/pytorch/plugins/precision/xla.py +++ b/src/lightning/pytorch/plugins/precision/xla.py @@ -79,7 +79,7 @@ def optimizer_step( # type: ignore[override] # we lack coverage here so disable this - something to explore if there's demand raise MisconfigurationException( "Skipping backward by returning `None` from your `training_step` is not implemented with XLA." - " Please, open an issue in `https://github.com/Lightning-AI/lightning/issues`" + " Please, open an issue in `https://github.com/Lightning-AI/pytorch-lightning/issues`" " requesting this feature." ) return closure_result diff --git a/src/lightning/pytorch/strategies/deepspeed.py b/src/lightning/pytorch/strategies/deepspeed.py index e17377d4464b0..dabfde70242b9 100644 --- a/src/lightning/pytorch/strategies/deepspeed.py +++ b/src/lightning/pytorch/strategies/deepspeed.py @@ -166,7 +166,7 @@ def __init__( nvme_path: Filesystem path for NVMe device for optimizer/parameter state offloading. optimizer_buffer_count: Number of buffers in buffer pool for optimizer state offloading - when ``offload_optimizer_device`` is set to to ``nvme``. + when ``offload_optimizer_device`` is set to ``nvme``. This should be at least the number of states maintained per parameter by the optimizer. For example, Adam optimizer has 4 states (parameter, gradient, momentum, and variance). @@ -400,8 +400,7 @@ def _setup_model_and_optimizers( """ if len(optimizers) != 1: raise ValueError( - f"Currently only one optimizer is supported with DeepSpeed." - f" Got {len(optimizers)} optimizers instead." + f"Currently only one optimizer is supported with DeepSpeed. Got {len(optimizers)} optimizers instead." ) # train_micro_batch_size_per_gpu is used for throughput logging purposes diff --git a/src/lightning/pytorch/strategies/launchers/multiprocessing.py b/src/lightning/pytorch/strategies/launchers/multiprocessing.py index aa207a527814e..3589460574c39 100644 --- a/src/lightning/pytorch/strategies/launchers/multiprocessing.py +++ b/src/lightning/pytorch/strategies/launchers/multiprocessing.py @@ -88,7 +88,7 @@ def __init__( def is_interactive_compatible(self) -> bool: # The start method 'spawn' is not supported in interactive environments # The start method 'fork' is the only one supported in Jupyter environments, with constraints around CUDA - # initialization. For more context, see https://github.com/Lightning-AI/lightning/issues/7550 + # initialization. For more context, see https://github.com/Lightning-AI/pytorch-lightning/issues/7550 return self._start_method == "fork" @override @@ -111,7 +111,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] if self._start_method == "spawn": _check_missing_main_guard() if self._already_fit and trainer is not None and trainer.state.fn == TrainerFn.FITTING: - # resolving https://github.com/Lightning-AI/lightning/issues/18775 will lift this restriction + # resolving https://github.com/Lightning-AI/pytorch-lightning/issues/18775 will lift this restriction raise NotImplementedError( "Calling `trainer.fit()` twice on the same Trainer instance using a spawn-based strategy is not" " supported. You can work around this limitation by creating a new Trainer instance and passing the" diff --git a/src/lightning/pytorch/strategies/launchers/xla.py b/src/lightning/pytorch/strategies/launchers/xla.py index 831faeb7bb993..066fecc79f208 100644 --- a/src/lightning/pytorch/strategies/launchers/xla.py +++ b/src/lightning/pytorch/strategies/launchers/xla.py @@ -76,7 +76,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] """ if self._already_fit and trainer is not None and trainer.state.fn == TrainerFn.FITTING: - # resolving https://github.com/Lightning-AI/lightning/issues/18775 will lift this restriction + # resolving https://github.com/Lightning-AI/pytorch-lightning/issues/18775 will lift this restriction raise NotImplementedError( "Calling `trainer.fit()` twice on the same Trainer instance using a spawn-based strategy is not" " supported. You can work around this by creating a new Trainer instance and passing the" diff --git a/src/lightning/pytorch/trainer/__init__.py b/src/lightning/pytorch/trainer/__init__.py index cbed5dd4f1f20..f2e1b963306a1 100644 --- a/src/lightning/pytorch/trainer/__init__.py +++ b/src/lightning/pytorch/trainer/__init__.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""""" from lightning.fabric.utilities.seed import seed_everything from lightning.pytorch.trainer.trainer import Trainer diff --git a/src/lightning/pytorch/trainer/call.py b/src/lightning/pytorch/trainer/call.py index 012d1a2152aa3..b5354eb2b08dd 100644 --- a/src/lightning/pytorch/trainer/call.py +++ b/src/lightning/pytorch/trainer/call.py @@ -21,6 +21,7 @@ import lightning.pytorch as pl from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from lightning.pytorch.callbacks import Checkpoint, EarlyStopping +from lightning.pytorch.loggers import WandbLogger from lightning.pytorch.strategies.launchers import _SubprocessScriptLauncher from lightning.pytorch.trainer.connectors.signal_connector import _get_sigkill_signal from lightning.pytorch.trainer.states import TrainerStatus @@ -91,8 +92,12 @@ def _call_setup_hook(trainer: "pl.Trainer") -> None: if isinstance(module, _DeviceDtypeModuleMixin): module._device = trainer.strategy.root_device + # wandb.init must be called before any tensorboard writers are created in order to sync tensorboard logs to wandb: + # https://github.com/wandb/wandb/issues/1782#issuecomment-779161203 + loggers = sorted(trainer.loggers, key=lambda logger: not isinstance(logger, WandbLogger)) + # Trigger lazy creation of experiment in loggers so loggers have their metadata available - for logger in trainer.loggers: + for logger in loggers: if hasattr(logger, "experiment"): _ = logger.experiment diff --git a/src/lightning/pytorch/trainer/connectors/callback_connector.py b/src/lightning/pytorch/trainer/connectors/callback_connector.py index a60f907d9361b..3f107bd9a124a 100644 --- a/src/lightning/pytorch/trainer/connectors/callback_connector.py +++ b/src/lightning/pytorch/trainer/connectors/callback_connector.py @@ -11,13 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import inspect import logging import os from collections.abc import Sequence from datetime import timedelta from typing import Optional, Union +from lightning_utilities import module_available + import lightning.pytorch as pl from lightning.fabric.utilities.registry import _load_external_callbacks from lightning.pytorch.callbacks import ( @@ -91,7 +93,24 @@ def _configure_checkpoint_callbacks(self, enable_checkpointing: bool) -> None: " but found `ModelCheckpoint` in callbacks list." ) elif enable_checkpointing: - self.trainer.callbacks.append(ModelCheckpoint()) + if module_available("litmodels") and self.trainer._model_registry: + trainer_source = inspect.getmodule(self.trainer) + if trainer_source is None or not isinstance(trainer_source.__package__, str): + raise RuntimeError("Unable to determine the source of the trainer.") + # this need to imported based on the actual package lightning/pytorch_lightning + if "pytorch_lightning" in trainer_source.__package__: + from litmodels.integrations.checkpoints import PytorchLightningModelCheckpoint as LitModelCheckpoint + else: + from litmodels.integrations.checkpoints import LightningModelCheckpoint as LitModelCheckpoint + + model_checkpoint = LitModelCheckpoint(model_name=self.trainer._model_registry) + else: + rank_zero_info( + "You are using the plain ModelCheckpoint callback." + " Consider using LitModelCheckpoint which with seamless uploading to Model registry." + ) + model_checkpoint = ModelCheckpoint() + self.trainer.callbacks.append(model_checkpoint) def _configure_model_summary_callback(self, enable_model_summary: bool) -> None: if not enable_model_summary: diff --git a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py index 71cc5a14686be..7f97a2f54bf19 100644 --- a/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py +++ b/src/lightning/pytorch/trainer/connectors/checkpoint_connector.py @@ -19,6 +19,7 @@ import torch from fsspec.core import url_to_fs from fsspec.implementations.local import LocalFileSystem +from lightning_utilities import module_available from torch import Tensor import lightning.pytorch as pl @@ -33,6 +34,10 @@ from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE from lightning.pytorch.utilities.migration import pl_legacy_patch from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint +from lightning.pytorch.utilities.model_registry import ( + _is_registry, + find_model_local_ckpt_path, +) from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn log = logging.getLogger(__name__) @@ -48,8 +53,7 @@ def __init__(self, trainer: "pl.Trainer") -> None: @property def _hpc_resume_path(self) -> Optional[str]: - dir_path_hpc = self.trainer.default_root_dir - dir_path_hpc = str(dir_path_hpc) + dir_path_hpc = str(self.trainer.default_root_dir) fs, path = url_to_fs(dir_path_hpc) if not _is_dir(fs, path): return None @@ -194,10 +198,17 @@ def _parse_ckpt_path( if not self._hpc_resume_path: raise ValueError( f'`.{fn}(ckpt_path="hpc")` is set but no HPC checkpoint was found.' - " Please pass an exact checkpoint path to `.{fn}(ckpt_path=...)`" + f" Please pass an exact checkpoint path to `.{fn}(ckpt_path=...)`" ) ckpt_path = self._hpc_resume_path + elif _is_registry(ckpt_path) and module_available("litmodels"): + ckpt_path = find_model_local_ckpt_path( + ckpt_path, + default_model_registry=self.trainer._model_registry, + default_root_dir=self.trainer.default_root_dir, + ) + if not ckpt_path: raise ValueError( f"`.{fn}()` found no path for the best weights: {ckpt_path!r}. Please" diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py b/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py index 0dbdc4eaf76e1..c1ee0013bfa19 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/fx_validator.py @@ -154,7 +154,7 @@ def check_logging(cls, fx_name: str) -> None: if fx_name not in cls.functions: raise RuntimeError( f"Logging inside `{fx_name}` is not implemented." - " Please, open an issue in `https://github.com/Lightning-AI/lightning/issues`." + " Please, open an issue in `https://github.com/Lightning-AI/pytorch-lightning/issues`." ) if cls.functions[fx_name] is None: diff --git a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py index fdde19aa80eea..0881ac0b3fa08 100644 --- a/src/lightning/pytorch/trainer/connectors/logger_connector/result.py +++ b/src/lightning/pytorch/trainer/connectors/logger_connector/result.py @@ -157,7 +157,7 @@ def forked(self) -> bool: def forked_name(self, on_step: bool) -> str: if self.forked: - return f'{self.name}_{"step" if on_step else "epoch"}' + return f"{self.name}_{'step' if on_step else 'epoch'}" return self.name @property diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index 0509f28acb07a..8b976cd2f4f46 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -30,6 +30,7 @@ from weakref import proxy import torch +from lightning_utilities import module_available from torch.optim import Optimizer import lightning.pytorch as pl @@ -70,6 +71,7 @@ from lightning.pytorch.utilities.compile import _maybe_unwrap_optimized, _verify_strategy_supports_compile from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.model_helpers import is_overridden +from lightning.pytorch.utilities.model_registry import _is_registry, download_model_from_registry from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn from lightning.pytorch.utilities.seed import isolate_rng from lightning.pytorch.utilities.types import ( @@ -128,6 +130,7 @@ def __init__( sync_batchnorm: bool = False, reload_dataloaders_every_n_epochs: int = 0, default_root_dir: Optional[_PATH] = None, + model_registry: Optional[str] = None, ) -> None: r"""Customize every aspect of training via flags. @@ -290,6 +293,8 @@ def __init__( Default: ``os.getcwd()``. Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' + model_registry: The name of the model being uploaded to Model hub. + Raises: TypeError: If ``gradient_clip_val`` is not an int or float. @@ -304,6 +309,9 @@ def __init__( if default_root_dir is not None: default_root_dir = os.fspath(default_root_dir) + # remove version if accidentally passed + self._model_registry = model_registry.split(":")[0] if model_registry else None + self.barebones = barebones if barebones: # opt-outs @@ -519,7 +527,20 @@ def fit( the :class:`~lightning.pytorch.core.hooks.DataHooks.train_dataloader` hook. ckpt_path: Path/URL of the checkpoint from which training is resumed. Could also be one of two special - keywords ``"last"`` and ``"hpc"``. If there is no checkpoint file at the path, an exception is raised. + keywords ``"last"``, ``"hpc"`` and ``"registry"``. + Otherwise, if there is no checkpoint file at the path, an exception is raised. + + - best: the best model checkpoint from the previous ``trainer.fit`` call will be loaded + - last: the last model checkpoint from the previous ``trainer.fit`` call will be loaded + - registry: the model will be downloaded from the Lightning Model Registry with following notations: + + - ``'registry'``: uses the latest/default version of default model set + with ``Tainer(..., model_registry="my-model")`` + - ``'registry:model-name'``: uses the latest/default version of this model `model-name` + - ``'registry:model-name:version:v2'``: uses the specific version 'v2' of the model `model-name` + - ``'registry:version:v2'``: uses the default model set + with ``Tainer(..., model_registry="my-model")`` and version 'v2' + Raises: TypeError: @@ -536,6 +557,7 @@ def fit( self.state.fn = TrainerFn.FITTING self.state.status = TrainerStatus.RUNNING self.training = True + self.should_stop = False call._call_and_handle_interrupt( self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path ) @@ -566,6 +588,8 @@ def _fit_impl( ) assert self.state.fn is not None + if _is_registry(ckpt_path) and module_available("litmodels"): + download_model_from_registry(ckpt_path, self) ckpt_path = self._checkpoint_connector._select_ckpt_path( self.state.fn, ckpt_path, @@ -595,8 +619,8 @@ def validate( Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines the :class:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` hook. - ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to validate. - If ``None`` and the model instance was passed, use the current weights. + ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish + to validate. If ``None`` and the model instance was passed, use the current weights. Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded if a checkpoint callback is configured. @@ -674,6 +698,8 @@ def _validate_impl( self._data_connector.attach_data(model, val_dataloaders=dataloaders, datamodule=datamodule) assert self.state.fn is not None + if _is_registry(ckpt_path) and module_available("litmodels"): + download_model_from_registry(ckpt_path, self) ckpt_path = self._checkpoint_connector._select_ckpt_path( self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) @@ -704,8 +730,8 @@ def test( Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines the :class:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` hook. - ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to test. - If ``None`` and the model instance was passed, use the current weights. + ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish + to test. If ``None`` and the model instance was passed, use the current weights. Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded if a checkpoint callback is configured. @@ -783,6 +809,8 @@ def _test_impl( self._data_connector.attach_data(model, test_dataloaders=dataloaders, datamodule=datamodule) assert self.state.fn is not None + if _is_registry(ckpt_path) and module_available("litmodels"): + download_model_from_registry(ckpt_path, self) ckpt_path = self._checkpoint_connector._select_ckpt_path( self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) @@ -819,8 +847,8 @@ def predict( return_predictions: Whether to return predictions. ``True`` by default except when an accelerator that spawns processes is used (not supported). - ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to predict. - If ``None`` and the model instance was passed, use the current weights. + ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"``, ``"registry"`` or path to the checkpoint you wish + to predict. If ``None`` and the model instance was passed, use the current weights. Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded if a checkpoint callback is configured. @@ -892,6 +920,8 @@ def _predict_impl( self._data_connector.attach_data(model, predict_dataloaders=dataloaders, datamodule=datamodule) assert self.state.fn is not None + if _is_registry(ckpt_path) and module_available("litmodels"): + download_model_from_registry(ckpt_path, self) ckpt_path = self._checkpoint_connector._select_ckpt_path( self.state.fn, ckpt_path, model_provided=model_provided, model_connected=self.lightning_module is not None ) diff --git a/src/lightning/pytorch/utilities/model_helpers.py b/src/lightning/pytorch/utilities/model_helpers.py index 44591aa7f4dc1..daf1c400c03df 100644 --- a/src/lightning/pytorch/utilities/model_helpers.py +++ b/src/lightning/pytorch/utilities/model_helpers.py @@ -104,7 +104,7 @@ def _check_mixed_imports(instance: object) -> None: _R_co = TypeVar("_R_co", covariant=True) # return type of the decorated method -class _restricted_classmethod_impl(Generic[_T, _P, _R_co]): +class _restricted_classmethod_impl(Generic[_T, _R_co, _P]): """Drop-in replacement for @classmethod, but raises an exception when the decorated method is called on an instance instead of a class type.""" diff --git a/src/lightning/pytorch/utilities/model_registry.py b/src/lightning/pytorch/utilities/model_registry.py new file mode 100644 index 0000000000000..a9ed495eb37d8 --- /dev/null +++ b/src/lightning/pytorch/utilities/model_registry.py @@ -0,0 +1,178 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import re +from typing import Optional + +from lightning_utilities import module_available + +import lightning.pytorch as pl +from lightning.fabric.utilities.imports import _IS_WINDOWS +from lightning.fabric.utilities.types import _PATH + +# skip these test on Windows as the path notation differ +if _IS_WINDOWS: + __doctest_skip__ = ["_determine_model_folder"] + + +def _is_registry(text: Optional[_PATH]) -> bool: + """Check if a string equals 'registry' or starts with 'registry:'. + + Args: + text: The string to check + + >>> _is_registry("registry") + True + >>> _is_registry("REGISTRY:model-name") + True + >>> _is_registry("something_registry") + False + >>> _is_registry("") + False + + """ + if not isinstance(text, str): + return False + + # Pattern matches exactly 'registry' or 'registry:' followed by any characters + pattern = r"^registry(:.*|$)" + return bool(re.match(pattern, text.lower())) + + +def _parse_registry_model_version(ckpt_path: Optional[_PATH]) -> tuple[str, str]: + """Parse the model version from a registry path. + + Args: + ckpt_path: The checkpoint path + + Returns: + string name and version of the model + + >>> _parse_registry_model_version("registry:model-name:version:1.0") + ('model-name', '1.0') + >>> _parse_registry_model_version("registry:model-name") + ('model-name', '') + >>> _parse_registry_model_version("registry:version:v2") + ('', 'v2') + + """ + if not ckpt_path or not _is_registry(ckpt_path): + raise ValueError(f"Invalid registry path: {ckpt_path}") + + # Split the path by ':' + parts = str(ckpt_path).lower().split(":") + # Default values + model_name, version = "", "" + + # Extract the model name and version based on the parts + if len(parts) >= 2 and parts[1] != "version": + model_name = parts[1] + if len(parts) == 3 and parts[1] == "version": + version = parts[2] + elif len(parts) == 4 and parts[2] == "version": + version = parts[3] + + return model_name, version + + +def _determine_model_name(ckpt_path: Optional[_PATH], default_model_registry: Optional[str]) -> str: + """Determine the model name from the checkpoint path. + + Args: + ckpt_path: The checkpoint path + default_model_registry: The default model registry + + Returns: + string name of the model with optional version + + >>> _determine_model_name("registry:model-name:version:1.0", "default-model") + 'model-name:1.0' + >>> _determine_model_name("registry:model-name", "default-model") + 'model-name' + >>> _determine_model_name("registry:version:v2", "default-model") + 'default-model:v2' + + """ + # try to find model and version + model_name, model_version = _parse_registry_model_version(ckpt_path) + # omitted model name try to use the model registry from Trainer + if not model_name and default_model_registry: + model_name = default_model_registry + if not model_name: + raise ValueError(f"Invalid model registry: '{ckpt_path}'") + model_registry = model_name + model_registry += f":{model_version}" if model_version else "" + return model_registry + + +def _determine_model_folder(model_name: str, default_root_dir: str) -> str: + """Determine the local model folder based on the model registry. + + Args: + model_name: The model name + default_root_dir: The default root directory + + Returns: + string path to the local model folder + + >>> _determine_model_folder("model-name", "/path/to/root") + '/path/to/root/model-name' + >>> _determine_model_folder("model-name:1.0", "/path/to/root") + '/path/to/root/model-name_1.0' + + """ + if not model_name: + raise ValueError(f"Invalid model registry: '{model_name}'") + # download the latest checkpoint from the model registry + model_name = model_name.replace("/", "_") + model_name = model_name.replace(":", "_") + local_model_dir = os.path.join(default_root_dir, model_name) + return local_model_dir + + +def find_model_local_ckpt_path( + ckpt_path: Optional[_PATH], default_model_registry: Optional[str], default_root_dir: str +) -> str: + """Find the local checkpoint path for a model.""" + model_registry = _determine_model_name(ckpt_path, default_model_registry) + local_model_dir = _determine_model_folder(model_registry, default_root_dir) + + # todo: resolve if there are multiple checkpoints + folder_files = [fn for fn in os.listdir(local_model_dir) if fn.endswith(".ckpt")] + if not folder_files: + raise RuntimeError(f"Parsing files from downloaded model: {model_registry}") + # print(f"local RANK {self.trainer.local_rank}: using model files: {folder_files}") + return os.path.join(local_model_dir, folder_files[0]) + + +def download_model_from_registry(ckpt_path: Optional[_PATH], trainer: "pl.Trainer") -> None: + """Download a model from the Lightning Model Registry.""" + if trainer.local_rank == 0: + if not module_available("litmodels"): + raise ImportError( + "The `litmodels` package is not installed. Please install it with `pip install litmodels`." + ) + + from litmodels import download_model + + model_registry = _determine_model_name(ckpt_path, trainer._model_registry) + local_model_dir = _determine_model_folder(model_registry, trainer.default_root_dir) + + # print(f"Rank {self.trainer.local_rank} downloads model checkpoint '{model_registry}'") + model_files = download_model(model_registry, download_dir=local_model_dir) + # print(f"Model checkpoint '{model_registry}' was downloaded to '{local_model_dir}'") + if not model_files: + raise RuntimeError(f"Download model failed - {model_registry}") + + trainer.strategy.barrier("download_model_from_registry") diff --git a/src/lightning_fabric/README.md b/src/lightning_fabric/README.md index d842c0d19118d..076800caeb71e 100644 --- a/src/lightning_fabric/README.md +++ b/src/lightning_fabric/README.md @@ -215,7 +215,7 @@ Lightning is rigorously tested across multiple CPUs and GPUs and against major P | System / PyTorch ver. | 1.12 | 1.13 | 2.0 | 2.1 | | :--------------------------------: | :-------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------: | -| Linux py3.9 \[GPUs\] | | | | ![Build Status](https://dev.azure.com/Lightning-AI/lightning/_apis/build/status%2Flightning-fabric%20%28GPUs%29) | +| Linux py3.9 [GPUs] | | | | ![Build Status](https://dev.azure.com/Lightning-AI/lightning/_apis/build/status%2Flightning-fabric%20%28GPUs%29) | | Linux (multiple Python versions) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | | OSX (multiple Python versions) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | | Windows (multiple Python versions) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | ![Test Fabric](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-fabric.yml/badge.svg) | diff --git a/src/lightning_fabric/__setup__.py b/src/lightning_fabric/__setup__.py index a55e1f2332f37..36dbae53ef171 100644 --- a/src/lightning_fabric/__setup__.py +++ b/src/lightning_fabric/__setup__.py @@ -85,7 +85,7 @@ def _setup_args() -> dict[str, Any]: }, "extras_require": _prepare_extras(), "project_urls": { - "Bug Tracker": "https://github.com/Lightning-AI/lightning/issues", + "Bug Tracker": "https://github.com/Lightning-AI/pytorch-lightning/issues", "Documentation": "https://pytorch-lightning.rtfd.io/en/latest/", "Source Code": "https://github.com/Lightning-AI/lightning", }, diff --git a/src/pytorch_lightning/README.md b/src/pytorch_lightning/README.md index ae9339dfb2b0d..f3fb8cb2fd2b3 100644 --- a/src/pytorch_lightning/README.md +++ b/src/pytorch_lightning/README.md @@ -79,7 +79,7 @@ Lightning is rigorously tested across multiple CPUs, GPUs and TPUs and against m | System / PyTorch ver. | 1.12 | 1.13 | 2.0 | 2.1 | | :--------------------------------: | :---------------------------------------------------------------------------------------------------------: | ----------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------- | -| Linux py3.9 \[GPUs\] | | | | ![Build Status](https://dev.azure.com/Lightning-AI/lightning/_apis/build/status%2Fpytorch-lightning%20%28GPUs%29) | +| Linux py3.9 [GPUs] | | | | ![Build Status](https://dev.azure.com/Lightning-AI/lightning/_apis/build/status%2Fpytorch-lightning%20%28GPUs%29) | | Linux (multiple Python versions) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | | OSX (multiple Python versions) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | | Windows (multiple Python versions) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | ![Test PyTorch](https://github.com/Lightning-AI/lightning/actions/workflows/ci-tests-pytorch.yml/badge.svg) | diff --git a/src/pytorch_lightning/__setup__.py b/src/pytorch_lightning/__setup__.py index 6677b469ba1de..97250404230b6 100644 --- a/src/pytorch_lightning/__setup__.py +++ b/src/pytorch_lightning/__setup__.py @@ -89,7 +89,7 @@ def _setup_args() -> dict[str, Any]: ), "extras_require": _prepare_extras(), "project_urls": { - "Bug Tracker": "https://github.com/Lightning-AI/lightning/issues", + "Bug Tracker": "https://github.com/Lightning-AI/pytorch-lightning/issues", "Documentation": "https://pytorch-lightning.rtfd.io/en/latest/", "Source Code": "https://github.com/Lightning-AI/lightning", }, diff --git a/src/version.info b/src/version.info index 797b505d19610..73462a5a13445 100644 --- a/src/version.info +++ b/src/version.info @@ -1 +1 @@ -2.5.0.post0 +2.5.1 diff --git a/tests/README.md b/tests/README.md index 8f015d3386fc3..9265caf4b412e 100644 --- a/tests/README.md +++ b/tests/README.md @@ -39,7 +39,7 @@ Note: if your computer does not have multi-GPU or TPU these tests are skipped. **GitHub Actions:** For convenience, you can also use your own GHActions building which will be triggered with each commit. This is useful if you do not test against all required dependency versions. -**Docker:** Another option is to utilize the [pytorch lightning cuda base docker image](https://hub.docker.com/repository/docker/pytorchlightning/pytorch_lightning/tags?page=1&name=cuda). You can then run: +**Docker:** Another option is to utilize the [pytorch lightning cuda base docker image](https://hub.docker.com/r/pytorchlightning/pytorch_lightning/tags?name=cuda). You can then run: ```bash python -m pytest src/lightning/pytorch tests/tests_pytorch -v @@ -64,9 +64,9 @@ You can rely on our CI to make sure all these tests pass. There are certain standalone tests, which you can run using: ```bash -./tests/run_standalone_tests.sh tests/tests_pytorch/trainer/ -# or run a specific test -./tests/run_standalone_tests.sh -k test_multi_gpu_model_ddp +cd tests/ +wget https://raw.githubusercontent.com/Lightning-AI/utilities/main/scripts/run_standalone_tests.sh +./tests/run_standalone_tests.sh tests_pytorch/ ``` ## Running Coverage diff --git a/tests/legacy/simple_classif_training.py b/tests/legacy/simple_classif_training.py index d2cf4cd2166f3..dd767ee9075f2 100644 --- a/tests/legacy/simple_classif_training.py +++ b/tests/legacy/simple_classif_training.py @@ -14,13 +14,14 @@ import os import sys -import lightning.pytorch as pl import torch -from lightning.pytorch import seed_everything -from lightning.pytorch.callbacks import EarlyStopping from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.simple_models import ClassificationModel +import lightning.pytorch as pl +from lightning.pytorch import seed_everything +from lightning.pytorch.callbacks import EarlyStopping + PATH_LEGACY = os.path.dirname(__file__) diff --git a/tests/parity_fabric/conftest.py b/tests/parity_fabric/conftest.py index ceb19e061c774..9fc6f9d908d81 100644 --- a/tests/parity_fabric/conftest.py +++ b/tests/parity_fabric/conftest.py @@ -17,7 +17,7 @@ import torch.distributed -@pytest.fixture() +@pytest.fixture def reset_deterministic_algorithm(): """Ensures that torch determinism settings are reset before the next test runs.""" yield @@ -25,7 +25,7 @@ def reset_deterministic_algorithm(): torch.use_deterministic_algorithms(False) -@pytest.fixture() +@pytest.fixture def reset_cudnn_benchmark(): """Ensures that the `torch.backends.cudnn.benchmark` setting gets reset before the next test runs.""" yield diff --git a/tests/parity_fabric/test_parity_ddp.py b/tests/parity_fabric/test_parity_ddp.py index 217d401ad6fba..d30d2b6233886 100644 --- a/tests/parity_fabric/test_parity_ddp.py +++ b/tests/parity_fabric/test_parity_ddp.py @@ -18,11 +18,11 @@ import torch import torch.distributed import torch.nn.functional -from lightning.fabric.fabric import Fabric from torch.nn.parallel.distributed import DistributedDataParallel from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler +from lightning.fabric.fabric import Fabric from parity_fabric.models import ConvNet from parity_fabric.utils import ( cuda_reset, diff --git a/tests/parity_fabric/test_parity_simple.py b/tests/parity_fabric/test_parity_simple.py index a97d39dfa1035..54c0de7297ac5 100644 --- a/tests/parity_fabric/test_parity_simple.py +++ b/tests/parity_fabric/test_parity_simple.py @@ -19,9 +19,9 @@ import torch import torch.distributed import torch.nn.functional -from lightning.fabric.fabric import Fabric from tests_fabric.helpers.runif import RunIf +from lightning.fabric.fabric import Fabric from parity_fabric.models import ConvNet from parity_fabric.utils import ( cuda_reset, diff --git a/tests/parity_fabric/utils.py b/tests/parity_fabric/utils.py index 7f0028dc23421..7d7a14732bb0e 100644 --- a/tests/parity_fabric/utils.py +++ b/tests/parity_fabric/utils.py @@ -14,6 +14,7 @@ import os import torch + from lightning.fabric.accelerators.cuda import _clear_cuda_memory diff --git a/tests/parity_pytorch/__init__.py b/tests/parity_pytorch/__init__.py index 148237ab9c718..6d7cadefc20fa 100644 --- a/tests/parity_pytorch/__init__.py +++ b/tests/parity_pytorch/__init__.py @@ -1,4 +1,5 @@ import pytest + from lightning.pytorch.utilities.testing import _runif_reasons diff --git a/tests/parity_pytorch/models.py b/tests/parity_pytorch/models.py index f55b0d6f1f36e..17cbef6a76faa 100644 --- a/tests/parity_pytorch/models.py +++ b/tests/parity_pytorch/models.py @@ -14,11 +14,12 @@ import torch import torch.nn.functional as F +from tests_pytorch import _PATH_DATASETS +from torch.utils.data import DataLoader + from lightning.pytorch.core.module import LightningModule from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE from lightning.pytorch.utilities.model_helpers import get_torchvision_model -from tests_pytorch import _PATH_DATASETS -from torch.utils.data import DataLoader if _TORCHVISION_AVAILABLE: from torchvision import transforms diff --git a/tests/parity_pytorch/test_basic_parity.py b/tests/parity_pytorch/test_basic_parity.py index d7f15815c831f..6b413dada8659 100644 --- a/tests/parity_pytorch/test_basic_parity.py +++ b/tests/parity_pytorch/test_basic_parity.py @@ -16,9 +16,9 @@ import numpy as np import pytest import torch -from lightning.pytorch import LightningModule, Trainer, seed_everything from tests_pytorch.helpers.advanced_models import ParityModuleMNIST, ParityModuleRNN +from lightning.pytorch import LightningModule, Trainer, seed_everything from parity_pytorch.measure import measure_loops from parity_pytorch.models import ParityModuleCIFAR diff --git a/tests/parity_pytorch/test_sync_batchnorm_parity.py b/tests/parity_pytorch/test_sync_batchnorm_parity.py index 4d2300cf15670..af22d7470e524 100644 --- a/tests/parity_pytorch/test_sync_batchnorm_parity.py +++ b/tests/parity_pytorch/test_sync_batchnorm_parity.py @@ -14,9 +14,9 @@ import torch import torch.nn as nn -from lightning.pytorch import LightningModule, Trainer, seed_everything from torch.utils.data import DataLoader, DistributedSampler +from lightning.pytorch import LightningModule, Trainer, seed_everything from parity_pytorch import RunIf diff --git a/tests/run_standalone_tests.sh b/tests/run_standalone_tests.sh deleted file mode 100755 index 75a52e16c57dc..0000000000000 --- a/tests/run_standalone_tests.sh +++ /dev/null @@ -1,117 +0,0 @@ -#!/bin/bash -# Copyright The Lightning AI team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -set -e -# THIS FILE ASSUMES IT IS RUN INSIDE THE tests/tests_ DIRECTORY - -# Batch size for testing: Determines how many standalone test invocations run in parallel -# It can be set through the env variable PL_STANDALONE_TESTS_BATCH_SIZE and defaults to 6 if not set -test_batch_size="${PL_STANDALONE_TESTS_BATCH_SIZE:-3}" -source="${PL_STANDALONE_TESTS_SOURCE:-"lightning"}" -# this is the directory where the tests are located -test_dir=$1 # parse the first argument -COLLECTED_TESTS_FILE="collected_tests.txt" - -ls -lh . # show the contents of the directory - -# this environment variable allows special tests to run -export PL_RUN_STANDALONE_TESTS=1 -# python arguments -defaults=" -m coverage run --source ${source} --append -m pytest --no-header -v -s --timeout 120 " -echo "Using defaults: ${defaults}" - -# get the list of parametrizations. we need to call them separately. the last two lines are removed. -# note: if there's a syntax error, this will fail with some garbled output -python3 -um pytest $test_dir -q --collect-only --pythonwarnings ignore 2>&1 > $COLLECTED_TESTS_FILE -# early terminate if collection failed (e.g. syntax error) -if [[ $? != 0 ]]; then - cat $COLLECTED_TESTS_FILE - exit 1 -fi - -# removes the last line of the file -sed -i '$d' $COLLECTED_TESTS_FILE - -# Get test list and run each test individually -tests=($(grep -oP '\S+::test_\S+' "$COLLECTED_TESTS_FILE")) -test_count=${#tests[@]} -# present the collected tests -printf "collected $test_count tests:\n-------------------\n" -# replace space with new line -echo "${tests[@]}" | tr ' ' '\n' -printf "\n===================\n" - -# if test count is one print warning -if [[ $test_count -eq 1 ]]; then - printf "WARNING: only one test found!\n" -elif [ $test_count -eq 0 ]; then - printf "ERROR: no tests found!\n" - exit 1 -fi - -# clear all the collected reports -rm -f parallel_test_output-*.txt # in case it exists, remove it - - -status=0 # reset the script status -report="" # final report -pids=() # array of PID for running tests -test_ids=() # array of indexes of running tests -printf "Running $test_count tests in batches of $test_batch_size\n" -for i in "${!tests[@]}"; do - # remove initial "tests/" from the test name - test=${tests[$i]/tests\//} - printf "Running test $((i+1))/$test_count: $test\n" - - # execute the test in the background - # redirect to a log file that buffers test output. since the tests will run in the background, - # we cannot let them output to std{out,err} because the outputs would be garbled together - python3 ${defaults} "$test" 2>&1 > "standalone_test_output-$i.txt" & - test_ids+=($i) # save the test's id in an array with running tests - pids+=($!) # save the PID in an array with running tests - - # if we reached the batch size, wait for all tests to finish - if (( (($i + 1) % $test_batch_size == 0) || $i == $test_count-1 )); then - printf "Waiting for batch to finish: $(IFS=' '; echo "${pids[@]}")\n" - # wait for running tests - for j in "${!test_ids[@]}"; do - i=${test_ids[$j]} # restore the global test's id - pid=${pids[$j]} # restore the particular PID - test=${tests[$i]} # restore the test name - printf "Waiting for $tests >> standalone_test_output-$i.txt (PID: $pid)\n" - wait -n $pid - # get the exit status of the test - test_status=$? - # add row to the final report - report+="Ran\t$test\t>> exit:$test_status\n" - if [[ $test_status != 0 ]]; then - # show the output of the failed test - cat "standalone_test_output-$i.txt" - # Process exited with a non-zero exit status - status=$test_status - fi - done - test_ids=() # reset the test's id array - pids=() # reset the PID array - fi -done - -# echo test report -printf '=%.s' {1..80} -printf "\n$report" -printf '=%.s' {1..80} -printf '\n' - -# exit with the worst test result -exit $status diff --git a/tests/tests_fabric/accelerators/test_cpu.py b/tests/tests_fabric/accelerators/test_cpu.py index 5efb5d6afddbc..7c7029f9ec9e3 100644 --- a/tests/tests_fabric/accelerators/test_cpu.py +++ b/tests/tests_fabric/accelerators/test_cpu.py @@ -14,6 +14,7 @@ import pytest import torch + from lightning.fabric.accelerators.cpu import CPUAccelerator, _parse_cpu_cores diff --git a/tests/tests_fabric/accelerators/test_cuda.py b/tests/tests_fabric/accelerators/test_cuda.py index 0aed3675d93e1..037eb8d400825 100644 --- a/tests/tests_fabric/accelerators/test_cuda.py +++ b/tests/tests_fabric/accelerators/test_cuda.py @@ -18,15 +18,15 @@ from unittest import mock from unittest.mock import Mock -import lightning.fabric import pytest import torch + +import lightning.fabric from lightning.fabric.accelerators.cuda import ( CUDAAccelerator, _check_cuda_matmul_precision, find_usable_cuda_devices, ) - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/accelerators/test_mps.py b/tests/tests_fabric/accelerators/test_mps.py index 20dc9b6a93581..612bd8f74a640 100644 --- a/tests/tests_fabric/accelerators/test_mps.py +++ b/tests/tests_fabric/accelerators/test_mps.py @@ -13,9 +13,9 @@ # limitations under the License. import pytest import torch + from lightning.fabric.accelerators.mps import MPSAccelerator from lightning.fabric.utilities.exceptions import MisconfigurationException - from tests_fabric.helpers.runif import RunIf _MAYBE_MPS = "mps" if MPSAccelerator.is_available() else "cpu" diff --git a/tests/tests_fabric/accelerators/test_registry.py b/tests/tests_fabric/accelerators/test_registry.py index 2544df1e01ff8..28bfbb8ffd97c 100644 --- a/tests/tests_fabric/accelerators/test_registry.py +++ b/tests/tests_fabric/accelerators/test_registry.py @@ -14,6 +14,7 @@ from typing import Any import torch + from lightning.fabric.accelerators import ACCELERATOR_REGISTRY, Accelerator diff --git a/tests/tests_fabric/accelerators/test_xla.py b/tests/tests_fabric/accelerators/test_xla.py index 7a906c8ae0c54..95d5cab90a5f2 100644 --- a/tests/tests_fabric/accelerators/test_xla.py +++ b/tests/tests_fabric/accelerators/test_xla.py @@ -13,8 +13,8 @@ # limitations under the License import pytest -from lightning.fabric.accelerators.xla import _XLA_AVAILABLE, XLAAccelerator +from lightning.fabric.accelerators.xla import _XLA_AVAILABLE, XLAAccelerator from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/conftest.py b/tests/tests_fabric/conftest.py index dd272257b3923..68f3f2cc38191 100644 --- a/tests/tests_fabric/conftest.py +++ b/tests/tests_fabric/conftest.py @@ -18,9 +18,10 @@ from pathlib import Path from unittest.mock import Mock -import lightning.fabric import pytest import torch.distributed + +import lightning.fabric from lightning.fabric.accelerators import XLAAccelerator from lightning.fabric.strategies.launchers.subprocess_script import _ChildProcessObserver from lightning.fabric.utilities.distributed import _destroy_dist_connection @@ -29,9 +30,10 @@ @pytest.fixture(autouse=True) def preserve_global_rank_variable(): """Ensures that the rank_zero_only.rank global variable gets reset in each test.""" - from lightning.fabric.utilities.rank_zero import rank_zero_only as rank_zero_only_fabric from lightning_utilities.core.rank_zero import rank_zero_only as rank_zero_only_utilities + from lightning.fabric.utilities.rank_zero import rank_zero_only as rank_zero_only_fabric + functions = (rank_zero_only_fabric, rank_zero_only_utilities) ranks = [getattr(fn, "rank", None) for fn in functions] yield @@ -126,7 +128,7 @@ def reset_in_fabric_backward(): wrappers._in_fabric_backward = False -@pytest.fixture() +@pytest.fixture def reset_deterministic_algorithm(): """Ensures that torch determinism settings are reset before the next test runs.""" yield @@ -134,7 +136,7 @@ def reset_deterministic_algorithm(): torch.use_deterministic_algorithms(False) -@pytest.fixture() +@pytest.fixture def reset_cudnn_benchmark(): """Ensures that the `torch.backends.cudnn.benchmark` setting gets reset before the next test runs.""" yield @@ -155,7 +157,7 @@ def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N monkeypatch.setitem(sys.modules, "torch_xla.distributed.fsdp.wrap", Mock()) -@pytest.fixture() +@pytest.fixture def xla_available(monkeypatch: pytest.MonkeyPatch) -> None: mock_xla_available(monkeypatch) @@ -166,12 +168,12 @@ def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N monkeypatch.setattr(lightning.fabric.accelerators.xla.XLAAccelerator, "auto_device_count", lambda *_: 8) -@pytest.fixture() +@pytest.fixture def tpu_available(monkeypatch: pytest.MonkeyPatch) -> None: mock_tpu_available(monkeypatch) -@pytest.fixture() +@pytest.fixture def caplog(caplog): """Workaround for https://github.com/pytest-dev/pytest/issues/3697. @@ -189,12 +191,16 @@ def caplog(caplog): @pytest.fixture(autouse=True) def leave_no_artifacts_behind(): + """Checks that no artifacts are left behind after the test.""" tests_root = Path(__file__).parent.parent + # Ignore the __pycache__ directories files_before = {p for p in tests_root.rglob("*") if "__pycache__" not in p.parts} yield files_after = {p for p in tests_root.rglob("*") if "__pycache__" not in p.parts} difference = files_after - files_before difference = {str(f.relative_to(tests_root)) for f in difference} + # ignore the .coverage files + difference = {f for f in difference if not f.endswith(".coverage")} assert not difference, f"Test left artifacts behind: {difference}" diff --git a/tests/tests_fabric/helpers/runif.py b/tests/tests_fabric/helpers/runif.py index 813c4f93b1788..23a620295bcbf 100644 --- a/tests/tests_fabric/helpers/runif.py +++ b/tests/tests_fabric/helpers/runif.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest + from lightning.fabric.utilities.testing import _runif_reasons diff --git a/tests/tests_fabric/loggers/test_csv.py b/tests/tests_fabric/loggers/test_csv.py index 784bb6fb45aba..08ed3990c2435 100644 --- a/tests/tests_fabric/loggers/test_csv.py +++ b/tests/tests_fabric/loggers/test_csv.py @@ -17,6 +17,7 @@ import pytest import torch + from lightning.fabric.loggers import CSVLogger from lightning.fabric.loggers.csv_logs import _ExperimentWriter @@ -147,7 +148,7 @@ def test_automatic_step_tracking(tmp_path): @mock.patch( - # Mock the existance check, so we can simulate appending to the metrics file + # Mock the existence check, so we can simulate appending to the metrics file "lightning.fabric.loggers.csv_logs._ExperimentWriter._check_log_dir_exists" ) def test_append_metrics_file(_, tmp_path): @@ -188,7 +189,7 @@ def test_append_columns(tmp_path): @mock.patch( - # Mock the existance check, so we can simulate appending to the metrics file + # Mock the existence check, so we can simulate appending to the metrics file "lightning.fabric.loggers.csv_logs._ExperimentWriter._check_log_dir_exists" ) def test_rewrite_with_new_header(_, tmp_path): diff --git a/tests/tests_fabric/loggers/test_tensorboard.py b/tests/tests_fabric/loggers/test_tensorboard.py index fa685241ea1b5..4dcb86f0e7406 100644 --- a/tests/tests_fabric/loggers/test_tensorboard.py +++ b/tests/tests_fabric/loggers/test_tensorboard.py @@ -19,10 +19,10 @@ import numpy as np import pytest import torch + from lightning.fabric.loggers import TensorBoardLogger from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE from lightning.fabric.wrappers import _FabricModule - from tests_fabric.test_fabric import BoringModel diff --git a/tests/tests_fabric/plugins/collectives/test_single_device.py b/tests/tests_fabric/plugins/collectives/test_single_device.py index e7aefdb6078b1..a5a909da64b8a 100644 --- a/tests/tests_fabric/plugins/collectives/test_single_device.py +++ b/tests/tests_fabric/plugins/collectives/test_single_device.py @@ -1,6 +1,7 @@ from unittest import mock import pytest + from lightning.fabric.plugins.collectives import SingleDeviceCollective diff --git a/tests/tests_fabric/plugins/collectives/test_torch_collective.py b/tests/tests_fabric/plugins/collectives/test_torch_collective.py index b4c223e770282..5cef70f3b91ba 100644 --- a/tests/tests_fabric/plugins/collectives/test_torch_collective.py +++ b/tests/tests_fabric/plugins/collectives/test_torch_collective.py @@ -6,12 +6,12 @@ import pytest import torch + from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator from lightning.fabric.plugins.collectives import TorchCollective from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies.ddp import DDPStrategy from lightning.fabric.strategies.launchers.multiprocessing import _MultiProcessingLauncher - from tests_fabric.helpers.runif import RunIf if TorchCollective.is_available(): diff --git a/tests/tests_fabric/plugins/environments/test_kubeflow.py b/tests/tests_fabric/plugins/environments/test_kubeflow.py index 3c44273825510..3436adc9ce2aa 100644 --- a/tests/tests_fabric/plugins/environments/test_kubeflow.py +++ b/tests/tests_fabric/plugins/environments/test_kubeflow.py @@ -16,6 +16,7 @@ from unittest import mock import pytest + from lightning.fabric.plugins.environments import KubeflowEnvironment diff --git a/tests/tests_fabric/plugins/environments/test_lightning.py b/tests/tests_fabric/plugins/environments/test_lightning.py index 02100329d173c..cc0179af4c3f5 100644 --- a/tests/tests_fabric/plugins/environments/test_lightning.py +++ b/tests/tests_fabric/plugins/environments/test_lightning.py @@ -15,6 +15,7 @@ from unittest import mock import pytest + from lightning.fabric.plugins.environments import LightningEnvironment diff --git a/tests/tests_fabric/plugins/environments/test_lsf.py b/tests/tests_fabric/plugins/environments/test_lsf.py index 4e60d968dc953..31cc5976cfe09 100644 --- a/tests/tests_fabric/plugins/environments/test_lsf.py +++ b/tests/tests_fabric/plugins/environments/test_lsf.py @@ -15,6 +15,7 @@ from unittest import mock import pytest + from lightning.fabric.plugins.environments import LSFEnvironment diff --git a/tests/tests_fabric/plugins/environments/test_mpi.py b/tests/tests_fabric/plugins/environments/test_mpi.py index 649d4dcb1dab2..3df0000cf2766 100644 --- a/tests/tests_fabric/plugins/environments/test_mpi.py +++ b/tests/tests_fabric/plugins/environments/test_mpi.py @@ -16,8 +16,9 @@ from unittest import mock from unittest.mock import MagicMock -import lightning.fabric.plugins.environments.mpi import pytest + +import lightning.fabric.plugins.environments.mpi from lightning.fabric.plugins.environments import MPIEnvironment diff --git a/tests/tests_fabric/plugins/environments/test_slurm.py b/tests/tests_fabric/plugins/environments/test_slurm.py index f237478a533f4..75ca43577d579 100644 --- a/tests/tests_fabric/plugins/environments/test_slurm.py +++ b/tests/tests_fabric/plugins/environments/test_slurm.py @@ -18,10 +18,10 @@ from unittest import mock import pytest -from lightning.fabric.plugins.environments import SLURMEnvironment -from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning_utilities.test.warning import no_warning_call +from lightning.fabric.plugins.environments import SLURMEnvironment +from lightning.fabric.utilities.warnings import PossibleUserWarning from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/plugins/environments/test_torchelastic.py b/tests/tests_fabric/plugins/environments/test_torchelastic.py index 3cf8619a00c22..161d42894df30 100644 --- a/tests/tests_fabric/plugins/environments/test_torchelastic.py +++ b/tests/tests_fabric/plugins/environments/test_torchelastic.py @@ -17,6 +17,7 @@ from unittest import mock import pytest + from lightning.fabric.plugins.environments import TorchElasticEnvironment diff --git a/tests/tests_fabric/plugins/environments/test_xla.py b/tests/tests_fabric/plugins/environments/test_xla.py index 7a3610b65b4bb..7e33d5db87dd4 100644 --- a/tests/tests_fabric/plugins/environments/test_xla.py +++ b/tests/tests_fabric/plugins/environments/test_xla.py @@ -14,11 +14,11 @@ import os from unittest import mock -import lightning.fabric import pytest + +import lightning.fabric from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1 from lightning.fabric.plugins.environments import XLAEnvironment - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/plugins/precision/test_all.py b/tests/tests_fabric/plugins/precision/test_all.py index 5e86a35647489..94e5efaa74eed 100644 --- a/tests/tests_fabric/plugins/precision/test_all.py +++ b/tests/tests_fabric/plugins/precision/test_all.py @@ -1,5 +1,6 @@ import pytest import torch + from lightning.fabric.plugins import DeepSpeedPrecision, DoublePrecision, FSDPPrecision, HalfPrecision diff --git a/tests/tests_fabric/plugins/precision/test_amp.py b/tests/tests_fabric/plugins/precision/test_amp.py index 93d53eb406f71..73507f085936b 100644 --- a/tests/tests_fabric/plugins/precision/test_amp.py +++ b/tests/tests_fabric/plugins/precision/test_amp.py @@ -16,6 +16,7 @@ import pytest import torch + from lightning.fabric.plugins.precision.amp import MixedPrecision from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 diff --git a/tests/tests_fabric/plugins/precision/test_amp_integration.py b/tests/tests_fabric/plugins/precision/test_amp_integration.py index aa6c6cfce4504..bcbd9435d47ac 100644 --- a/tests/tests_fabric/plugins/precision/test_amp_integration.py +++ b/tests/tests_fabric/plugins/precision/test_amp_integration.py @@ -16,9 +16,9 @@ import pytest import torch import torch.nn as nn + from lightning.fabric import Fabric, seed_everything from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py index b8b9020b201a7..f529b631d2374 100644 --- a/tests/tests_fabric/plugins/precision/test_bitsandbytes.py +++ b/tests/tests_fabric/plugins/precision/test_bitsandbytes.py @@ -11,23 +11,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License +import platform import sys from unittest.mock import Mock -import lightning.fabric import pytest import torch import torch.distributed + +import lightning.fabric from lightning.fabric import Fabric from lightning.fabric.connector import _Connector from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE, BitsandbytesPrecision from lightning.fabric.utilities.init import _materialize_meta_tensors from lightning.fabric.utilities.load import _lazy_load - from tests_fabric.helpers.runif import RunIf @pytest.mark.skipif(_BITSANDBYTES_AVAILABLE, reason="bitsandbytes needs to be unavailable") +@pytest.mark.skipif(platform.system() == "Darwin", reason="Bitsandbytes is only supported on CUDA GPUs") # skip on Mac def test_bitsandbytes_plugin(monkeypatch): module = lightning.fabric.plugins.precision.bitsandbytes monkeypatch.setattr(module, "_BITSANDBYTES_AVAILABLE", lambda: True) @@ -95,6 +97,7 @@ def __init__(self): @RunIf(min_cuda_gpus=1, max_torch="2.4") +@pytest.mark.filterwarnings("ignore::FutureWarning") @pytest.mark.skipif(not _BITSANDBYTES_AVAILABLE, reason="bitsandbytes unavailable") @pytest.mark.parametrize( ("args", "expected"), diff --git a/tests/tests_fabric/plugins/precision/test_deepspeed.py b/tests/tests_fabric/plugins/precision/test_deepspeed.py index 248f616646842..170f15afaa2ad 100644 --- a/tests/tests_fabric/plugins/precision/test_deepspeed.py +++ b/tests/tests_fabric/plugins/precision/test_deepspeed.py @@ -16,9 +16,9 @@ import pytest import torch + from lightning.fabric.plugins.precision.deepspeed import DeepSpeedPrecision from lightning.fabric.utilities.types import Steppable - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py b/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py index 27e44398fc095..e989534343b8c 100644 --- a/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py +++ b/tests/tests_fabric/plugins/precision/test_deepspeed_integration.py @@ -14,10 +14,10 @@ from unittest import mock import pytest + from lightning.fabric.connector import _Connector from lightning.fabric.plugins import DeepSpeedPrecision from lightning.fabric.strategies import DeepSpeedStrategy - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/plugins/precision/test_double.py b/tests/tests_fabric/plugins/precision/test_double.py index 97d4a2303100e..4921e0f4e659b 100644 --- a/tests/tests_fabric/plugins/precision/test_double.py +++ b/tests/tests_fabric/plugins/precision/test_double.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch + from lightning.fabric.plugins.precision.double import DoublePrecision diff --git a/tests/tests_fabric/plugins/precision/test_double_integration.py b/tests/tests_fabric/plugins/precision/test_double_integration.py index 6701bc1a80e59..8f96f75ad1c67 100644 --- a/tests/tests_fabric/plugins/precision/test_double_integration.py +++ b/tests/tests_fabric/plugins/precision/test_double_integration.py @@ -15,8 +15,8 @@ import torch import torch.nn as nn -from lightning.fabric import Fabric +from lightning.fabric import Fabric from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/plugins/precision/test_fsdp.py b/tests/tests_fabric/plugins/precision/test_fsdp.py index 6a4968736ea86..3b8d916e20c8f 100644 --- a/tests/tests_fabric/plugins/precision/test_fsdp.py +++ b/tests/tests_fabric/plugins/precision/test_fsdp.py @@ -15,9 +15,9 @@ import pytest import torch + from lightning.fabric.plugins import FSDPPrecision from lightning.fabric.plugins.precision.utils import _DtypeContextManager - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/plugins/precision/test_half.py b/tests/tests_fabric/plugins/precision/test_half.py index 4037feebbd178..00d23df4ae5b6 100644 --- a/tests/tests_fabric/plugins/precision/test_half.py +++ b/tests/tests_fabric/plugins/precision/test_half.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest import torch + from lightning.fabric.plugins.precision import HalfPrecision diff --git a/tests/tests_fabric/plugins/precision/test_transformer_engine.py b/tests/tests_fabric/plugins/precision/test_transformer_engine.py index c003715dead8a..033484aca9c90 100644 --- a/tests/tests_fabric/plugins/precision/test_transformer_engine.py +++ b/tests/tests_fabric/plugins/precision/test_transformer_engine.py @@ -14,10 +14,11 @@ import sys from unittest.mock import Mock -import lightning.fabric import pytest import torch import torch.distributed + +import lightning.fabric from lightning.fabric.connector import _Connector from lightning.fabric.plugins.precision.transformer_engine import TransformerEnginePrecision diff --git a/tests/tests_fabric/plugins/precision/test_utils.py b/tests/tests_fabric/plugins/precision/test_utils.py index 74899c86e9e2d..6e459c3fe6637 100644 --- a/tests/tests_fabric/plugins/precision/test_utils.py +++ b/tests/tests_fabric/plugins/precision/test_utils.py @@ -1,5 +1,6 @@ import pytest import torch + from lightning.fabric.plugins.precision.utils import _ClassReplacementContextManager, _DtypeContextManager diff --git a/tests/tests_fabric/plugins/precision/test_xla.py b/tests/tests_fabric/plugins/precision/test_xla.py index 0cdc11b00b99a..cfdc32112a957 100644 --- a/tests/tests_fabric/plugins/precision/test_xla.py +++ b/tests/tests_fabric/plugins/precision/test_xla.py @@ -17,6 +17,7 @@ import pytest import torch + from lightning.fabric.plugins import XLAPrecision diff --git a/tests/tests_fabric/plugins/precision/test_xla_integration.py b/tests/tests_fabric/plugins/precision/test_xla_integration.py index 14a5cd1442e4a..75dede49e2fe0 100644 --- a/tests/tests_fabric/plugins/precision/test_xla_integration.py +++ b/tests/tests_fabric/plugins/precision/test_xla_integration.py @@ -17,9 +17,9 @@ import pytest import torch import torch.nn as nn + from lightning.fabric import Fabric from lightning.fabric.plugins import XLAPrecision - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/launchers/test_multiprocessing.py b/tests/tests_fabric/strategies/launchers/test_multiprocessing.py index 6c595fba7acab..5bb85e070f17d 100644 --- a/tests/tests_fabric/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_fabric/strategies/launchers/test_multiprocessing.py @@ -16,8 +16,8 @@ import pytest import torch -from lightning.fabric.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher +from lightning.fabric.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py b/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py index c3bc7d3c2c6cd..6ae96b9bcafc6 100644 --- a/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py +++ b/tests/tests_fabric/strategies/launchers/test_multiprocessing_integration.py @@ -15,8 +15,8 @@ import pytest import torch import torch.nn as nn -from lightning.fabric import Fabric +from lightning.fabric import Fabric from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/launchers/test_subprocess_script.py b/tests/tests_fabric/strategies/launchers/test_subprocess_script.py index a2d04e29bc0d6..70587ca2877a6 100644 --- a/tests/tests_fabric/strategies/launchers/test_subprocess_script.py +++ b/tests/tests_fabric/strategies/launchers/test_subprocess_script.py @@ -17,8 +17,9 @@ from unittest import mock from unittest.mock import ANY, Mock -import lightning.fabric import pytest + +import lightning.fabric from lightning.fabric.strategies.launchers.subprocess_script import ( _HYDRA_AVAILABLE, _ChildProcessObserver, diff --git a/tests/tests_fabric/strategies/test_ddp.py b/tests/tests_fabric/strategies/test_ddp.py index b98d5f8226dc2..fa5c975228a5e 100644 --- a/tests/tests_fabric/strategies/test_ddp.py +++ b/tests/tests_fabric/strategies/test_ddp.py @@ -19,12 +19,12 @@ import pytest import torch +from torch.nn.parallel import DistributedDataParallel + from lightning.fabric.plugins import DoublePrecision, HalfPrecision, Precision from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies import DDPStrategy from lightning.fabric.strategies.ddp import _DDPBackwardSyncControl -from torch.nn.parallel import DistributedDataParallel - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/test_ddp_integration.py b/tests/tests_fabric/strategies/test_ddp_integration.py index a7ed09b00b09e..3ed76211e5d6d 100644 --- a/tests/tests_fabric/strategies/test_ddp_integration.py +++ b/tests/tests_fabric/strategies/test_ddp_integration.py @@ -18,11 +18,12 @@ import pytest import torch -from lightning.fabric import Fabric from lightning_utilities.core.imports import RequirementCache from torch._dynamo import OptimizedModule from torch.nn.parallel.distributed import DistributedDataParallel +from lightning.fabric import Fabric +from lightning.fabric.utilities.imports import _TORCH_LESS_EQUAL_2_6 from tests_fabric.helpers.runif import RunIf from tests_fabric.strategies.test_single_device import _run_test_clip_gradients from tests_fabric.test_fabric import BoringModel @@ -84,7 +85,9 @@ def test_reapply_compile(): fabric.launch() model = BoringModel() - compile_kwargs = {"mode": "reduce-overhead"} + # currently (PyTorch 2.6) using ruduce-overhead here casues a RuntimeError: + # Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. + compile_kwargs = {"mode": "reduce-overhead"} if _TORCH_LESS_EQUAL_2_6 else {} compiled_model = torch.compile(model, **compile_kwargs) torch.compile.reset_mock() diff --git a/tests/tests_fabric/strategies/test_deepspeed.py b/tests/tests_fabric/strategies/test_deepspeed.py index 4ee87b265b086..032ee63cd4721 100644 --- a/tests/tests_fabric/strategies/test_deepspeed.py +++ b/tests/tests_fabric/strategies/test_deepspeed.py @@ -18,14 +18,14 @@ import pytest import torch -from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator -from lightning.fabric.strategies import DeepSpeedStrategy from torch.optim import Optimizer +from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator +from lightning.fabric.strategies import DeepSpeedStrategy from tests_fabric.helpers.runif import RunIf -@pytest.fixture() +@pytest.fixture def deepspeed_config(): return { "optimizer": {"type": "SGD", "params": {"lr": 3e-5}}, @@ -36,7 +36,7 @@ def deepspeed_config(): } -@pytest.fixture() +@pytest.fixture def deepspeed_zero_config(deepspeed_config): return {**deepspeed_config, "zero_allow_untested_optimizer": True, "zero_optimization": {"stage": 2}} diff --git a/tests/tests_fabric/strategies/test_deepspeed_integration.py b/tests/tests_fabric/strategies/test_deepspeed_integration.py index 4811599ed05ab..5970b673cee5f 100644 --- a/tests/tests_fabric/strategies/test_deepspeed_integration.py +++ b/tests/tests_fabric/strategies/test_deepspeed_integration.py @@ -20,11 +20,11 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.utils.data import DataLoader + from lightning.fabric import Fabric from lightning.fabric.plugins import DeepSpeedPrecision from lightning.fabric.strategies import DeepSpeedStrategy -from torch.utils.data import DataLoader - from tests_fabric.helpers.datasets import RandomDataset, RandomIterableDataset from tests_fabric.helpers.runif import RunIf from tests_fabric.test_fabric import BoringModel diff --git a/tests/tests_fabric/strategies/test_dp.py b/tests/tests_fabric/strategies/test_dp.py index e50abb1882870..ff470c646f4b3 100644 --- a/tests/tests_fabric/strategies/test_dp.py +++ b/tests/tests_fabric/strategies/test_dp.py @@ -16,9 +16,9 @@ import pytest import torch + from lightning.fabric import Fabric from lightning.fabric.strategies import DataParallelStrategy - from tests_fabric.helpers.runif import RunIf from tests_fabric.strategies.test_single_device import _run_test_clip_gradients diff --git a/tests/tests_fabric/strategies/test_fsdp.py b/tests/tests_fabric/strategies/test_fsdp.py index cb6542cdb6243..d5f82752a9176 100644 --- a/tests/tests_fabric/strategies/test_fsdp.py +++ b/tests/tests_fabric/strategies/test_fsdp.py @@ -19,6 +19,10 @@ import pytest import torch import torch.nn as nn +from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from torch.optim import Adam + from lightning.fabric.plugins import HalfPrecision from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies import FSDPStrategy @@ -28,9 +32,6 @@ _is_sharded_checkpoint, ) from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 -from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision -from torch.distributed.fsdp.wrap import ModuleWrapPolicy -from torch.optim import Adam def test_custom_mixed_precision(): diff --git a/tests/tests_fabric/strategies/test_fsdp_integration.py b/tests/tests_fabric/strategies/test_fsdp_integration.py index 0697c3043d496..576a0df38b966 100644 --- a/tests/tests_fabric/strategies/test_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_fsdp_integration.py @@ -20,17 +20,18 @@ import pytest import torch import torch.nn as nn -from lightning.fabric import Fabric -from lightning.fabric.plugins import FSDPPrecision -from lightning.fabric.strategies import FSDPStrategy -from lightning.fabric.utilities.load import _load_distributed_checkpoint -from lightning.fabric.wrappers import _FabricOptimizer from torch._dynamo import OptimizedModule from torch.distributed.fsdp import FlatParameter, FullyShardedDataParallel, OptimStateKeyType from torch.distributed.fsdp.wrap import always_wrap_policy, wrap from torch.nn import Parameter from torch.utils.data import DataLoader +from lightning.fabric import Fabric +from lightning.fabric.plugins import FSDPPrecision +from lightning.fabric.strategies import FSDPStrategy +from lightning.fabric.utilities.imports import _TORCH_LESS_EQUAL_2_6 +from lightning.fabric.utilities.load import _load_distributed_checkpoint +from lightning.fabric.wrappers import _FabricOptimizer from tests_fabric.helpers.datasets import RandomDataset from tests_fabric.helpers.runif import RunIf from tests_fabric.test_fabric import BoringModel @@ -411,7 +412,9 @@ def test_reapply_compile(): fabric.launch() model = BoringModel() - compile_kwargs = {"mode": "reduce-overhead"} + # currently (PyTorch 2.6) using ruduce-overhead here casues a RuntimeError: + # Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. + compile_kwargs = {"mode": "reduce-overhead"} if _TORCH_LESS_EQUAL_2_6 else {} compiled_model = torch.compile(model, **compile_kwargs) torch.compile.reset_mock() diff --git a/tests/tests_fabric/strategies/test_model_parallel.py b/tests/tests_fabric/strategies/test_model_parallel.py index 1f8b5b783b73e..78622adf66fa6 100644 --- a/tests/tests_fabric/strategies/test_model_parallel.py +++ b/tests/tests_fabric/strategies/test_model_parallel.py @@ -19,12 +19,12 @@ import pytest import torch import torch.nn as nn +from torch.optim import Adam + from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies import ModelParallelStrategy from lightning.fabric.strategies.fsdp import _is_sharded_checkpoint from lightning.fabric.strategies.model_parallel import _ParallelBackwardSyncControl -from torch.optim import Adam - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/test_model_parallel_integration.py b/tests/tests_fabric/strategies/test_model_parallel_integration.py index b04a29b691529..bddfadd9a2c54 100644 --- a/tests/tests_fabric/strategies/test_model_parallel_integration.py +++ b/tests/tests_fabric/strategies/test_model_parallel_integration.py @@ -20,16 +20,16 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.utils.data import DataLoader, DistributedSampler + from lightning.fabric import Fabric from lightning.fabric.strategies.model_parallel import ModelParallelStrategy, _load_raw_module_state from lightning.fabric.utilities.load import _load_distributed_checkpoint -from torch.utils.data import DataLoader, DistributedSampler - from tests_fabric.helpers.datasets import RandomDataset from tests_fabric.helpers.runif import RunIf -@pytest.fixture() +@pytest.fixture def distributed(): yield if torch.distributed.is_initialized(): diff --git a/tests/tests_fabric/strategies/test_single_device.py b/tests/tests_fabric/strategies/test_single_device.py index 95ed9787f40a2..fff7175909222 100644 --- a/tests/tests_fabric/strategies/test_single_device.py +++ b/tests/tests_fabric/strategies/test_single_device.py @@ -15,9 +15,9 @@ import pytest import torch + from lightning.fabric import Fabric from lightning.fabric.strategies import SingleDeviceStrategy - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/test_strategy.py b/tests/tests_fabric/strategies/test_strategy.py index a7a1dba87cb97..37ccbea5e6c95 100644 --- a/tests/tests_fabric/strategies/test_strategy.py +++ b/tests/tests_fabric/strategies/test_strategy.py @@ -16,10 +16,10 @@ import pytest import torch + from lightning.fabric.plugins import DoublePrecision, HalfPrecision, Precision from lightning.fabric.strategies import SingleDeviceStrategy from lightning.fabric.utilities.types import _Stateful - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/test_xla.py b/tests/tests_fabric/strategies/test_xla.py index f711eb3470b45..a260b3f231e1d 100644 --- a/tests/tests_fabric/strategies/test_xla.py +++ b/tests/tests_fabric/strategies/test_xla.py @@ -18,13 +18,13 @@ import pytest import torch +from torch.utils.data import DataLoader + from lightning.fabric.accelerators.xla import _XLA_GREATER_EQUAL_2_1, XLAAccelerator from lightning.fabric.strategies import XLAStrategy from lightning.fabric.strategies.launchers.xla import _XLALauncher from lightning.fabric.utilities.distributed import ReduceOp from lightning.fabric.utilities.seed import seed_everything -from torch.utils.data import DataLoader - from tests_fabric.helpers.datasets import RandomDataset from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/test_xla_fsdp.py b/tests/tests_fabric/strategies/test_xla_fsdp.py index 879a55cf77f34..c2634283ad110 100644 --- a/tests/tests_fabric/strategies/test_xla_fsdp.py +++ b/tests/tests_fabric/strategies/test_xla_fsdp.py @@ -18,12 +18,12 @@ import pytest import torch.nn import torch.nn as nn +from torch.optim import Adam + from lightning.fabric.accelerators import XLAAccelerator from lightning.fabric.plugins import XLAPrecision from lightning.fabric.strategies import XLAFSDPStrategy from lightning.fabric.strategies.xla_fsdp import _activation_checkpointing_auto_wrapper, _XLAFSDPBackwardSyncControl -from torch.optim import Adam - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/strategies/test_xla_fsdp_integration.py b/tests/tests_fabric/strategies/test_xla_fsdp_integration.py index 20c2ef042272e..b77803744b6c4 100644 --- a/tests/tests_fabric/strategies/test_xla_fsdp_integration.py +++ b/tests/tests_fabric/strategies/test_xla_fsdp_integration.py @@ -18,10 +18,10 @@ import pytest import torch -from lightning.fabric import Fabric -from lightning.fabric.strategies import XLAFSDPStrategy from torch.utils.data import DataLoader +from lightning.fabric import Fabric +from lightning.fabric.strategies import XLAFSDPStrategy from tests_fabric.helpers.datasets import RandomDataset from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/test_cli.py b/tests/tests_fabric/test_cli.py index a57f413ff6081..944584114184b 100644 --- a/tests/tests_fabric/test_cli.py +++ b/tests/tests_fabric/test_cli.py @@ -20,12 +20,12 @@ from unittest.mock import Mock import pytest -from lightning.fabric.cli import _consolidate, _get_supported_strategies, _run +from lightning.fabric.cli import _consolidate, _get_supported_strategies, _run from tests_fabric.helpers.runif import RunIf -@pytest.fixture() +@pytest.fixture def fake_script(tmp_path): script = tmp_path / "script.py" script.touch() @@ -179,19 +179,6 @@ def test_run_through_fabric_entry_point(): assert message in result.stdout or message in result.stderr -@pytest.mark.skipif("lightning.fabric" == "lightning_fabric", reason="standalone package") -def test_run_through_lightning_entry_point(): - result = subprocess.run("lightning run model --help", capture_output=True, text=True, shell=True) - - deprecation_message = ( - "`lightning run model` is deprecated and will be removed in future versions. " - "Please call `fabric run` instead" - ) - message = "Usage: lightning run [OPTIONS] SCRIPT [SCRIPT_ARGS]" - assert deprecation_message in result.stdout - assert message in result.stdout or message in result.stderr - - @mock.patch("lightning.fabric.cli._process_cli_args") @mock.patch("lightning.fabric.cli._load_distributed_checkpoint") @mock.patch("lightning.fabric.cli.torch.save") diff --git a/tests/tests_fabric/test_connector.py b/tests/tests_fabric/test_connector.py index 8a6e9206b3df5..9bb9fa1d7d145 100644 --- a/tests/tests_fabric/test_connector.py +++ b/tests/tests_fabric/test_connector.py @@ -19,10 +19,12 @@ from unittest import mock from unittest.mock import Mock -import lightning.fabric import pytest import torch import torch.distributed +from lightning_utilities.test.warning import no_warning_call + +import lightning.fabric from lightning.fabric import Fabric from lightning.fabric.accelerators import XLAAccelerator from lightning.fabric.accelerators.accelerator import Accelerator @@ -63,8 +65,6 @@ from lightning.fabric.strategies.ddp import _DDP_FORK_ALIASES from lightning.fabric.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from lightning.fabric.utilities.imports import _IS_WINDOWS -from lightning_utilities.test.warning import no_warning_call - from tests_fabric.conftest import mock_tpu_available from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index 7bb6b29eceaf2..ee002b5d8061c 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -17,11 +17,15 @@ from unittest import mock from unittest.mock import ANY, MagicMock, Mock, PropertyMock, call -import lightning.fabric import pytest import torch import torch.distributed import torch.nn.functional +from lightning_utilities.test.warning import no_warning_call +from torch import nn +from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler, TensorDataset + +import lightning.fabric from lightning.fabric.fabric import Fabric from lightning.fabric.strategies import ( DataParallelStrategy, @@ -37,10 +41,6 @@ from lightning.fabric.utilities.seed import pl_worker_init_function, seed_everything from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer -from lightning_utilities.test.warning import no_warning_call -from torch import nn -from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler, TensorDataset - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/test_wrappers.py b/tests/tests_fabric/test_wrappers.py index 363da022d285b..de8ae208e4dc4 100644 --- a/tests/tests_fabric/test_wrappers.py +++ b/tests/tests_fabric/test_wrappers.py @@ -16,6 +16,10 @@ import pytest import torch +from torch._dynamo import OptimizedModule +from torch.utils.data import BatchSampler, DistributedSampler +from torch.utils.data.dataloader import DataLoader + from lightning.fabric.fabric import Fabric from lightning.fabric.plugins import Precision from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin @@ -28,10 +32,6 @@ _unwrap_objects, is_wrapped, ) -from torch._dynamo import OptimizedModule -from torch.utils.data import BatchSampler, DistributedSampler -from torch.utils.data.dataloader import DataLoader - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/utilities/test_apply_func.py b/tests/tests_fabric/utilities/test_apply_func.py index 9e137561aa525..055fa89101c96 100644 --- a/tests/tests_fabric/utilities/test_apply_func.py +++ b/tests/tests_fabric/utilities/test_apply_func.py @@ -13,9 +13,10 @@ # limitations under the License. import pytest import torch -from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars, move_data_to_device from torch import Tensor +from lightning.fabric.utilities.apply_func import convert_tensors_to_scalars, move_data_to_device + @pytest.mark.parametrize("should_return", [False, True]) def test_wrongly_implemented_transferable_data_type(should_return): diff --git a/tests/tests_fabric/utilities/test_cloud_io.py b/tests/tests_fabric/utilities/test_cloud_io.py index d502199da1493..e1333ddff87f3 100644 --- a/tests/tests_fabric/utilities/test_cloud_io.py +++ b/tests/tests_fabric/utilities/test_cloud_io.py @@ -16,6 +16,7 @@ import fsspec from fsspec.implementations.local import LocalFileSystem from fsspec.spec import AbstractFileSystem + from lightning.fabric.utilities.cloud_io import _is_dir, get_filesystem diff --git a/tests/tests_fabric/utilities/test_consolidate_checkpoint.py b/tests/tests_fabric/utilities/test_consolidate_checkpoint.py index 2584aab8bdc2e..78690a9870982 100644 --- a/tests/tests_fabric/utilities/test_consolidate_checkpoint.py +++ b/tests/tests_fabric/utilities/test_consolidate_checkpoint.py @@ -16,8 +16,9 @@ from pathlib import Path from unittest import mock -import lightning.fabric import pytest + +import lightning.fabric from lightning.fabric.utilities.consolidate_checkpoint import _parse_cli_args, _process_cli_args from lightning.fabric.utilities.load import _METADATA_FILENAME diff --git a/tests/tests_fabric/utilities/test_data.py b/tests/tests_fabric/utilities/test_data.py index 656b9cac3d77e..faff6e182a06f 100644 --- a/tests/tests_fabric/utilities/test_data.py +++ b/tests/tests_fabric/utilities/test_data.py @@ -4,10 +4,14 @@ from unittest import mock from unittest.mock import Mock -import lightning.fabric import numpy as np import pytest import torch +from lightning_utilities.test.warning import no_warning_call +from torch import Tensor +from torch.utils.data import BatchSampler, DataLoader, RandomSampler + +import lightning.fabric from lightning.fabric.utilities.data import ( AttributeDict, _get_dataloader_init_args_and_kwargs, @@ -21,10 +25,6 @@ suggested_max_num_workers, ) from lightning.fabric.utilities.exceptions import MisconfigurationException -from lightning_utilities.test.warning import no_warning_call -from torch import Tensor -from torch.utils.data import BatchSampler, DataLoader, RandomSampler - from tests_fabric.helpers.datasets import RandomDataset, RandomIterableDataset diff --git a/tests/tests_fabric/utilities/test_device_dtype_mixin.py b/tests/tests_fabric/utilities/test_device_dtype_mixin.py index 9958e48c624ee..1261ca5e0accb 100644 --- a/tests/tests_fabric/utilities/test_device_dtype_mixin.py +++ b/tests/tests_fabric/utilities/test_device_dtype_mixin.py @@ -1,8 +1,8 @@ import pytest import torch -from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from torch import nn as nn +from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/utilities/test_device_parser.py b/tests/tests_fabric/utilities/test_device_parser.py index 9b5a09e370860..f9f9e49cf6afe 100644 --- a/tests/tests_fabric/utilities/test_device_parser.py +++ b/tests/tests_fabric/utilities/test_device_parser.py @@ -14,6 +14,7 @@ from unittest import mock import pytest + from lightning.fabric.utilities import device_parser from lightning.fabric.utilities.exceptions import MisconfigurationException diff --git a/tests/tests_fabric/utilities/test_distributed.py b/tests/tests_fabric/utilities/test_distributed.py index f5a78a1529a52..9282f00f1ffb6 100644 --- a/tests/tests_fabric/utilities/test_distributed.py +++ b/tests/tests_fabric/utilities/test_distributed.py @@ -5,9 +5,11 @@ from unittest import mock from unittest.mock import Mock -import lightning.fabric import pytest import torch +from lightning_utilities.core.imports import RequirementCache + +import lightning.fabric from lightning.fabric.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies import DDPStrategy, SingleDeviceStrategy @@ -23,8 +25,6 @@ _sync_ddp, is_shared_filesystem, ) -from lightning_utilities.core.imports import RequirementCache - from tests_fabric.helpers.runif import RunIf @@ -105,6 +105,8 @@ def _test_all_reduce(strategy): assert result is tensor # inplace +# flaky with "process 0 terminated with signal SIGABRT" (GLOO) +@pytest.mark.flaky(reruns=3, only_rerun="torch.multiprocessing.spawn.ProcessExitedException") @RunIf(skip_windows=True) @pytest.mark.parametrize( "process", diff --git a/tests/tests_fabric/utilities/test_init.py b/tests/tests_fabric/utilities/test_init.py index dd08dec020669..69758c5a3e17e 100644 --- a/tests/tests_fabric/utilities/test_init.py +++ b/tests/tests_fabric/utilities/test_init.py @@ -16,12 +16,12 @@ import pytest import torch.nn + from lightning.fabric.utilities.init import ( _EmptyInit, _has_meta_device_parameters_or_buffers, _materialize_meta_tensors, ) - from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/utilities/test_load.py b/tests/tests_fabric/utilities/test_load.py index 39d257f8b685b..ed38aa2459af7 100644 --- a/tests/tests_fabric/utilities/test_load.py +++ b/tests/tests_fabric/utilities/test_load.py @@ -14,6 +14,7 @@ import pytest import torch import torch.nn as nn + from lightning.fabric.utilities.load import ( _lazy_load, _materialize_tensors, @@ -55,7 +56,7 @@ def test_lazy_load_tensor(tmp_path): for t0, t1 in zip(expected.values(), loaded.values()): assert isinstance(t1, _NotYetLoadedTensor) t1_materialized = _materialize_tensors(t1) - assert type(t0) == type(t1_materialized) + assert type(t0) == type(t1_materialized) # noqa: E721 assert torch.equal(t0, t1_materialized) @@ -91,7 +92,7 @@ def test_materialize_tensors(tmp_path): loaded = _lazy_load(tmp_path / "tensor.pt") materialized = _materialize_tensors(loaded) assert torch.equal(materialized, tensor) - assert type(tensor) == type(materialized) + assert type(tensor) == type(materialized) # noqa: E721 # Collection of tensors collection = { diff --git a/tests/tests_fabric/utilities/test_logger.py b/tests/tests_fabric/utilities/test_logger.py index 0f6500cb42be1..4cb55c4cb68d8 100644 --- a/tests/tests_fabric/utilities/test_logger.py +++ b/tests/tests_fabric/utilities/test_logger.py @@ -17,6 +17,7 @@ import numpy as np import torch + from lightning.fabric.utilities.logger import ( _add_prefix, _convert_json_serializable, @@ -63,6 +64,12 @@ def test_flatten_dict(): assert params["c/8"] == "foo" assert params["c/9/10"] == "bar" + # Test list of nested dicts flattening + params = {"dl": [{"a": 1, "c": 3}, {"b": 2, "d": 5}], "l": [1, 2, 3, 4]} + params = _flatten_dict(params) + + assert params == {"dl/0/a": 1, "dl/0/c": 3, "dl/1/b": 2, "dl/1/d": 5, "l": [1, 2, 3, 4]} + # Test flattening of argparse Namespace params = Namespace(a=1, b=2) wrapping_dict = {"params": params} diff --git a/tests/tests_fabric/utilities/test_optimizer.py b/tests/tests_fabric/utilities/test_optimizer.py index 83c7ed44120b9..d96c58049ed3a 100644 --- a/tests/tests_fabric/utilities/test_optimizer.py +++ b/tests/tests_fabric/utilities/test_optimizer.py @@ -2,9 +2,9 @@ import pytest import torch -from lightning.fabric.utilities.optimizer import _optimizer_to_device from torch import Tensor +from lightning.fabric.utilities.optimizer import _optimizer_to_device from tests_fabric.helpers.runif import RunIf diff --git a/tests/tests_fabric/utilities/test_rank_zero.py b/tests/tests_fabric/utilities/test_rank_zero.py index 0c1b39fe9d9b8..b8ea54f90625d 100644 --- a/tests/tests_fabric/utilities/test_rank_zero.py +++ b/tests/tests_fabric/utilities/test_rank_zero.py @@ -3,6 +3,7 @@ from unittest import mock import pytest + from lightning.fabric.utilities.rank_zero import _get_rank diff --git a/tests/tests_fabric/utilities/test_seed.py b/tests/tests_fabric/utilities/test_seed.py index be2ecba3294b1..4a948a5f98736 100644 --- a/tests/tests_fabric/utilities/test_seed.py +++ b/tests/tests_fabric/utilities/test_seed.py @@ -1,11 +1,13 @@ import os import random +import warnings from unittest import mock from unittest.mock import Mock import numpy import pytest import torch + from lightning.fabric.utilities.seed import ( _collect_rng_states, _set_rng_states, @@ -29,9 +31,9 @@ def test_seed_stays_same_with_multiple_seed_everything_calls(): seed_everything() initial_seed = os.environ.get("PL_GLOBAL_SEED") - with pytest.warns(None) as record: + with warnings.catch_warnings(): + warnings.simplefilter("error") seed_everything() - assert not record # does not warn seed = os.environ.get("PL_GLOBAL_SEED") assert initial_seed == seed diff --git a/tests/tests_fabric/utilities/test_spike.py b/tests/tests_fabric/utilities/test_spike.py index 9739540af7f18..6054bf224d3df 100644 --- a/tests/tests_fabric/utilities/test_spike.py +++ b/tests/tests_fabric/utilities/test_spike.py @@ -3,6 +3,7 @@ import pytest import torch + from lightning.fabric import Fabric from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, SpikeDetection, TrainingSpikeException diff --git a/tests/tests_fabric/utilities/test_throughput.py b/tests/tests_fabric/utilities/test_throughput.py index d410d0766d97b..00dafbb72cb8f 100644 --- a/tests/tests_fabric/utilities/test_throughput.py +++ b/tests/tests_fabric/utilities/test_throughput.py @@ -3,6 +3,7 @@ import pytest import torch + from lightning.fabric import Fabric from lightning.fabric.plugins import Precision from lightning.fabric.utilities.throughput import ( @@ -12,7 +13,6 @@ get_available_flops, measure_flops, ) - from tests_fabric.test_fabric import BoringModel diff --git a/tests/tests_fabric/utilities/test_warnings.py b/tests/tests_fabric/utilities/test_warnings.py index b7989d85b5932..bfccaaa8481d9 100644 --- a/tests/tests_fabric/utilities/test_warnings.py +++ b/tests/tests_fabric/utilities/test_warnings.py @@ -27,16 +27,17 @@ from pathlib import Path from unittest import mock -import lightning.fabric import pytest +from lightning_utilities.core.rank_zero import WarningCache, _warn +from lightning_utilities.test.warning import no_warning_call + +import lightning.fabric from lightning.fabric.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn from lightning.fabric.utilities.warnings import ( PossibleUserWarning, _is_path_in_lightning, disable_possible_user_warnings, ) -from lightning_utilities.core.rank_zero import WarningCache, _warn -from lightning_utilities.test.warning import no_warning_call def line_number(): diff --git a/tests/tests_pytorch/__init__.py b/tests/tests_pytorch/__init__.py index a43ffae6a83b4..efbfa8bb14c76 100644 --- a/tests/tests_pytorch/__init__.py +++ b/tests/tests_pytorch/__init__.py @@ -25,7 +25,7 @@ # todo: this setting `PYTHONPATH` may not be used by other evns like Conda for import packages if str(_PROJECT_ROOT) not in os.getenv("PYTHONPATH", ""): splitter = ":" if os.environ.get("PYTHONPATH", "") else "" - os.environ["PYTHONPATH"] = f'{_PROJECT_ROOT}{splitter}{os.environ.get("PYTHONPATH", "")}' + os.environ["PYTHONPATH"] = f"{_PROJECT_ROOT}{splitter}{os.environ.get('PYTHONPATH', '')}" # Ignore cleanup warnings from pytest (rarely happens due to a race condition when executing pytest in parallel) warnings.filterwarnings("ignore", category=pytest.PytestWarning, message=r".*\(rm_rf\) error removing.*") diff --git a/tests/tests_pytorch/accelerators/test_common.py b/tests/tests_pytorch/accelerators/test_common.py index 6967bffd9ffa2..42fd66a247d1d 100644 --- a/tests/tests_pytorch/accelerators/test_common.py +++ b/tests/tests_pytorch/accelerators/test_common.py @@ -14,6 +14,7 @@ from typing import Any import torch + from lightning.pytorch import Trainer from lightning.pytorch.accelerators import Accelerator from lightning.pytorch.strategies import DDPStrategy diff --git a/tests/tests_pytorch/accelerators/test_cpu.py b/tests/tests_pytorch/accelerators/test_cpu.py index 844556064621d..ec2aabf559dc7 100644 --- a/tests/tests_pytorch/accelerators/test_cpu.py +++ b/tests/tests_pytorch/accelerators/test_cpu.py @@ -3,16 +3,16 @@ from typing import Any, Union from unittest.mock import Mock -import lightning.pytorch as pl import pytest import torch + +import lightning.pytorch as pl from lightning.fabric.plugins import TorchCheckpointIO from lightning.pytorch import Trainer from lightning.pytorch.accelerators import CPUAccelerator from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.plugins.precision.precision import Precision from lightning.pytorch.strategies import SingleDeviceStrategy - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/accelerators/test_gpu.py b/tests/tests_pytorch/accelerators/test_gpu.py index e175f8aa7647c..5a71887e17eec 100644 --- a/tests/tests_pytorch/accelerators/test_gpu.py +++ b/tests/tests_pytorch/accelerators/test_gpu.py @@ -15,11 +15,11 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.accelerators import CUDAAccelerator from lightning.pytorch.accelerators.cuda import get_nvidia_gpu_stats from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/accelerators/test_mps.py b/tests/tests_pytorch/accelerators/test_mps.py index 73d0785f9592d..c0a28840f0ef6 100644 --- a/tests/tests_pytorch/accelerators/test_mps.py +++ b/tests/tests_pytorch/accelerators/test_mps.py @@ -16,11 +16,11 @@ import pytest import torch + +import tests_pytorch.helpers.pipelines as tpipes from lightning.pytorch import Trainer from lightning.pytorch.accelerators import MPSAccelerator from lightning.pytorch.demos.boring_classes import BoringModel - -import tests_pytorch.helpers.pipelines as tpipes from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/accelerators/test_xla.py b/tests/tests_pytorch/accelerators/test_xla.py index 48b346006786f..83dace719371d 100644 --- a/tests/tests_pytorch/accelerators/test_xla.py +++ b/tests/tests_pytorch/accelerators/test_xla.py @@ -17,9 +17,12 @@ from unittest import mock from unittest.mock import MagicMock, call, patch -import lightning.fabric import pytest import torch +from torch import nn +from torch.utils.data import DataLoader + +import lightning.fabric from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.pytorch import Trainer from lightning.pytorch.accelerators import CPUAccelerator, XLAAccelerator @@ -27,9 +30,6 @@ from lightning.pytorch.plugins import Precision, XLACheckpointIO, XLAPrecision from lightning.pytorch.strategies import DDPStrategy, XLAStrategy from lightning.pytorch.utilities import find_shared_parameters -from torch import nn -from torch.utils.data import DataLoader - from tests_pytorch.helpers.runif import RunIf from tests_pytorch.trainer.connectors.test_accelerator_connector import DeviceMock from tests_pytorch.trainer.optimization.test_manual_optimization import assert_emtpy_grad @@ -148,7 +148,7 @@ def on_train_start(self): def on_train_end(self): # this might fail if run in an environment with too many ranks, as the total - # length of the dataloader will be distrbuted among them and then each rank might not do 3 steps + # length of the dataloader will be distributed among them and then each rank might not do 3 steps assert self.called["training_step"] == 3 assert self.called["on_train_batch_start"] == 3 assert self.called["on_train_batch_end"] == 3 diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index 89c1effe839a8..430fb9842cddc 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -17,14 +17,15 @@ from unittest.mock import DEFAULT, Mock import pytest +from tests_pytorch.helpers.runif import RunIf +from torch.utils.data import DataLoader + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ProgressBar, RichProgressBar from lightning.pytorch.callbacks.progress.rich_progress import RichProgressBarTheme from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.loggers.logger import DummyLogger -from tests_pytorch.helpers.runif import RunIf -from torch.utils.data import DataLoader @RunIf(rich=True) diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index d5187d5a1e325..d93bf1cf60e9c 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -22,6 +22,9 @@ import pytest import torch +from tests_pytorch.helpers.runif import RunIf +from torch.utils.data.dataloader import DataLoader + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint, ProgressBar, TQDMProgressBar from lightning.pytorch.callbacks.progress.tqdm_progress import Tqdm @@ -30,8 +33,6 @@ from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.loggers.logger import DummyLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException -from tests_pytorch.helpers.runif import RunIf -from torch.utils.data.dataloader import DataLoader class MockTqdm(Tqdm): diff --git a/tests/tests_pytorch/callbacks/test_callback_hook_outputs.py b/tests/tests_pytorch/callbacks/test_callback_hook_outputs.py index 8f32c756881da..366a924a5867c 100644 --- a/tests/tests_pytorch/callbacks/test_callback_hook_outputs.py +++ b/tests/tests_pytorch/callbacks/test_callback_hook_outputs.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest + from lightning.pytorch import Callback, Trainer from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/callbacks/test_callbacks.py b/tests/tests_pytorch/callbacks/test_callbacks.py index 38b9428526505..53ea109b6ddf3 100644 --- a/tests/tests_pytorch/callbacks/test_callbacks.py +++ b/tests/tests_pytorch/callbacks/test_callbacks.py @@ -16,10 +16,11 @@ from unittest.mock import Mock import pytest +from lightning_utilities.test.warning import no_warning_call + from lightning.pytorch import Callback, Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel -from lightning_utilities.test.warning import no_warning_call def test_callbacks_configured_in_model(tmp_path): diff --git a/tests/tests_pytorch/callbacks/test_device_stats_monitor.py b/tests/tests_pytorch/callbacks/test_device_stats_monitor.py index f1d999f1df61a..290a0921cb06d 100644 --- a/tests/tests_pytorch/callbacks/test_device_stats_monitor.py +++ b/tests/tests_pytorch/callbacks/test_device_stats_monitor.py @@ -20,6 +20,7 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.accelerators.cpu import _CPU_PERCENT, _CPU_SWAP_PERCENT, _CPU_VM_PERCENT, get_cpu_stats from lightning.pytorch.callbacks import DeviceStatsMonitor @@ -28,7 +29,6 @@ from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.rank_zero import rank_zero_only - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index a3d56bb0135c3..9a87b3daaad6e 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -22,11 +22,11 @@ import cloudpickle import pytest import torch + from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.exceptions import MisconfigurationException - from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel @@ -61,8 +61,8 @@ def on_train_epoch_end(self, trainer, pl_module): def test_resume_early_stopping_from_checkpoint(tmp_path): """Prevent regressions to bugs: - https://github.com/Lightning-AI/lightning/issues/1464 - https://github.com/Lightning-AI/lightning/issues/1463 + https://github.com/Lightning-AI/pytorch-lightning/issues/1464 + https://github.com/Lightning-AI/pytorch-lightning/issues/1463 """ seed_everything(42) diff --git a/tests/tests_pytorch/callbacks/test_finetuning_callback.py b/tests/tests_pytorch/callbacks/test_finetuning_callback.py index 0c09ae5d5042a..07343c1ecc12a 100644 --- a/tests/tests_pytorch/callbacks/test_finetuning_callback.py +++ b/tests/tests_pytorch/callbacks/test_finetuning_callback.py @@ -15,14 +15,14 @@ import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 -from lightning.pytorch import LightningModule, Trainer, seed_everything -from lightning.pytorch.callbacks import BackboneFinetuning, BaseFinetuning, ModelCheckpoint -from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from torch import nn from torch.optim import SGD, Optimizer from torch.utils.data import DataLoader +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 +from lightning.pytorch import LightningModule, Trainer, seed_everything +from lightning.pytorch.callbacks import BackboneFinetuning, BaseFinetuning, ModelCheckpoint +from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/callbacks/test_gradient_accumulation_scheduler.py b/tests/tests_pytorch/callbacks/test_gradient_accumulation_scheduler.py index cc584b2da624c..9ad9759ff248c 100644 --- a/tests/tests_pytorch/callbacks/test_gradient_accumulation_scheduler.py +++ b/tests/tests_pytorch/callbacks/test_gradient_accumulation_scheduler.py @@ -15,6 +15,7 @@ from unittest.mock import Mock, patch import pytest + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import GradientAccumulationScheduler from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/callbacks/test_lambda_function.py b/tests/tests_pytorch/callbacks/test_lambda_function.py index 40d694bb35ebc..2b5e025653940 100644 --- a/tests/tests_pytorch/callbacks/test_lambda_function.py +++ b/tests/tests_pytorch/callbacks/test_lambda_function.py @@ -14,10 +14,10 @@ from functools import partial import pytest + from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import Callback, LambdaCallback from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.models.test_hooks import get_members diff --git a/tests/tests_pytorch/callbacks/test_lr_monitor.py b/tests/tests_pytorch/callbacks/test_lr_monitor.py index 4aedb4f23fa14..66ce47f0e7ad4 100644 --- a/tests/tests_pytorch/callbacks/test_lr_monitor.py +++ b/tests/tests_pytorch/callbacks/test_lr_monitor.py @@ -13,6 +13,8 @@ # limitations under the License. import pytest import torch +from torch import optim + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor from lightning.pytorch.callbacks.callback import Callback @@ -20,8 +22,6 @@ from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch import optim - from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel @@ -44,9 +44,9 @@ def test_lr_monitor_single_lr(tmp_path): assert lr_monitor.lrs, "No learning rates logged" assert all(v is None for v in lr_monitor.last_momentum_values.values()), "Momentum should not be logged by default" - assert all( - v is None for v in lr_monitor.last_weight_decay_values.values() - ), "Weight decay should not be logged by default" + assert all(v is None for v in lr_monitor.last_weight_decay_values.values()), ( + "Weight decay should not be logged by default" + ) assert len(lr_monitor.lrs) == len(trainer.lr_scheduler_configs) assert list(lr_monitor.lrs) == ["lr-SGD"] @@ -87,9 +87,9 @@ def configure_optimizers(self): assert len(lr_monitor.last_momentum_values) == len(trainer.lr_scheduler_configs) assert all(k == f"lr-{opt}-momentum" for k in lr_monitor.last_momentum_values) - assert all( - v is not None for v in lr_monitor.last_weight_decay_values.values() - ), "Expected weight decay to be logged" + assert all(v is not None for v in lr_monitor.last_weight_decay_values.values()), ( + "Expected weight decay to be logged" + ) assert len(lr_monitor.last_weight_decay_values) == len(trainer.lr_scheduler_configs) assert all(k == f"lr-{opt}-weight_decay" for k in lr_monitor.last_weight_decay_values) @@ -548,10 +548,10 @@ def finetune_function(self, pl_module, epoch: int, optimizer): """Called when the epoch begins.""" if epoch == 1 and isinstance(optimizer, torch.optim.SGD): self.unfreeze_and_add_param_group(pl_module.backbone[0], optimizer, lr=0.1) - if epoch == 2 and isinstance(optimizer, torch.optim.Adam): + if epoch == 2 and type(optimizer) is torch.optim.Adam: self.unfreeze_and_add_param_group(pl_module.layer, optimizer, lr=0.1) - if epoch == 3 and isinstance(optimizer, torch.optim.Adam): + if epoch == 3 and type(optimizer) is torch.optim.Adam: assert len(optimizer.param_groups) == 2 self.unfreeze_and_add_param_group(pl_module.backbone[1], optimizer, lr=0.1) assert len(optimizer.param_groups) == 3 diff --git a/tests/tests_pytorch/callbacks/test_prediction_writer.py b/tests/tests_pytorch/callbacks/test_prediction_writer.py index 02604f5a195fe..7249d5739686b 100644 --- a/tests/tests_pytorch/callbacks/test_prediction_writer.py +++ b/tests/tests_pytorch/callbacks/test_prediction_writer.py @@ -14,11 +14,12 @@ from unittest.mock import ANY, Mock, call import pytest +from torch.utils.data import DataLoader + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import BasePredictionWriter from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch.utils.data import DataLoader class DummyPredictionWriter(BasePredictionWriter): diff --git a/tests/tests_pytorch/callbacks/test_pruning.py b/tests/tests_pytorch/callbacks/test_pruning.py index f3ec7e2ccc029..d70ab68b78b32 100644 --- a/tests/tests_pytorch/callbacks/test_pruning.py +++ b/tests/tests_pytorch/callbacks/test_pruning.py @@ -19,13 +19,13 @@ import pytest import torch import torch.nn.utils.prune as pytorch_prune +from torch import nn +from torch.nn import Sequential + from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint, ModelPruning from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch import nn -from torch.nn import Sequential - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/callbacks/test_rich_model_summary.py b/tests/tests_pytorch/callbacks/test_rich_model_summary.py index 73709fd80a833..7534c23d5679c 100644 --- a/tests/tests_pytorch/callbacks/test_rich_model_summary.py +++ b/tests/tests_pytorch/callbacks/test_rich_model_summary.py @@ -16,11 +16,11 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import RichModelSummary, RichProgressBar from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.model_summary import summarize - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/callbacks/test_spike.py b/tests/tests_pytorch/callbacks/test_spike.py index 5634feaf221cd..692a28dcc38c4 100644 --- a/tests/tests_pytorch/callbacks/test_spike.py +++ b/tests/tests_pytorch/callbacks/test_spike.py @@ -3,6 +3,7 @@ import pytest import torch + from lightning.fabric.utilities.spike import _TORCHMETRICS_GREATER_EQUAL_1_0_0, TrainingSpikeException from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks.spike import SpikeDetection diff --git a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py index 8d3a1800e1fa2..e9e11b6dbb466 100644 --- a/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py +++ b/tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py @@ -20,17 +20,17 @@ import pytest import torch +from torch import nn +from torch.optim.lr_scheduler import LambdaLR +from torch.optim.swa_utils import SWALR +from torch.utils.data import DataLoader + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import StochasticWeightAveraging from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset from lightning.pytorch.strategies import Strategy from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch import nn -from torch.optim.lr_scheduler import LambdaLR -from torch.optim.swa_utils import SWALR -from torch.utils.data import DataLoader - from tests_pytorch.helpers.runif import RunIf @@ -347,7 +347,7 @@ def test_swa_resume_training_from_checkpoint(tmp_path, crash_on_epoch): @pytest.mark.parametrize("crash_on_epoch", [1, 3]) def test_swa_resume_training_from_checkpoint_custom_scheduler(tmp_path, crash_on_epoch): - # Reproduces the bug reported in https://github.com/Lightning-AI/lightning/issues/11665 + # Reproduces the bug reported in https://github.com/Lightning-AI/pytorch-lightning/issues/11665 model = CustomSchedulerModel(crash_on_epoch=crash_on_epoch) resume_model = CustomSchedulerModel() _swa_resume_training_from_checkpoint(tmp_path, model, resume_model) diff --git a/tests/tests_pytorch/callbacks/test_throughput_monitor.py b/tests/tests_pytorch/callbacks/test_throughput_monitor.py index 4867134a85642..9f77e4371e69e 100644 --- a/tests/tests_pytorch/callbacks/test_throughput_monitor.py +++ b/tests/tests_pytorch/callbacks/test_throughput_monitor.py @@ -3,6 +3,7 @@ import pytest import torch + from lightning.fabric.utilities.throughput import measure_flops from lightning.pytorch import Trainer from lightning.pytorch.callbacks.throughput_monitor import ThroughputMonitor diff --git a/tests/tests_pytorch/callbacks/test_timer.py b/tests/tests_pytorch/callbacks/test_timer.py index e6359a2e9a5e1..e91170f2096d3 100644 --- a/tests/tests_pytorch/callbacks/test_timer.py +++ b/tests/tests_pytorch/callbacks/test_timer.py @@ -17,12 +17,12 @@ from unittest.mock import Mock, patch import pytest + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.callbacks.timer import Timer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.exceptions import MisconfigurationException - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py b/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py index ff8f3c95e43c5..2e998c42ed2b7 100644 --- a/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py +++ b/tests/tests_pytorch/checkpointing/test_checkpoint_callback_frequency.py @@ -16,9 +16,9 @@ import pytest import torch + from lightning.pytorch import Trainer, callbacks from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py index be754d3911ade..006a123356c98 100644 --- a/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py +++ b/tests/tests_pytorch/checkpointing/test_legacy_checkpoints.py @@ -16,11 +16,11 @@ import sys from unittest.mock import patch -import lightning.pytorch as pl import pytest import torch -from lightning.pytorch import Callback, Trainer +import lightning.pytorch as pl +from lightning.pytorch import Callback, Trainer from tests_pytorch import _PATH_LEGACY from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf @@ -75,6 +75,7 @@ def test_legacy_ckpt_threading(pl_version: str): def load_model(): import torch + from lightning.pytorch.utilities.migration import pl_legacy_patch with pl_legacy_patch(): diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index d43f07179e7bb..1907a5fb35799 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -25,20 +25,21 @@ from unittest.mock import Mock, call, patch import cloudpickle -import lightning.pytorch as pl import pytest import torch import yaml from jsonargparse import ArgumentParser +from torch import optim +from torch.utils.data.dataloader import DataLoader + +import lightning.pytorch as pl from lightning.fabric.utilities.cloud_io import _load as pl_load from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint -from lightning.pytorch.demos.boring_classes import BoringModel +from lightning.pytorch.demos.boring_classes import BoringModel, RandomIterableDataset from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE -from torch import optim - from tests_pytorch.helpers.runif import RunIf if _OMEGACONF_AVAILABLE: @@ -1624,3 +1625,44 @@ def test_save_last_cli(val, expected): parser.add_argument("--a", type=annot) args = parser.parse_args(["--a", val]) assert args.a == expected + + +def test_load_with_inf_data_loader(tmp_path): + """Test loading from a checkpoint with a dataloader that does not have a length.""" + # Test for https://github.com/Lightning-AI/pytorch-lightning/issues/20565 + dataset = RandomIterableDataset(size=32, count=10) + + class ModelWithIterableDataset(BoringModel): + def train_dataloader(self) -> DataLoader: + return DataLoader(dataset) + + def val_dataloader(self) -> DataLoader: + return DataLoader(dataset) + + model = ModelWithIterableDataset() + with pytest.raises(TypeError): + len(model.train_dataloader()) + + trainer_kwargs = { + "default_root_dir": tmp_path, + "max_epochs": 2, + "limit_train_batches": 2, + "limit_val_batches": None, + "check_val_every_n_epoch": 1, + "enable_model_summary": False, + "logger": False, + } + mc_kwargs = { + "save_last": True, + "every_n_train_steps": 1, + } + trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs)) + trainer.fit(model) + + checkpoint_path = tmp_path / "checkpoints" / "epoch=1-step=4.ckpt" + assert checkpoint_path.name in os.listdir(tmp_path / "checkpoints") + + # Resume from checkpoint and run for more epochs + trainer_kwargs["max_epochs"] = 4 + trainer = Trainer(**trainer_kwargs, callbacks=ModelCheckpoint(**mc_kwargs)) + trainer.fit(model, ckpt_path=checkpoint_path) diff --git a/tests/tests_pytorch/checkpointing/test_torch_saving.py b/tests/tests_pytorch/checkpointing/test_torch_saving.py index 4422a4063c719..e49c45bf54ade 100644 --- a/tests/tests_pytorch/checkpointing/test_torch_saving.py +++ b/tests/tests_pytorch/checkpointing/test_torch_saving.py @@ -13,9 +13,9 @@ # limitations under the License. import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py b/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py index c07400eaf8446..b0f1528d42d28 100644 --- a/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_trainer_checkpoint.py @@ -15,9 +15,10 @@ from unittest import mock from unittest.mock import ANY, Mock -import lightning.pytorch as pl import pytest import torch + +import lightning.pytorch as pl from lightning.fabric.plugins import TorchCheckpointIO, XLACheckpointIO from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint diff --git a/tests/tests_pytorch/conftest.py b/tests/tests_pytorch/conftest.py index ea5207516cad1..b02d9d089a354 100644 --- a/tests/tests_pytorch/conftest.py +++ b/tests/tests_pytorch/conftest.py @@ -21,18 +21,18 @@ from pathlib import Path from unittest.mock import Mock -import lightning.fabric -import lightning.pytorch import pytest import torch.distributed +from tqdm import TMonitor + +import lightning.fabric +import lightning.pytorch from lightning.fabric.plugins.environments.lightning import find_free_network_port from lightning.fabric.strategies.launchers.subprocess_script import _ChildProcessObserver from lightning.fabric.utilities.distributed import _destroy_dist_connection, _distributed_is_initialized from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.pytorch.accelerators import XLAAccelerator from lightning.pytorch.trainer.connectors.signal_connector import _SignalConnector -from tqdm import TMonitor - from tests_pytorch import _PATH_DATASETS @@ -44,9 +44,10 @@ def datadir(): @pytest.fixture(autouse=True) def preserve_global_rank_variable(): """Ensures that the rank_zero_only.rank global variable gets reset in each test.""" + from lightning_utilities.core.rank_zero import rank_zero_only as rank_zero_only_utilities + from lightning.fabric.utilities.rank_zero import rank_zero_only as rank_zero_only_fabric from lightning.pytorch.utilities.rank_zero import rank_zero_only as rank_zero_only_pytorch - from lightning_utilities.core.rank_zero import rank_zero_only as rank_zero_only_utilities functions = (rank_zero_only_pytorch, rank_zero_only_fabric, rank_zero_only_utilities) ranks = [getattr(fn, "rank", None) for fn in functions] @@ -93,6 +94,7 @@ def restore_env_variables(): "TF_CPP_MIN_LOG_LEVEL", "TF_GRPC_DEFAULT_OPTIONS", "XLA_FLAGS", + "TORCHINDUCTOR_CACHE_DIR", # leaked by torch.compile } leaked_vars.difference_update(allowlist) assert not leaked_vars, f"test is leaking environment variable(s): {set(leaked_vars)}" @@ -176,22 +178,22 @@ def mock_cuda_count(monkeypatch, n: int) -> None: monkeypatch.setattr(lightning.pytorch.accelerators.cuda, "num_cuda_devices", lambda: n) -@pytest.fixture() +@pytest.fixture def cuda_count_0(monkeypatch): mock_cuda_count(monkeypatch, 0) -@pytest.fixture() +@pytest.fixture def cuda_count_1(monkeypatch): mock_cuda_count(monkeypatch, 1) -@pytest.fixture() +@pytest.fixture def cuda_count_2(monkeypatch): mock_cuda_count(monkeypatch, 2) -@pytest.fixture() +@pytest.fixture def cuda_count_4(monkeypatch): mock_cuda_count(monkeypatch, 4) @@ -201,12 +203,12 @@ def mock_mps_count(monkeypatch, n: int) -> None: monkeypatch.setattr(lightning.fabric.accelerators.mps.MPSAccelerator, "is_available", lambda *_: n > 0) -@pytest.fixture() +@pytest.fixture def mps_count_0(monkeypatch): mock_mps_count(monkeypatch, 0) -@pytest.fixture() +@pytest.fixture def mps_count_1(monkeypatch): mock_mps_count(monkeypatch, 1) @@ -222,7 +224,7 @@ def mock_xla_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N monkeypatch.setattr(lightning.fabric.strategies.launchers.xla, "_XLA_AVAILABLE", value) -@pytest.fixture() +@pytest.fixture def xla_available(monkeypatch: pytest.MonkeyPatch) -> None: mock_xla_available(monkeypatch) @@ -238,12 +240,12 @@ def mock_tpu_available(monkeypatch: pytest.MonkeyPatch, value: bool = True) -> N monkeypatch.setitem(sys.modules, "torch_xla.experimental", Mock()) -@pytest.fixture() +@pytest.fixture def tpu_available(monkeypatch) -> None: mock_tpu_available(monkeypatch) -@pytest.fixture() +@pytest.fixture def caplog(caplog): """Workaround for https://github.com/pytest-dev/pytest/issues/3697. @@ -271,7 +273,7 @@ def caplog(caplog): logging.getLogger(name).propagate = propagate -@pytest.fixture() +@pytest.fixture def tmpdir_server(tmp_path): Handler = partial(SimpleHTTPRequestHandler, directory=str(tmp_path)) from http.server import ThreadingHTTPServer @@ -285,7 +287,7 @@ def tmpdir_server(tmp_path): server.shutdown() -@pytest.fixture() +@pytest.fixture def single_process_pg(): """Initialize the default process group with only the current process for testing purposes. @@ -311,12 +313,16 @@ def single_process_pg(): @pytest.fixture(autouse=True) def leave_no_artifacts_behind(): + """Checks that no artifacts are left behind after the test.""" tests_root = Path(__file__).parent.parent + # Ignore the __pycache__ directories files_before = {p for p in tests_root.rglob("*") if "__pycache__" not in p.parts} yield files_after = {p for p in tests_root.rglob("*") if "__pycache__" not in p.parts} difference = files_after - files_before difference = {str(f.relative_to(tests_root)) for f in difference} + # ignore the .coverage files + difference = {f for f in difference if not f.endswith(".coverage")} assert not difference, f"Test left artifacts behind: {difference}" diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index b3ccd88aae704..fcdc660a0fffc 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -21,6 +21,7 @@ import pytest import torch + from lightning.pytorch import LightningDataModule, Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import ( @@ -34,7 +35,6 @@ from lightning.pytorch.utilities import AttributeDict from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE - from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel @@ -112,7 +112,7 @@ def prepare_data(self): def test_hooks_no_recursion_error(): # hooks were appended in cascade every tine a new data module was instantiated leading to a recursion error. - # See https://github.com/Lightning-AI/lightning/issues/3652 + # See https://github.com/Lightning-AI/pytorch-lightning/issues/3652 class DummyDM(LightningDataModule): def setup(self, *args, **kwargs): pass diff --git a/tests/tests_pytorch/core/test_lightning_module.py b/tests/tests_pytorch/core/test_lightning_module.py index 5ee91e82689f4..2036014762ebf 100644 --- a/tests/tests_pytorch/core/test_lightning_module.py +++ b/tests/tests_pytorch/core/test_lightning_module.py @@ -17,15 +17,15 @@ import pytest import torch +from torch import nn +from torch.optim import SGD, Adam + from lightning.fabric import Fabric from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.core.module import _TrainerFabricShim from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch import nn -from torch.optim import SGD, Adam - from tests_pytorch.helpers.runif import RunIf @@ -336,9 +336,9 @@ def __init__(self, spec): ), "Expect the shards to be different before `m_1` loading `m_0`'s state dict" m_1.load_state_dict(m_0.state_dict(), strict=False) - assert torch.allclose( - m_1.sharded_tensor.local_shards()[0].tensor, m_0.sharded_tensor.local_shards()[0].tensor - ), "Expect the shards to be same after `m_1` loading `m_0`'s state dict" + assert torch.allclose(m_1.sharded_tensor.local_shards()[0].tensor, m_0.sharded_tensor.local_shards()[0].tensor), ( + "Expect the shards to be same after `m_1` loading `m_0`'s state dict" + ) def test_lightning_module_configure_gradient_clipping(tmp_path): diff --git a/tests/tests_pytorch/core/test_lightning_optimizer.py b/tests/tests_pytorch/core/test_lightning_optimizer.py index b25b7ae648a3a..ed1ca2b4db03f 100644 --- a/tests/tests_pytorch/core/test_lightning_optimizer.py +++ b/tests/tests_pytorch/core/test_lightning_optimizer.py @@ -15,13 +15,13 @@ from unittest.mock import DEFAULT, Mock, patch import torch +from torch.optim import SGD, Adam, Optimizer + from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.core.optimizer import LightningOptimizer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loops.optimization.automatic import Closure from lightning.pytorch.tuner.tuning import Tuner -from torch.optim import SGD, Adam, Optimizer - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/core/test_metric_result_integration.py b/tests/tests_pytorch/core/test_metric_result_integration.py index dcb3f71c7499c..4a5df7d37cd7a 100644 --- a/tests/tests_pytorch/core/test_metric_result_integration.py +++ b/tests/tests_pytorch/core/test_metric_result_integration.py @@ -16,9 +16,15 @@ from contextlib import nullcontext, suppress from unittest import mock -import lightning.pytorch as pl import pytest import torch +from lightning_utilities.test.warning import no_warning_call +from torch import Tensor, tensor +from torch.nn import ModuleDict, ModuleList +from torchmetrics import Metric, MetricCollection +from torchmetrics.classification import Accuracy + +import lightning.pytorch as pl from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer from lightning.pytorch.callbacks import OnExceptionCheckpoint @@ -30,12 +36,6 @@ _Sync, ) from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11 -from lightning_utilities.test.warning import no_warning_call -from torch import Tensor, tensor -from torch.nn import ModuleDict, ModuleList -from torchmetrics import Metric, MetricCollection -from torchmetrics.classification import Accuracy - from tests_pytorch.core.test_results import spawn_launch from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/core/test_results.py b/tests/tests_pytorch/core/test_results.py index 006731f3e7c60..93982086a6b0a 100644 --- a/tests/tests_pytorch/core/test_results.py +++ b/tests/tests_pytorch/core/test_results.py @@ -13,14 +13,15 @@ # limitations under the License. from functools import partial +import pytest import torch import torch.distributed as dist + from lightning.fabric.plugins.environments import LightningEnvironment from lightning.pytorch.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator from lightning.pytorch.strategies import DDPStrategy from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher from lightning.pytorch.trainer.connectors.logger_connector.result import _Sync - from tests_pytorch.helpers.runif import RunIf from tests_pytorch.models.test_tpu import wrap_launch_function @@ -48,6 +49,8 @@ def result_reduce_ddp_fn(strategy): assert actual.item() == dist.get_world_size() +# flaky with "process 0 terminated with signal SIGABRT" +@pytest.mark.flaky(reruns=3, only_rerun="torch.multiprocessing.spawn.ProcessExitedException") @RunIf(skip_windows=True) def test_result_reduce_ddp(): spawn_launch(result_reduce_ddp_fn, [torch.device("cpu")] * 2) diff --git a/tests/tests_pytorch/core/test_saving.py b/tests/tests_pytorch/core/test_saving.py index c7e48239754c5..8e1e9584a7c68 100644 --- a/tests/tests_pytorch/core/test_saving.py +++ b/tests/tests_pytorch/core/test_saving.py @@ -1,11 +1,11 @@ from unittest.mock import ANY, Mock -import lightning.pytorch as pl import pytest import torch + +import lightning.pytorch as pl from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel - from tests_pytorch.conftest import mock_cuda_count, mock_mps_count from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/demos/transformer.py b/tests/tests_pytorch/demos/transformer.py index 47ecbb2083273..50de873511053 100644 --- a/tests/tests_pytorch/demos/transformer.py +++ b/tests/tests_pytorch/demos/transformer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch + from lightning.pytorch.demos import Transformer diff --git a/tests/tests_pytorch/deprecated_api/test_no_removal_version.py b/tests/tests_pytorch/deprecated_api/test_no_removal_version.py index e6da72c777dbb..0b79b638534fa 100644 --- a/tests/tests_pytorch/deprecated_api/test_no_removal_version.py +++ b/tests/tests_pytorch/deprecated_api/test_no_removal_version.py @@ -1,9 +1,10 @@ import sys from unittest.mock import Mock -import lightning.fabric import pytest import torch.nn + +import lightning.fabric from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.plugins.precision.double import LightningDoublePrecisionModule diff --git a/tests/tests_pytorch/helpers/__init__.py b/tests/tests_pytorch/helpers/__init__.py index 82a6332c56738..1299d7e542955 100644 --- a/tests/tests_pytorch/helpers/__init__.py +++ b/tests/tests_pytorch/helpers/__init__.py @@ -4,5 +4,4 @@ ManualOptimBoringModel, RandomDataset, ) - from tests_pytorch.helpers.datasets import TrialMNIST # noqa: F401 diff --git a/tests/tests_pytorch/helpers/advanced_models.py b/tests/tests_pytorch/helpers/advanced_models.py index 4fecf516018c1..959e6e5968d18 100644 --- a/tests/tests_pytorch/helpers/advanced_models.py +++ b/tests/tests_pytorch/helpers/advanced_models.py @@ -16,9 +16,9 @@ import torch import torch.nn as nn import torch.nn.functional as F -from lightning.pytorch.core.module import LightningModule from torch.utils.data import DataLoader +from lightning.pytorch.core.module import LightningModule from tests_pytorch import _PATH_DATASETS from tests_pytorch.helpers.datasets import MNIST, AverageDataset, TrialMNIST @@ -219,3 +219,54 @@ def configure_optimizers(self): def train_dataloader(self): return DataLoader(MNIST(root=_PATH_DATASETS, train=True, download=True), batch_size=128, num_workers=1) + + +class TBPTTModule(LightningModule): + def __init__(self): + super().__init__() + + self.batch_size = 10 + self.in_features = 10 + self.out_features = 5 + self.hidden_dim = 20 + + self.automatic_optimization = False + self.truncated_bptt_steps = 10 + + self.rnn = nn.LSTM(self.in_features, self.hidden_dim, batch_first=True) + self.linear_out = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features) + + def forward(self, x, hs): + seq, hs = self.rnn(x, hs) + return self.linear_out(seq), hs + + def training_step(self, batch, batch_idx): + x, y = batch + split_x, split_y = [ + x.tensor_split(self.truncated_bptt_steps, dim=1), + y.tensor_split(self.truncated_bptt_steps, dim=1), + ] + + hiddens = None + optimizer = self.optimizers() + losses = [] + + for x, y in zip(split_x, split_y): + y_pred, hiddens = self(x, hiddens) + loss = F.mse_loss(y_pred, y) + + optimizer.zero_grad() + self.manual_backward(loss) + optimizer.step() + + # "Truncate" + hiddens = [h.detach() for h in hiddens] + losses.append(loss.detach()) + + return + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.001) + + def train_dataloader(self): + return DataLoader(AverageDataset(), batch_size=self.batch_size) diff --git a/tests/tests_pytorch/helpers/datamodules.py b/tests/tests_pytorch/helpers/datamodules.py index 5a91d8ebb981d..6282acf3be547 100644 --- a/tests/tests_pytorch/helpers/datamodules.py +++ b/tests/tests_pytorch/helpers/datamodules.py @@ -13,10 +13,10 @@ # limitations under the License. import torch -from lightning.pytorch.core.datamodule import LightningDataModule from lightning_utilities.core.imports import RequirementCache from torch.utils.data import DataLoader +from lightning.pytorch.core.datamodule import LightningDataModule from tests_pytorch.helpers.datasets import MNIST, SklearnDataset, TrialMNIST _SKLEARN_AVAILABLE = RequirementCache("scikit-learn") diff --git a/tests/tests_pytorch/helpers/deterministic_model.py b/tests/tests_pytorch/helpers/deterministic_model.py index 95975e9ad1654..fbf3488158825 100644 --- a/tests/tests_pytorch/helpers/deterministic_model.py +++ b/tests/tests_pytorch/helpers/deterministic_model.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch -from lightning.pytorch.core.module import LightningModule from torch import Tensor, nn from torch.utils.data import DataLoader, Dataset +from lightning.pytorch.core.module import LightningModule + class DeterministicModel(LightningModule): def __init__(self, weights=None): diff --git a/tests/tests_pytorch/helpers/pipelines.py b/tests/tests_pytorch/helpers/pipelines.py index ab33878010123..b6c63a5702bfc 100644 --- a/tests/tests_pytorch/helpers/pipelines.py +++ b/tests/tests_pytorch/helpers/pipelines.py @@ -14,11 +14,11 @@ from functools import partial import torch +from torchmetrics.functional import accuracy + from lightning.pytorch import LightningDataModule, LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11 -from torchmetrics.functional import accuracy - from tests_pytorch.helpers.utils import get_default_logger, load_model_from_checkpoint diff --git a/tests/tests_pytorch/helpers/runif.py b/tests/tests_pytorch/helpers/runif.py index 1c5b059d679a5..25fadd524adf8 100644 --- a/tests/tests_pytorch/helpers/runif.py +++ b/tests/tests_pytorch/helpers/runif.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest + from lightning.pytorch.utilities.testing import _runif_reasons diff --git a/tests/tests_pytorch/helpers/simple_models.py b/tests/tests_pytorch/helpers/simple_models.py index 940adc0ac49a6..a9dc635bba275 100644 --- a/tests/tests_pytorch/helpers/simple_models.py +++ b/tests/tests_pytorch/helpers/simple_models.py @@ -15,11 +15,12 @@ import torch import torch.nn.functional as F -from lightning.pytorch import LightningModule from lightning_utilities.core.imports import compare_version from torch import nn from torchmetrics import Accuracy, MeanSquaredError +from lightning.pytorch import LightningModule + # using new API with task _TM_GE_0_11 = compare_version("torchmetrics", operator.ge, "0.11.0") diff --git a/tests/tests_pytorch/helpers/test_models.py b/tests/tests_pytorch/helpers/test_models.py index 7e44f79413863..721641ae8343a 100644 --- a/tests/tests_pytorch/helpers/test_models.py +++ b/tests/tests_pytorch/helpers/test_models.py @@ -14,10 +14,10 @@ import os import pytest + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel - -from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN +from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN, TBPTTModule from tests_pytorch.helpers.datamodules import ClassifDataModule, RegressDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel, RegressionModel @@ -49,3 +49,10 @@ def test_models(tmp_path, data_class, model_class): model.to_torchscript() if data_class: model.to_onnx(os.path.join(tmp_path, "my-model.onnx"), input_sample=dm.sample) + + +def test_tbptt(tmp_path): + model = TBPTTModule() + + trainer = Trainer(default_root_dir=tmp_path, max_epochs=1) + trainer.fit(model) diff --git a/tests/tests_pytorch/loggers/conftest.py b/tests/tests_pytorch/loggers/conftest.py index 7cc5cc94fe8cc..98819eb080eb8 100644 --- a/tests/tests_pytorch/loggers/conftest.py +++ b/tests/tests_pytorch/loggers/conftest.py @@ -18,7 +18,7 @@ import pytest -@pytest.fixture() +@pytest.fixture def mlflow_mock(monkeypatch): mlflow = ModuleType("mlflow") mlflow.set_tracking_uri = Mock() @@ -43,7 +43,7 @@ def mlflow_mock(monkeypatch): return mlflow -@pytest.fixture() +@pytest.fixture def wandb_mock(monkeypatch): class RunType: # to make isinstance checks pass pass @@ -55,6 +55,7 @@ class RunType: # to make isinstance checks pass watch=Mock(), log_artifact=Mock(), use_artifact=Mock(), + define_metric=Mock(), id="run_id", ) @@ -89,28 +90,30 @@ class RunType: # to make isinstance checks pass return wandb -@pytest.fixture() +@pytest.fixture def comet_mock(monkeypatch): comet = ModuleType("comet_ml") monkeypatch.setitem(sys.modules, "comet_ml", comet) - comet.Experiment = Mock() - comet.ExistingExperiment = Mock() - comet.OfflineExperiment = Mock() - comet.API = Mock() - comet.config = Mock() + # to support dunder methods calling we will create a special mock + comet_experiment = MagicMock(name="CommonExperiment") + setattr(comet_experiment, "__internal_api__set_model_graph__", MagicMock()) + setattr(comet_experiment, "__internal_api__log_metrics__", MagicMock()) + setattr(comet_experiment, "__internal_api__log_parameters__", MagicMock()) - comet_api = ModuleType("api") - comet_api.API = Mock() - monkeypatch.setitem(sys.modules, "comet_ml.api", comet_api) + comet.Experiment = MagicMock(name="Experiment", return_value=comet_experiment) + comet.ExistingExperiment = MagicMock(name="ExistingExperiment", return_value=comet_experiment) + comet.OfflineExperiment = MagicMock(name="OfflineExperiment", return_value=comet_experiment) - comet.api = comet_api + comet.ExperimentConfig = Mock() + comet.start = Mock(name="comet_ml.start", return_value=comet.Experiment()) + comet.config = Mock() monkeypatch.setattr("lightning.pytorch.loggers.comet._COMET_AVAILABLE", True) return comet -@pytest.fixture() +@pytest.fixture def neptune_mock(monkeypatch): class RunType: # to make isinstance checks pass def get_root_object(self): diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index 1b845c57ec35d..a9cf9af79185b 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -19,6 +19,7 @@ import pytest import torch + from lightning.pytorch import Callback, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import ( @@ -32,7 +33,6 @@ from lightning.pytorch.loggers.logger import DummyExperiment, Logger from lightning.pytorch.loggers.tensorboard import _TENSORBOARD_AVAILABLE from lightning.pytorch.tuner.tuning import Tuner - from tests_pytorch.helpers.runif import RunIf from tests_pytorch.loggers.test_comet import _patch_comet_atexit from tests_pytorch.loggers.test_mlflow import mock_mlflow_run_creation @@ -58,6 +58,8 @@ def _get_logger_args(logger_class, save_dir): logger_args.update(offline=True) if issubclass(logger_class, NeptuneLogger): logger_args.update(mode="offline") + if issubclass(logger_class, CometLogger): + logger_args.update(online=False) return logger_args @@ -105,7 +107,9 @@ def log_metrics(self, metrics, step): if logger_class == CometLogger: logger.experiment.id = "foo" - logger.experiment.project_name = "bar" + logger._comet_config.offline_directory = None + logger._project_name = "bar" + logger.experiment.get_key.return_value = "SOME_KEY" if logger_class == NeptuneLogger: logger._retrieve_run_data = Mock() @@ -292,7 +296,9 @@ def test_logger_with_prefix_all(mlflow_mock, wandb_mock, comet_mock, neptune_moc _patch_comet_atexit(monkeypatch) logger = _instantiate_logger(CometLogger, save_dir=tmp_path, prefix=prefix) logger.log_metrics({"test": 1.0}, step=0) - logger.experiment.log_metrics.assert_called_once_with({"tmp-test": 1.0}, epoch=None, step=0) + logger.experiment.__internal_api__log_metrics__.assert_called_once_with( + {"test": 1.0}, epoch=None, step=0, prefix=prefix, framework="pytorch-lightning" + ) # MLflow Metric = mlflow_mock.entities.Metric diff --git a/tests/tests_pytorch/loggers/test_comet.py b/tests/tests_pytorch/loggers/test_comet.py index e467c63543ede..dae8f617b873e 100644 --- a/tests/tests_pytorch/loggers/test_comet.py +++ b/tests/tests_pytorch/loggers/test_comet.py @@ -13,15 +13,14 @@ # limitations under the License. import os from unittest import mock -from unittest.mock import DEFAULT, Mock, patch +from unittest.mock import Mock, call -import pytest -from lightning.pytorch import Trainer -from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.loggers import CometLogger -from lightning.pytorch.utilities.exceptions import MisconfigurationException from torch import tensor +from lightning.pytorch.loggers import CometLogger + +FRAMEWORK_NAME = "pytorch-lightning" + def _patch_comet_atexit(monkeypatch): """Prevent comet logger from trying to print at exit, since pytest's stdout/stderr redirection breaks it.""" @@ -33,195 +32,163 @@ def _patch_comet_atexit(monkeypatch): @mock.patch.dict(os.environ, {}) def test_comet_logger_online(comet_mock): """Test comet online with mocks.""" - # Test api_key given - comet_experiment = comet_mock.Experiment - logger = CometLogger(api_key="key", workspace="dummy-test", project_name="general") - _ = logger.experiment - comet_experiment.assert_called_once_with(api_key="key", workspace="dummy-test", project_name="general") - - # Test both given - comet_experiment.reset_mock() - logger = CometLogger(save_dir="test", api_key="key", workspace="dummy-test", project_name="general") - _ = logger.experiment - comet_experiment.assert_called_once_with(api_key="key", workspace="dummy-test", project_name="general") - - # Test already exists - comet_existing = comet_mock.ExistingExperiment - logger = CometLogger( - experiment_key="test", - experiment_name="experiment", + + comet_start = comet_mock.start + + # Test api_key given with old param "project_name" + _logger = CometLogger(api_key="key", workspace="dummy-test", project_name="general") + comet_start.assert_called_once_with( api_key="key", workspace="dummy-test", - project_name="general", + project="general", + experiment_key=None, + mode=None, + online=None, + experiment_config=comet_mock.ExperimentConfig(), ) - _ = logger.experiment - comet_existing.assert_called_once_with( - api_key="key", workspace="dummy-test", project_name="general", previous_experiment="test" + + # Test online given + comet_start.reset_mock() + _logger = CometLogger(save_dir="test", api_key="key", workspace="dummy-test", project_name="general", online=True) + comet_start.assert_called_once_with( + api_key="key", + workspace="dummy-test", + project="general", + experiment_key=None, + mode=None, + online=True, + experiment_config=comet_mock.ExperimentConfig(), ) - comet_existing().set_name.assert_called_once_with("experiment") - # API experiment - api = comet_mock.api.API - CometLogger(api_key="key", workspace="dummy-test", project_name="general", rest_api_key="rest") - api.assert_called_once_with("rest") + # Test experiment_key given + comet_start.reset_mock() + _logger = CometLogger( + experiment_key="test_key", + api_key="key", + project="general", + ) + comet_start.assert_called_once_with( + api_key="key", + workspace=None, + project="general", + experiment_key="test_key", + mode=None, + online=None, + experiment_config=comet_mock.ExperimentConfig(), + ) @mock.patch.dict(os.environ, {}) -def test_comet_experiment_resets_if_not_alive(comet_mock): - """Test that the CometLogger creates a new experiment if the old one is not alive anymore.""" +def test_comet_experiment_is_still_alive_after_training_complete(comet_mock): + """Test that the CometLogger will not end an experiment after training is complete.""" + logger = CometLogger() - assert logger._experiment is None - alive_experiment = Mock(alive=True) - logger._experiment = alive_experiment - assert logger.experiment is alive_experiment + assert logger.experiment is not None - unalive_experiment = Mock(alive=False) - logger._experiment = unalive_experiment - assert logger.experiment is not unalive_experiment + logger._experiment = Mock() + logger.finalize("ended") + # Assert that data was saved to comet.com + logger._experiment.flush.assert_called_once() -@mock.patch.dict(os.environ, {}) -def test_comet_logger_no_api_key_given(comet_mock): - """Test that CometLogger fails to initialize if both api key and save_dir are missing.""" - with pytest.raises(MisconfigurationException, match="requires either api_key or save_dir"): - comet_mock.config.get_api_key.return_value = None - CometLogger(workspace="dummy-test", project_name="general") + # Assert that was not ended + logger._experiment.end.assert_not_called() @mock.patch.dict(os.environ, {}) def test_comet_logger_experiment_name(comet_mock): """Test that Comet Logger experiment name works correctly.""" - api_key = "key" - experiment_name = "My Name" + api_key = "api_key" + experiment_name = "My Experiment Name" - # Test api_key given - comet_experiment = comet_mock.Experiment + comet_start = comet_mock.start + + # here we use old style arg "experiment_name" (new one is "name") logger = CometLogger(api_key=api_key, experiment_name=experiment_name) - assert logger._experiment is None + comet_start.assert_called_once_with( + api_key=api_key, + workspace=None, + project=None, + experiment_key=None, + mode=None, + online=None, + experiment_config=comet_mock.ExperimentConfig(), + ) + # check that we saved "experiment name" in kwargs as new "name" arg + assert logger._kwargs["name"] == experiment_name + assert "experiment_name" not in logger._kwargs - _ = logger.experiment - comet_experiment.assert_called_once_with(api_key=api_key, project_name=None) - comet_experiment().set_name.assert_called_once_with(experiment_name) + # check that "experiment name" was passed to experiment config correctly + assert call(experiment_name=experiment_name) not in comet_mock.ExperimentConfig.call_args_list + assert call(name=experiment_name) in comet_mock.ExperimentConfig.call_args_list @mock.patch.dict(os.environ, {}) -def test_comet_logger_manual_experiment_key(comet_mock): - """Test that Comet Logger respects manually set COMET_EXPERIMENT_KEY.""" +def test_comet_version(comet_mock): + """Test that CometLogger.version returns an Experiment key.""" api_key = "key" - experiment_key = "96346da91469407a85641afe5766b554" - - instantiation_environ = {} - - def save_os_environ(*args, **kwargs): - nonlocal instantiation_environ - instantiation_environ = os.environ.copy() - - return DEFAULT - - comet_experiment = comet_mock.Experiment - comet_experiment.side_effect = save_os_environ - - # Test api_key given - with patch.dict(os.environ, {"COMET_EXPERIMENT_KEY": experiment_key}): - logger = CometLogger(api_key=api_key) - assert logger.version == experiment_key - assert logger._experiment is None + experiment_name = "My Name" - _ = logger.experiment - comet_experiment.assert_called_once_with(api_key=api_key, project_name=None) + logger = CometLogger(api_key=api_key, name=experiment_name) + assert logger._experiment is not None + _ = logger.version - assert instantiation_environ["COMET_EXPERIMENT_KEY"] == experiment_key + logger._experiment.get_key.assert_called() @mock.patch.dict(os.environ, {}) -def test_comet_logger_dirs_creation(comet_mock, tmp_path, monkeypatch): - """Test that the logger creates the folders and files in the right place.""" +def test_comet_epoch_logging(comet_mock, tmp_path, monkeypatch): + """Test that CometLogger removes the epoch key from the metrics dict and passes it as argument.""" _patch_comet_atexit(monkeypatch) - comet_experiment = comet_mock.OfflineExperiment - - comet_mock.config.get_api_key.return_value = None - comet_mock.generate_guid = Mock() - comet_mock.generate_guid.return_value = "4321" - logger = CometLogger(project_name="test", save_dir=str(tmp_path)) - assert not os.listdir(tmp_path) - assert logger.mode == "offline" - assert logger.save_dir == str(tmp_path) - assert logger.name == "test" - assert logger.version == "4321" - - _ = logger.experiment - comet_experiment.assert_called_once_with(offline_directory=str(tmp_path), project_name="test") - - # mock return values of experiment - logger.experiment.id = "1" - logger.experiment.project_name = "test" - - model = BoringModel() - trainer = Trainer( - default_root_dir=tmp_path, logger=logger, max_epochs=1, limit_train_batches=3, limit_val_batches=3 + logger.log_metrics({"test": 1, "epoch": 1}, step=123) + logger.experiment.__internal_api__log_metrics__.assert_called_once_with( + {"test": 1}, + epoch=1, + step=123, + prefix=logger._prefix, + framework="pytorch-lightning", ) - assert trainer.log_dir == logger.save_dir - trainer.fit(model) - - assert trainer.checkpoint_callback.dirpath == str(tmp_path / "test" / "1" / "checkpoints") - assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {"epoch=0-step=3.ckpt"} - assert trainer.log_dir == logger.save_dir @mock.patch.dict(os.environ, {}) -def test_comet_name_default(comet_mock): - """Test that CometLogger.name don't create an Experiment and returns a default value.""" - api_key = "key" - logger = CometLogger(api_key=api_key) - assert logger._experiment is None - assert logger.name == "comet-default" - assert logger._experiment is None - +def test_comet_log_hyperparams(comet_mock, tmp_path, monkeypatch): + """Test that CometLogger.log_hyperparams calls internal API method.""" + _patch_comet_atexit(monkeypatch) -@mock.patch.dict(os.environ, {}) -def test_comet_name_project_name(comet_mock): - """Test that CometLogger.name does not create an Experiment and returns project name if passed.""" - api_key = "key" - project_name = "My Project Name" - logger = CometLogger(api_key=api_key, project_name=project_name) - assert logger._experiment is None - assert logger.name == project_name - assert logger._experiment is None + logger = CometLogger(project_name="test") + hyperparams = { + "batch_size": 256, + "config": { + "SLURM Job ID": "22334455", + "RGB slurm jobID": "12345678", + "autoencoder_model": False, + }, + } + logger.log_hyperparams(hyperparams) + + logger.experiment.__internal_api__log_parameters__.assert_called_once_with( + parameters=hyperparams, + framework=FRAMEWORK_NAME, + flatten_nested=True, + source="manual", + ) @mock.patch.dict(os.environ, {}) -def test_comet_version_without_experiment(comet_mock): - """Test that CometLogger.version does not create an Experiment.""" - api_key = "key" - experiment_name = "My Name" - comet_mock.generate_guid = Mock() - comet_mock.generate_guid.return_value = "1234" - - logger = CometLogger(api_key=api_key, experiment_name=experiment_name) - assert logger._experiment is None - - first_version = logger.version - assert first_version is not None - assert logger.version == first_version - assert logger._experiment is None - - _ = logger.experiment - - logger.reset_experiment() +def test_comet_log_graph(comet_mock, tmp_path, monkeypatch): + """Test that CometLogger.log_hyperparams calls internal API method.""" + _patch_comet_atexit(monkeypatch) - second_version = logger.version == "1234" - assert second_version is not None - assert second_version != first_version + logger = CometLogger(project_name="test") + model = Mock() + logger.log_graph(model=model) -@mock.patch.dict(os.environ, {}) -def test_comet_epoch_logging(comet_mock, tmp_path, monkeypatch): - """Test that CometLogger removes the epoch key from the metrics dict and passes it as argument.""" - _patch_comet_atexit(monkeypatch) - logger = CometLogger(project_name="test", save_dir=str(tmp_path)) - logger.log_metrics({"test": 1, "epoch": 1}, step=123) - logger.experiment.log_metrics.assert_called_once_with({"test": 1}, epoch=1, step=123) + logger.experiment.__internal_api__set_model_graph__.assert_called_once_with( + graph=model, + framework="pytorch-lightning", + ) @mock.patch.dict(os.environ, {}) diff --git a/tests/tests_pytorch/loggers/test_csv.py b/tests/tests_pytorch/loggers/test_csv.py index 27b85bb4ad745..c131d03d38245 100644 --- a/tests/tests_pytorch/loggers/test_csv.py +++ b/tests/tests_pytorch/loggers/test_csv.py @@ -18,11 +18,11 @@ import fsspec import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.core.saving import load_hparams_from_yaml from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.loggers.csv_logs import ExperimentWriter - from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel @@ -75,7 +75,6 @@ def test_named_version(tmp_path): logger = CSVLogger(save_dir=tmp_path, name=exp_name, version=expected_version) logger.log_hyperparams({"a": 1, "b": 2}) - logger.save() assert logger.version == expected_version assert os.listdir(tmp_path / exp_name) == [expected_version] assert os.listdir(tmp_path / exp_name / expected_version) @@ -85,7 +84,7 @@ def test_named_version(tmp_path): def test_no_name(tmp_path, name): """Verify that None or empty name works.""" logger = CSVLogger(save_dir=tmp_path, name=name) - logger.save() + logger.log_hyperparams() assert os.path.normpath(logger.root_dir) == str(tmp_path) # use os.path.normpath to handle trailing / assert os.listdir(tmp_path / "version_0") @@ -116,7 +115,6 @@ def test_log_hyperparams(tmp_path): "layer": torch.nn.BatchNorm1d, } logger.log_hyperparams(hparams) - logger.save() path_yaml = os.path.join(logger.log_dir, ExperimentWriter.NAME_HPARAMS_FILE) params = load_hparams_from_yaml(path_yaml) diff --git a/tests/tests_pytorch/loggers/test_logger.py b/tests/tests_pytorch/loggers/test_logger.py index dcdd504fd4660..124a9120a9197 100644 --- a/tests/tests_pytorch/loggers/test_logger.py +++ b/tests/tests_pytorch/loggers/test_logger.py @@ -20,6 +20,7 @@ import numpy as np import pytest import torch + from lightning.fabric.utilities.logger import _convert_params, _sanitize_params from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index 14af36680904c..8118349ea6721 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -16,6 +16,7 @@ from unittest.mock import MagicMock, Mock import pytest + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers.mlflow import ( @@ -426,3 +427,33 @@ def test_set_tracking_uri(mlflow_mock): mlflow_mock.set_tracking_uri.assert_not_called() _ = logger.experiment mlflow_mock.set_tracking_uri.assert_called_with("the_tracking_uri") + + +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) +def test_mlflow_log_model_with_checkpoint_path_prefix(mlflow_mock, tmp_path): + """Test that the logger creates the folders and files in the right place with a prefix.""" + client = mlflow_mock.tracking.MlflowClient + + # Get model, logger, trainer and train + model = BoringModel() + logger = MLFlowLogger("test", save_dir=str(tmp_path), log_model="all", checkpoint_path_prefix="my_prefix") + logger = mock_mlflow_run_creation(logger, experiment_id="test-id") + + trainer = Trainer( + default_root_dir=tmp_path, + logger=logger, + max_epochs=2, + limit_train_batches=3, + limit_val_batches=3, + ) + trainer.fit(model) + + # Checkpoint log + assert client.return_value.log_artifact.call_count == 2 + # Metadata and aliases log + assert client.return_value.log_artifacts.call_count == 2 + + # Check that the prefix is used in the artifact path + for call in client.return_value.log_artifact.call_args_list: + args, _ = call + assert str(args[2]).startswith("my_prefix") diff --git a/tests/tests_pytorch/loggers/test_neptune.py b/tests/tests_pytorch/loggers/test_neptune.py index 0a39337ac5c16..6dc3816fac858 100644 --- a/tests/tests_pytorch/loggers/test_neptune.py +++ b/tests/tests_pytorch/loggers/test_neptune.py @@ -17,9 +17,10 @@ from unittest import mock from unittest.mock import MagicMock, call -import lightning.pytorch as pl import pytest import torch + +import lightning.pytorch as pl from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import NeptuneLogger @@ -255,11 +256,10 @@ def test_after_save_checkpoint(neptune_mock): mock_file.side_effect = mock.Mock() logger.after_save_checkpoint(cb_mock) - assert run_instance_mock.__setitem__.call_count == 3 - assert run_instance_mock.__getitem__.call_count == 2 - assert run_attr_mock.upload.call_count == 2 - - assert mock_file.from_stream.call_count == 2 + assert run_instance_mock.__setitem__.call_count == 1 # best_model_path + assert run_instance_mock.__getitem__.call_count == 4 # last_model_path, best_k_models, best_model_path + assert run_attr_mock.upload.call_count == 4 # last_model_path, best_k_models, best_model_path + assert mock_file.from_stream.call_count == 0 run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model1") run_instance_mock.__getitem__.assert_any_call(f"{model_key_prefix}/checkpoints/model2/with/slashes") diff --git a/tests/tests_pytorch/loggers/test_tensorboard.py b/tests/tests_pytorch/loggers/test_tensorboard.py index 82ffff25cac7c..173805f1a6f3c 100644 --- a/tests/tests_pytorch/loggers/test_tensorboard.py +++ b/tests/tests_pytorch/loggers/test_tensorboard.py @@ -20,12 +20,12 @@ import pytest import torch import yaml + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.loggers.tensorboard import _TENSORBOARD_AVAILABLE from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE - from tests_pytorch.helpers.runif import RunIf if _OMEGACONF_AVAILABLE: diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index a8e70bfb6589d..35c1917983dcf 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -18,14 +18,14 @@ import pytest import yaml +from lightning_utilities.test.warning import no_warning_call + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.cli import LightningCLI from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning_utilities.test.warning import no_warning_call - from tests_pytorch.test_cli import _xfail_python_ge_3_11_9 @@ -126,6 +126,24 @@ def test_wandb_logger_init(wandb_mock): assert logger.version == wandb_mock.init().id +def test_wandb_logger_sync_tensorboard(wandb_mock): + logger = WandbLogger(sync_tensorboard=True) + wandb_mock.run = None + logger.experiment + + # test that tensorboard's global_step is set as the default x-axis if sync_tensorboard=True + wandb_mock.init.return_value.define_metric.assert_called_once_with("*", step_metric="global_step") + + +def test_wandb_logger_sync_tensorboard_log_metrics(wandb_mock): + logger = WandbLogger(sync_tensorboard=True) + metrics = {"loss": 1e-3, "accuracy": 0.99} + logger.log_metrics(metrics) + + # test that trainer/global_step is not added to the logged metrics if sync_tensorboard=True + wandb_mock.run.log.assert_called_once_with(metrics) + + def test_wandb_logger_init_before_spawn(wandb_mock): logger = WandbLogger() assert logger._experiment is None @@ -133,6 +151,43 @@ def test_wandb_logger_init_before_spawn(wandb_mock): assert logger._experiment is not None +def test_wandb_logger_experiment_called_first(wandb_mock, tmp_path): + wandb_experiment_called = False + + def tensorboard_experiment_side_effect() -> mock.MagicMock: + nonlocal wandb_experiment_called + assert wandb_experiment_called + return mock.MagicMock() + + def wandb_experiment_side_effect() -> mock.MagicMock: + nonlocal wandb_experiment_called + wandb_experiment_called = True + return mock.MagicMock() + + with ( + mock.patch.object( + TensorBoardLogger, + "experiment", + new_callable=lambda: mock.PropertyMock(side_effect=tensorboard_experiment_side_effect), + ), + mock.patch.object( + WandbLogger, + "experiment", + new_callable=lambda: mock.PropertyMock(side_effect=wandb_experiment_side_effect), + ), + ): + model = BoringModel() + trainer = Trainer( + default_root_dir=tmp_path, + log_every_n_steps=1, + limit_train_batches=0, + limit_val_batches=0, + max_steps=1, + logger=[TensorBoardLogger(tmp_path), WandbLogger(save_dir=tmp_path)], + ) + trainer.fit(model) + + def test_wandb_pickle(wandb_mock, tmp_path): """Verify that pickling trainer with wandb logger works. @@ -371,6 +426,44 @@ def test_wandb_log_model(wandb_mock, tmp_path): ) wandb_mock.init().log_artifact.assert_called_with(wandb_mock.Artifact(), aliases=["latest", "best"]) + # Test wandb artifact with two checkpoint_callbacks + wandb_mock.init().log_artifact.reset_mock() + wandb_mock.init.reset_mock() + wandb_mock.Artifact.reset_mock() + logger = WandbLogger(save_dir=tmp_path, log_model=True) + logger.experiment.id = "1" + logger.experiment.name = "run_name" + trainer = Trainer( + default_root_dir=tmp_path, + logger=logger, + max_epochs=3, + limit_train_batches=3, + limit_val_batches=3, + callbacks=[ + ModelCheckpoint(monitor="epoch", save_top_k=2), + ModelCheckpoint(monitor="step", save_top_k=2), + ], + ) + trainer.fit(model) + for name, val, version in [("epoch", 0, 2), ("step", 3, 3)]: + wandb_mock.Artifact.assert_any_call( + name="model-1", + type="model", + metadata={ + "score": val, + "original_filename": f"epoch=0-step=3-v{version}.ckpt", + "ModelCheckpoint": { + "monitor": name, + "mode": "min", + "save_last": None, + "save_top_k": 2, + "save_weights_only": False, + "_every_n_train_steps": 0, + }, + }, + ) + wandb_mock.init().log_artifact.assert_any_call(wandb_mock.Artifact(), aliases=["latest"]) + def test_wandb_log_model_with_score(wandb_mock, tmp_path): """Test to prevent regression on #15543, ensuring the score is logged as a Python number, not a scalar tensor.""" diff --git a/tests/tests_pytorch/loops/optimization/test_automatic_loop.py b/tests/tests_pytorch/loops/optimization/test_automatic_loop.py index 2fb04d0d9d8d1..e20c1789be023 100644 --- a/tests/tests_pytorch/loops/optimization/test_automatic_loop.py +++ b/tests/tests_pytorch/loops/optimization/test_automatic_loop.py @@ -17,6 +17,7 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loops.optimization.automatic import ClosureResult diff --git a/tests/tests_pytorch/loops/optimization/test_closure.py b/tests/tests_pytorch/loops/optimization/test_closure.py index d7d4e51794aca..7766a385c3057 100644 --- a/tests/tests_pytorch/loops/optimization/test_closure.py +++ b/tests/tests_pytorch/loops/optimization/test_closure.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.exceptions import MisconfigurationException diff --git a/tests/tests_pytorch/loops/optimization/test_manual_loop.py b/tests/tests_pytorch/loops/optimization/test_manual_loop.py index 67be30b24e159..cedfefb4791ea 100644 --- a/tests/tests_pytorch/loops/optimization/test_manual_loop.py +++ b/tests/tests_pytorch/loops/optimization/test_manual_loop.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loops.optimization.manual import ManualResult diff --git a/tests/tests_pytorch/loops/test_all.py b/tests/tests_pytorch/loops/test_all.py index 1eb67064fb300..51b7bbeedf90b 100644 --- a/tests/tests_pytorch/loops/test_all.py +++ b/tests/tests_pytorch/loops/test_all.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest + from lightning.pytorch import Callback, Trainer from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/loops/test_evaluation_loop.py b/tests/tests_pytorch/loops/test_evaluation_loop.py index 588672e19a05c..6d8c4be70d9c6 100644 --- a/tests/tests_pytorch/loops/test_evaluation_loop.py +++ b/tests/tests_pytorch/loops/test_evaluation_loop.py @@ -16,13 +16,13 @@ import pytest import torch +from torch.utils.data.dataloader import DataLoader +from torch.utils.data.sampler import BatchSampler, RandomSampler + from lightning.fabric.accelerators.cuda import _clear_cuda_memory from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.utilities import CombinedLoader -from torch.utils.data.dataloader import DataLoader -from torch.utils.data.sampler import BatchSampler, RandomSampler - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/loops/test_evaluation_loop_flow.py b/tests/tests_pytorch/loops/test_evaluation_loop_flow.py index 09ded9a777679..e698a15cb0fb2 100644 --- a/tests/tests_pytorch/loops/test_evaluation_loop_flow.py +++ b/tests/tests_pytorch/loops/test_evaluation_loop_flow.py @@ -14,11 +14,11 @@ """Tests the evaluation loop.""" import torch +from torch import Tensor + from lightning.pytorch import Trainer from lightning.pytorch.core.module import LightningModule from lightning.pytorch.trainer.states import RunningStage -from torch import Tensor - from tests_pytorch.helpers.deterministic_model import DeterministicModel diff --git a/tests/tests_pytorch/loops/test_fetchers.py b/tests/tests_pytorch/loops/test_fetchers.py index 75b25e3d98fd8..f66e9f9f3b16f 100644 --- a/tests/tests_pytorch/loops/test_fetchers.py +++ b/tests/tests_pytorch/loops/test_fetchers.py @@ -17,6 +17,9 @@ import pytest import torch +from torch import Tensor +from torch.utils.data import DataLoader, Dataset, IterableDataset + from lightning.pytorch import LightningDataModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.loops.fetchers import _DataLoaderIterDataFetcher, _PrefetchDataFetcher @@ -24,9 +27,6 @@ from lightning.pytorch.utilities.combined_loader import CombinedLoader from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.types import STEP_OUTPUT -from torch import Tensor -from torch.utils.data import DataLoader, Dataset, IterableDataset - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/loops/test_loops.py b/tests/tests_pytorch/loops/test_loops.py index 1820ca3568173..384ae2b47859b 100644 --- a/tests/tests_pytorch/loops/test_loops.py +++ b/tests/tests_pytorch/loops/test_loops.py @@ -20,6 +20,8 @@ import pytest import torch +from torch.utils.data.dataloader import DataLoader, _MultiProcessingDataLoaderIter + from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import Callback, ModelCheckpoint, OnExceptionCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset @@ -27,8 +29,6 @@ from lightning.pytorch.loops.progress import _BaseProgress from lightning.pytorch.utilities import CombinedLoader from lightning.pytorch.utilities.types import STEP_OUTPUT -from torch.utils.data.dataloader import DataLoader, _MultiProcessingDataLoaderIter - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/loops/test_prediction_loop.py b/tests/tests_pytorch/loops/test_prediction_loop.py index f27413955cae9..470cbcdc195f5 100644 --- a/tests/tests_pytorch/loops/test_prediction_loop.py +++ b/tests/tests_pytorch/loops/test_prediction_loop.py @@ -14,10 +14,11 @@ import itertools import pytest +from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler + from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper -from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler def test_prediction_loop_stores_predictions(tmp_path): diff --git a/tests/tests_pytorch/loops/test_progress.py b/tests/tests_pytorch/loops/test_progress.py index e7256d4504402..27184d7b17afb 100644 --- a/tests/tests_pytorch/loops/test_progress.py +++ b/tests/tests_pytorch/loops/test_progress.py @@ -14,6 +14,7 @@ from copy import deepcopy import pytest + from lightning.pytorch.loops.progress import ( _BaseProgress, _OptimizerProgress, diff --git a/tests/tests_pytorch/loops/test_training_epoch_loop.py b/tests/tests_pytorch/loops/test_training_epoch_loop.py index a110a20bfaf84..35a9dacae766d 100644 --- a/tests/tests_pytorch/loops/test_training_epoch_loop.py +++ b/tests/tests_pytorch/loops/test_training_epoch_loop.py @@ -16,11 +16,12 @@ import pytest import torch +from lightning_utilities.test.warning import no_warning_call + from lightning.fabric.utilities.warnings import PossibleUserWarning -from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.trainer.trainer import Trainer -from lightning_utilities.test.warning import no_warning_call def test_no_val_on_train_epoch_loop_restart(tmp_path): @@ -91,7 +92,16 @@ def test_should_stop_triggers_validation_once(min_epochs, min_steps, val_count, (min_epochs/steps is satisfied). """ - model = BoringModel() + + class NewBoring(BoringModel): + def training_step(self, batch, batch_idx): + self.log("loss", self.step(batch)) + return {"loss": self.step(batch)} + + model = NewBoring() + # create a stopping condition with a high threshold so it triggers immediately + # check the condition before validation so the count is unaffected + stopping = EarlyStopping(monitor="loss", check_on_train_epoch_end=True, stopping_threshold=100) trainer = Trainer( default_root_dir=tmp_path, num_sanity_val_steps=0, @@ -102,8 +112,8 @@ def test_should_stop_triggers_validation_once(min_epochs, min_steps, val_count, min_steps=min_steps, enable_model_summary=False, enable_checkpointing=False, + callbacks=[stopping], ) - trainer.should_stop = True # Request to stop before min_epochs/min_steps are reached trainer.fit_loop.epoch_loop.val_loop.run = Mock() trainer.fit(model) assert trainer.fit_loop.epoch_loop.val_loop.run.call_count == val_count diff --git a/tests/tests_pytorch/loops/test_training_loop.py b/tests/tests_pytorch/loops/test_training_loop.py index 2afeb338fd9fd..29afd1ba1a250 100644 --- a/tests/tests_pytorch/loops/test_training_loop.py +++ b/tests/tests_pytorch/loops/test_training_loop.py @@ -16,6 +16,7 @@ import pytest import torch + from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loops import _FitLoop diff --git a/tests/tests_pytorch/loops/test_training_loop_flow_dict.py b/tests/tests_pytorch/loops/test_training_loop_flow_dict.py index c89913b29dbdb..44a0c85184d9e 100644 --- a/tests/tests_pytorch/loops/test_training_loop_flow_dict.py +++ b/tests/tests_pytorch/loops/test_training_loop_flow_dict.py @@ -14,9 +14,9 @@ """Tests to ensure that the training loop works with a dict (1.0)""" import torch + from lightning.pytorch import Trainer from lightning.pytorch.core.module import LightningModule - from tests_pytorch.helpers.deterministic_model import DeterministicModel diff --git a/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py b/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py index ccdeb55d50ea8..4a4e8cb81c6a2 100644 --- a/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py +++ b/tests/tests_pytorch/loops/test_training_loop_flow_scalar.py @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest -from lightning.pytorch import Trainer -from lightning.pytorch.core.module import LightningModule -from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset -from lightning.pytorch.loops.optimization.automatic import Closure -from lightning.pytorch.trainer.states import RunningStage from lightning_utilities.test.warning import no_warning_call from torch import Tensor from torch.utils.data import DataLoader from torch.utils.data._utils.collate import default_collate +from lightning.pytorch import Trainer +from lightning.pytorch.core.module import LightningModule +from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset +from lightning.pytorch.loops.optimization.automatic import Closure +from lightning.pytorch.trainer.states import RunningStage from tests_pytorch.helpers.deterministic_model import DeterministicModel diff --git a/tests/tests_pytorch/models/test_amp.py b/tests/tests_pytorch/models/test_amp.py index c28d6300131ae..24323f5c1d691 100644 --- a/tests/tests_pytorch/models/test_amp.py +++ b/tests/tests_pytorch/models/test_amp.py @@ -16,12 +16,12 @@ import pytest import torch -from lightning.fabric.plugins.environments import SLURMEnvironment -from lightning.pytorch import Trainer -from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from torch.utils.data import DataLoader import tests_pytorch.helpers.utils as tutils +from lightning.fabric.plugins.environments import SLURMEnvironment +from lightning.pytorch import Trainer +from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/models/test_cpu.py b/tests/tests_pytorch/models/test_cpu.py index 435b8437760d2..a2d38aca7c56c 100644 --- a/tests/tests_pytorch/models/test_cpu.py +++ b/tests/tests_pytorch/models/test_cpu.py @@ -15,12 +15,12 @@ from unittest import mock import torch -from lightning.pytorch import Trainer, seed_everything -from lightning.pytorch.callbacks import Callback, EarlyStopping, ModelCheckpoint -from lightning.pytorch.demos.boring_classes import BoringModel import tests_pytorch.helpers.pipelines as tpipes import tests_pytorch.helpers.utils as tutils +from lightning.pytorch import Trainer, seed_everything +from lightning.pytorch.callbacks import Callback, EarlyStopping, ModelCheckpoint +from lightning.pytorch.demos.boring_classes import BoringModel from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel diff --git a/tests/tests_pytorch/models/test_ddp_fork_amp.py b/tests/tests_pytorch/models/test_ddp_fork_amp.py index 54d394948eeee..a1d7c6a7f8ac1 100644 --- a/tests/tests_pytorch/models/test_ddp_fork_amp.py +++ b/tests/tests_pytorch/models/test_ddp_fork_amp.py @@ -14,8 +14,8 @@ import multiprocessing import torch -from lightning.pytorch.plugins import MixedPrecision +from lightning.pytorch.plugins import MixedPrecision from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/models/test_fabric_integration.py b/tests/tests_pytorch/models/test_fabric_integration.py index fd5ecd34e96be..3e2811e8a9413 100644 --- a/tests/tests_pytorch/models/test_fabric_integration.py +++ b/tests/tests_pytorch/models/test_fabric_integration.py @@ -16,6 +16,7 @@ from unittest.mock import Mock import torch + from lightning.fabric import Fabric from lightning.pytorch.demos.boring_classes import BoringModel, ManualOptimBoringModel diff --git a/tests/tests_pytorch/models/test_gpu.py b/tests/tests_pytorch/models/test_gpu.py index b411774c3e164..797120312436f 100644 --- a/tests/tests_pytorch/models/test_gpu.py +++ b/tests/tests_pytorch/models/test_gpu.py @@ -18,14 +18,14 @@ import pytest import torch + +import tests_pytorch.helpers.pipelines as tpipes from lightning.fabric.plugins.environments import TorchElasticEnvironment from lightning.fabric.utilities.device_parser import _parse_gpu_ids from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.accelerators import CPUAccelerator, CUDAAccelerator from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.exceptions import MisconfigurationException - -import tests_pytorch.helpers.pipelines as tpipes from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel diff --git a/tests/tests_pytorch/models/test_hooks.py b/tests/tests_pytorch/models/test_hooks.py index 1a8aeb4b297a9..e943d0533cab5 100644 --- a/tests/tests_pytorch/models/test_hooks.py +++ b/tests/tests_pytorch/models/test_hooks.py @@ -18,12 +18,12 @@ import pytest import torch -from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer, __version__ -from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset -from lightning.pytorch.utilities.model_helpers import is_overridden from torch import Tensor from torch.utils.data import DataLoader +from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer, __version__ +from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset +from lightning.pytorch.utilities.model_helpers import is_overridden from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index 871b1cba673eb..6fd400aab2724 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -25,6 +25,10 @@ import pytest import torch from fsspec.implementations.local import LocalFileSystem +from lightning_utilities.core.imports import RequirementCache +from lightning_utilities.test.warning import no_warning_call +from torch.utils.data import DataLoader + from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.core.datamodule import LightningDataModule @@ -34,10 +38,6 @@ from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger from lightning.pytorch.utilities import AttributeDict, is_picklable from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE -from lightning_utilities.core.imports import RequirementCache -from lightning_utilities.test.warning import no_warning_call -from torch.utils.data import DataLoader - from tests_pytorch.helpers.runif import RunIf if _OMEGACONF_AVAILABLE: diff --git a/tests/tests_pytorch/models/test_onnx.py b/tests/tests_pytorch/models/test_onnx.py index ee670cd66e871..81fd5631a3400 100644 --- a/tests/tests_pytorch/models/test_onnx.py +++ b/tests/tests_pytorch/models/test_onnx.py @@ -21,11 +21,11 @@ import onnxruntime import pytest import torch -from lightning.pytorch import Trainer -from lightning.pytorch.demos.boring_classes import BoringModel from lightning_utilities import compare_version import tests_pytorch.helpers.pipelines as tpipes +from lightning.pytorch import Trainer +from lightning.pytorch.demos.boring_classes import BoringModel from tests_pytorch.helpers.runif import RunIf from tests_pytorch.utilities.test_model_summary import UnorderedModel @@ -111,17 +111,17 @@ def test_model_saves_on_multi_gpu(tmp_path): assert os.path.exists(file_path) is True -@RunIf(onnx=True) +# todo: investigate where the logging happening in torch.onnx for PT 2.6+ +@RunIf(onnx=True, max_torch="2.6.0") def test_verbose_param(tmp_path, capsys): """Test that output is present when verbose parameter is set.""" model = BoringModel() model.example_input_array = torch.randn(5, 32) file_path = os.path.join(tmp_path, "model.onnx") - with patch("torch.onnx.log", autospec=True) as test: + with patch("torch.onnx.log", autospec=True) as mocked: model.to_onnx(file_path, verbose=True) - args, _ = test.call_args - prefix, _ = args + (prefix, _), _ = mocked.call_args assert prefix == "Exported graph: " diff --git a/tests/tests_pytorch/models/test_restore.py b/tests/tests_pytorch/models/test_restore.py index 64f70b176a971..099493890831d 100644 --- a/tests/tests_pytorch/models/test_restore.py +++ b/tests/tests_pytorch/models/test_restore.py @@ -22,16 +22,16 @@ import cloudpickle import pytest import torch -from lightning.fabric import seed_everything -from lightning.pytorch import Callback, Trainer -from lightning.pytorch.callbacks import ModelCheckpoint -from lightning.pytorch.demos.boring_classes import BoringModel -from lightning.pytorch.trainer.states import TrainerFn from lightning_utilities.test.warning import no_warning_call from torch import Tensor import tests_pytorch.helpers.pipelines as tpipes import tests_pytorch.helpers.utils as tutils +from lightning.fabric import seed_everything +from lightning.pytorch import Callback, Trainer +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.demos.boring_classes import BoringModel +from lightning.pytorch.trainer.states import TrainerFn from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel diff --git a/tests/tests_pytorch/models/test_torchscript.py b/tests/tests_pytorch/models/test_torchscript.py index 993085729e545..8f9151265d21a 100644 --- a/tests/tests_pytorch/models/test_torchscript.py +++ b/tests/tests_pytorch/models/test_torchscript.py @@ -18,11 +18,11 @@ import pytest import torch from fsspec.implementations.local import LocalFileSystem + from lightning.fabric.utilities.cloud_io import get_filesystem from lightning.fabric.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_2_4 from lightning.pytorch.core.module import LightningModule from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleRNN from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/models/test_tpu.py b/tests/tests_pytorch/models/test_tpu.py index af927e0d5596a..8067fd63b6562 100644 --- a/tests/tests_pytorch/models/test_tpu.py +++ b/tests/tests_pytorch/models/test_tpu.py @@ -17,6 +17,9 @@ import pytest import torch +from torch.utils.data import DataLoader + +import tests_pytorch.helpers.pipelines as tpipes from lightning.pytorch import Trainer from lightning.pytorch.accelerators import XLAAccelerator from lightning.pytorch.callbacks import EarlyStopping @@ -25,9 +28,6 @@ from lightning.pytorch.strategies.launchers.xla import _XLALauncher from lightning.pytorch.trainer.connectors.logger_connector.result import _Sync from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch.utils.data import DataLoader - -import tests_pytorch.helpers.pipelines as tpipes from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/overrides/test_distributed.py b/tests/tests_pytorch/overrides/test_distributed.py index 3e2fba54bcd03..dd5f8f1504af7 100644 --- a/tests/tests_pytorch/overrides/test_distributed.py +++ b/tests/tests_pytorch/overrides/test_distributed.py @@ -15,11 +15,11 @@ import pytest import torch +from torch.utils.data import BatchSampler, SequentialSampler + from lightning.fabric.utilities.data import has_len from lightning.pytorch import LightningModule, Trainer, seed_everything from lightning.pytorch.overrides.distributed import UnrepeatedDistributedSampler, _IndexBatchSamplerWrapper -from torch.utils.data import BatchSampler, SequentialSampler - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/plugins/precision/test_all.py b/tests/tests_pytorch/plugins/precision/test_all.py index 2668311c8b452..2a11b7c66c772 100644 --- a/tests/tests_pytorch/plugins/precision/test_all.py +++ b/tests/tests_pytorch/plugins/precision/test_all.py @@ -1,5 +1,6 @@ import pytest import torch + from lightning.pytorch.plugins import ( DeepSpeedPrecision, DoublePrecision, diff --git a/tests/tests_pytorch/plugins/precision/test_amp.py b/tests/tests_pytorch/plugins/precision/test_amp.py index 90ecc703c8945..cb061c540b2be 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp.py +++ b/tests/tests_pytorch/plugins/precision/test_amp.py @@ -14,9 +14,10 @@ from unittest.mock import Mock import pytest +from torch.optim import Optimizer + from lightning.pytorch.plugins import MixedPrecision from lightning.pytorch.utilities import GradClipAlgorithmType -from torch.optim import Optimizer def test_clip_gradients(): diff --git a/tests/tests_pytorch/plugins/precision/test_amp_integration.py b/tests/tests_pytorch/plugins/precision/test_amp_integration.py index bc9f77907919a..f231e3ce91e0b 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp_integration.py +++ b/tests/tests_pytorch/plugins/precision/test_amp_integration.py @@ -14,12 +14,12 @@ from unittest.mock import Mock import torch + from lightning.fabric import seed_everything from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.plugins.precision import MixedPrecision - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/plugins/precision/test_bitsandbytes.py b/tests/tests_pytorch/plugins/precision/test_bitsandbytes.py index a88e38d6303e4..a478a2b9831a1 100644 --- a/tests/tests_pytorch/plugins/precision/test_bitsandbytes.py +++ b/tests/tests_pytorch/plugins/precision/test_bitsandbytes.py @@ -11,19 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License +import platform import sys from unittest.mock import Mock -import lightning.fabric import pytest import torch import torch.distributed + +import lightning.fabric from lightning.fabric.plugins.precision.bitsandbytes import _BITSANDBYTES_AVAILABLE from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.plugins.precision.bitsandbytes import BitsandbytesPrecision @pytest.mark.skipif(_BITSANDBYTES_AVAILABLE, reason="bitsandbytes needs to be unavailable") +@pytest.mark.skipif(platform.system() == "Darwin", reason="Bitsandbytes is only supported on CUDA GPUs") # skip on Mac def test_bitsandbytes_plugin(monkeypatch): module = lightning.fabric.plugins.precision.bitsandbytes monkeypatch.setattr(module, "_BITSANDBYTES_AVAILABLE", lambda: True) diff --git a/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py b/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py index 3e1aaa17763e9..da4ce6b89aaab 100644 --- a/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py +++ b/tests/tests_pytorch/plugins/precision/test_deepspeed_precision.py @@ -14,6 +14,7 @@ import pytest import torch + from lightning.pytorch.plugins.precision.deepspeed import DeepSpeedPrecision diff --git a/tests/tests_pytorch/plugins/precision/test_double.py b/tests/tests_pytorch/plugins/precision/test_double.py index 1ee89752fcbae..74c178c65d05d 100644 --- a/tests/tests_pytorch/plugins/precision/test_double.py +++ b/tests/tests_pytorch/plugins/precision/test_double.py @@ -16,11 +16,11 @@ import pytest import torch +from torch.utils.data import DataLoader, Dataset + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.plugins.precision.double import DoublePrecision -from torch.utils.data import DataLoader, Dataset - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/plugins/precision/test_fsdp.py b/tests/tests_pytorch/plugins/precision/test_fsdp.py index 3ad3af1f1b56b..0389d364dcb79 100644 --- a/tests/tests_pytorch/plugins/precision/test_fsdp.py +++ b/tests/tests_pytorch/plugins/precision/test_fsdp.py @@ -15,9 +15,9 @@ import pytest import torch + from lightning.fabric.plugins.precision.utils import _DtypeContextManager from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/plugins/precision/test_half.py b/tests/tests_pytorch/plugins/precision/test_half.py index d51392a00e3d0..9597e01ea428b 100644 --- a/tests/tests_pytorch/plugins/precision/test_half.py +++ b/tests/tests_pytorch/plugins/precision/test_half.py @@ -14,6 +14,7 @@ import pytest import torch + from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.plugins import HalfPrecision diff --git a/tests/tests_pytorch/plugins/precision/test_transformer_engine.py b/tests/tests_pytorch/plugins/precision/test_transformer_engine.py index 7c92ff47d909a..a9967280e3f23 100644 --- a/tests/tests_pytorch/plugins/precision/test_transformer_engine.py +++ b/tests/tests_pytorch/plugins/precision/test_transformer_engine.py @@ -15,9 +15,10 @@ from contextlib import nullcontext from unittest.mock import ANY, Mock -import lightning.fabric import pytest import torch + +import lightning.fabric from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.plugins import TransformerEnginePrecision from lightning.pytorch.trainer.connectors.accelerator_connector import _AcceleratorConnector diff --git a/tests/tests_pytorch/plugins/precision/test_xla.py b/tests/tests_pytorch/plugins/precision/test_xla.py index 97990b6380dab..b456c49e8ff50 100644 --- a/tests/tests_pytorch/plugins/precision/test_xla.py +++ b/tests/tests_pytorch/plugins/precision/test_xla.py @@ -18,8 +18,8 @@ import pytest import torch -from lightning.pytorch.plugins import XLAPrecision +from lightning.pytorch.plugins import XLAPrecision from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/plugins/test_amp_plugins.py b/tests/tests_pytorch/plugins/test_amp_plugins.py index e8adb59c39e51..0b68c098cc713 100644 --- a/tests/tests_pytorch/plugins/test_amp_plugins.py +++ b/tests/tests_pytorch/plugins/test_amp_plugins.py @@ -18,11 +18,11 @@ import pytest import torch +from torch import Tensor + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.plugins import MixedPrecision -from torch import Tensor - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py index 58baa47e7a620..cae26fc1fe775 100644 --- a/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py +++ b/tests/tests_pytorch/plugins/test_checkpoint_io_plugin.py @@ -17,6 +17,7 @@ from unittest.mock import MagicMock, Mock import torch + from lightning.fabric.plugins import CheckpointIO, TorchCheckpointIO from lightning.fabric.utilities.types import _PATH from lightning.pytorch import Trainer diff --git a/tests/tests_pytorch/plugins/test_cluster_integration.py b/tests/tests_pytorch/plugins/test_cluster_integration.py index 026465ac8b17b..08bd1707b5cfd 100644 --- a/tests/tests_pytorch/plugins/test_cluster_integration.py +++ b/tests/tests_pytorch/plugins/test_cluster_integration.py @@ -16,11 +16,11 @@ import pytest import torch + from lightning.fabric.plugins.environments import LightningEnvironment, SLURMEnvironment, TorchElasticEnvironment from lightning.pytorch import Trainer from lightning.pytorch.strategies import DDPStrategy, DeepSpeedStrategy from lightning.pytorch.utilities.rank_zero import rank_zero_only - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/profilers/test_profiler.py b/tests/tests_pytorch/profilers/test_profiler.py index 5b0c13e605ee2..d0221d12e317f 100644 --- a/tests/tests_pytorch/profilers/test_profiler.py +++ b/tests/tests_pytorch/profilers/test_profiler.py @@ -22,6 +22,7 @@ import numpy as np import pytest import torch + from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 from lightning.pytorch import Callback, Trainer from lightning.pytorch.callbacks import EarlyStopping, StochasticWeightAveraging @@ -30,7 +31,6 @@ from lightning.pytorch.profilers import AdvancedProfiler, PassThroughProfiler, PyTorchProfiler, SimpleProfiler from lightning.pytorch.profilers.pytorch import _KINETO_AVAILABLE, RegisterRecordFunction, warning_cache from lightning.pytorch.utilities.exceptions import MisconfigurationException - from tests_pytorch.helpers.runif import RunIf PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0005 @@ -55,7 +55,7 @@ def _sleep_generator(durations): yield duration -@pytest.fixture() +@pytest.fixture def simple_profiler(): return SimpleProfiler() @@ -69,7 +69,7 @@ def test_simple_profiler_durations(simple_profiler, action: str, expected: list) time.sleep(duration) # different environments have different precision when it comes to time.sleep() - # see: https://github.com/Lightning-AI/lightning/issues/796 + # see: https://github.com/Lightning-AI/pytorch-lightning/issues/796 np.testing.assert_allclose(simple_profiler.recorded_durations[action], expected, rtol=0.2) @@ -264,7 +264,7 @@ def test_simple_profiler_summary(tmp_path, extended): assert expected_text == summary -@pytest.fixture() +@pytest.fixture def advanced_profiler(tmp_path): return AdvancedProfiler(dirpath=tmp_path, filename="profiler") @@ -277,7 +277,7 @@ def test_advanced_profiler_durations(advanced_profiler, action: str, expected: l time.sleep(duration) # different environments have different precision when it comes to time.sleep() - # see: https://github.com/Lightning-AI/lightning/issues/796 + # see: https://github.com/Lightning-AI/pytorch-lightning/issues/796 recorded_total_duration = _get_python_cprofile_total_duration(advanced_profiler.profiled_actions[action]) expected_total_duration = np.sum(expected) np.testing.assert_allclose(recorded_total_duration, expected_total_duration, rtol=0.2) @@ -336,7 +336,7 @@ def test_advanced_profiler_deepcopy(advanced_profiler): assert deepcopy(advanced_profiler) -@pytest.fixture() +@pytest.fixture def pytorch_profiler(tmp_path): return PyTorchProfiler(dirpath=tmp_path, filename="profiler") diff --git a/tests/tests_pytorch/profilers/test_xla_profiler.py b/tests/tests_pytorch/profilers/test_xla_profiler.py index 980a4dac74731..80337382ddcb5 100644 --- a/tests/tests_pytorch/profilers/test_xla_profiler.py +++ b/tests/tests_pytorch/profilers/test_xla_profiler.py @@ -16,10 +16,10 @@ from unittest import mock import pytest + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.profilers import XLAProfiler - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/serve/test_servable_module_validator.py b/tests/tests_pytorch/serve/test_servable_module_validator.py index ec4dd8825c8ea..ba90949132ba2 100644 --- a/tests/tests_pytorch/serve/test_servable_module_validator.py +++ b/tests/tests_pytorch/serve/test_servable_module_validator.py @@ -1,9 +1,10 @@ import pytest import torch +from torch import Tensor + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.serve.servable_module_validator import ServableModule, ServableModuleValidator -from torch import Tensor class ServableBoringModel(BoringModel, ServableModule): diff --git a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py index 394d827058987..d26f6c4d2c3ef 100644 --- a/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py +++ b/tests/tests_pytorch/strategies/launchers/test_multiprocessing.py @@ -18,13 +18,13 @@ import pytest import torch + from lightning.fabric.plugins import ClusterEnvironment from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.strategies import DDPStrategy from lightning.pytorch.strategies.launchers.multiprocessing import _GlobalStateSnapshot, _MultiProcessingLauncher from lightning.pytorch.trainer.states import TrainerFn - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/strategies/launchers/test_subprocess_script.py b/tests/tests_pytorch/strategies/launchers/test_subprocess_script.py index b8a5ddb29de23..dd8576ec0cafe 100644 --- a/tests/tests_pytorch/strategies/launchers/test_subprocess_script.py +++ b/tests/tests_pytorch/strategies/launchers/test_subprocess_script.py @@ -4,9 +4,9 @@ from unittest.mock import Mock import pytest -from lightning.pytorch.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from lightning_utilities.core.imports import RequirementCache +from lightning.pytorch.strategies.launchers.subprocess_script import _SubprocessScriptLauncher from tests_pytorch.helpers.runif import RunIf _HYDRA_WITH_RUN_PROCESS = RequirementCache("hydra-core>=1.0.7") diff --git a/tests/tests_pytorch/strategies/test_common.py b/tests/tests_pytorch/strategies/test_common.py index 699424b3c53b9..6ab4f49374c27 100644 --- a/tests/tests_pytorch/strategies/test_common.py +++ b/tests/tests_pytorch/strategies/test_common.py @@ -15,10 +15,10 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.plugins import DoublePrecision, HalfPrecision, Precision from lightning.pytorch.strategies import SingleDeviceStrategy - from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel diff --git a/tests/tests_pytorch/strategies/test_custom_strategy.py b/tests/tests_pytorch/strategies/test_custom_strategy.py index 347dacbd9a811..8a297db217943 100644 --- a/tests/tests_pytorch/strategies/test_custom_strategy.py +++ b/tests/tests_pytorch/strategies/test_custom_strategy.py @@ -17,6 +17,7 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.strategies import SingleDeviceStrategy diff --git a/tests/tests_pytorch/strategies/test_ddp.py b/tests/tests_pytorch/strategies/test_ddp.py index b23d306b9d907..915e57440b40f 100644 --- a/tests/tests_pytorch/strategies/test_ddp.py +++ b/tests/tests_pytorch/strategies/test_ddp.py @@ -17,14 +17,14 @@ import pytest import torch +from torch.nn.parallel import DistributedDataParallel + from lightning.fabric.plugins.environments import LightningEnvironment from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.plugins import DoublePrecision, HalfPrecision, Precision from lightning.pytorch.strategies import DDPStrategy from lightning.pytorch.trainer.states import TrainerFn -from torch.nn.parallel import DistributedDataParallel - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/strategies/test_ddp_integration.py b/tests/tests_pytorch/strategies/test_ddp_integration.py index 836072d36be83..048403366ebc7 100644 --- a/tests/tests_pytorch/strategies/test_ddp_integration.py +++ b/tests/tests_pytorch/strategies/test_ddp_integration.py @@ -15,9 +15,14 @@ from unittest import mock from unittest.mock import Mock -import lightning.pytorch as pl import pytest import torch +from torch.distributed.optim import ZeroRedundancyOptimizer +from torch.multiprocessing import ProcessRaisedException +from torch.nn.parallel.distributed import DistributedDataParallel + +import lightning.pytorch as pl +import tests_pytorch.helpers.pipelines as tpipes from lightning.fabric.plugins.environments import ClusterEnvironment, LightningEnvironment from lightning.fabric.utilities.distributed import _distributed_is_initialized from lightning.pytorch import Trainer @@ -27,11 +32,6 @@ from lightning.pytorch.strategies.launchers import _SubprocessScriptLauncher from lightning.pytorch.strategies.launchers.multiprocessing import _MultiProcessingLauncher from lightning.pytorch.trainer import seed_everything -from torch.distributed.optim import ZeroRedundancyOptimizer -from torch.multiprocessing import ProcessRaisedException -from torch.nn.parallel.distributed import DistributedDataParallel - -import tests_pytorch.helpers.pipelines as tpipes from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel diff --git a/tests/tests_pytorch/strategies/test_ddp_integration_comm_hook.py b/tests/tests_pytorch/strategies/test_ddp_integration_comm_hook.py index 3723ee2fc5a8f..c4d40e1dfa7d2 100644 --- a/tests/tests_pytorch/strategies/test_ddp_integration_comm_hook.py +++ b/tests/tests_pytorch/strategies/test_ddp_integration_comm_hook.py @@ -15,10 +15,10 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.strategies import DDPStrategy - from tests_pytorch.helpers.runif import RunIf if torch.distributed.is_available(): diff --git a/tests/tests_pytorch/strategies/test_deepspeed.py b/tests/tests_pytorch/strategies/test_deepspeed.py index 73697ea131545..7e7d2eacd0617 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed.py +++ b/tests/tests_pytorch/strategies/test_deepspeed.py @@ -22,6 +22,10 @@ import pytest import torch import torch.nn.functional as F +from torch import Tensor, nn +from torch.utils.data import DataLoader +from torchmetrics import Accuracy + from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.accelerators import CUDAAccelerator from lightning.pytorch.callbacks import Callback, LearningRateMonitor, ModelCheckpoint @@ -31,10 +35,6 @@ from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11 -from torch import Tensor, nn -from torch.utils.data import DataLoader -from torchmetrics import Accuracy - from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf @@ -81,7 +81,7 @@ def automatic_optimization(self) -> bool: return False -@pytest.fixture() +@pytest.fixture def deepspeed_config(): return { "optimizer": {"type": "SGD", "params": {"lr": 3e-5}}, @@ -92,7 +92,7 @@ def deepspeed_config(): } -@pytest.fixture() +@pytest.fixture def deepspeed_zero_config(deepspeed_config): return {**deepspeed_config, "zero_allow_untested_optimizer": True, "zero_optimization": {"stage": 2}} diff --git a/tests/tests_pytorch/strategies/test_fsdp.py b/tests/tests_pytorch/strategies/test_fsdp.py index 2aee68f7ae733..f3e88ca356764 100644 --- a/tests/tests_pytorch/strategies/test_fsdp.py +++ b/tests/tests_pytorch/strategies/test_fsdp.py @@ -12,6 +12,10 @@ import pytest import torch import torch.nn as nn +from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision +from torch.distributed.fsdp.wrap import ModuleWrapPolicy, always_wrap_policy, size_based_auto_wrap_policy, wrap +from torchmetrics import Accuracy + from lightning.fabric.plugins.environments import LightningEnvironment from lightning.fabric.strategies.fsdp import _is_sharded_checkpoint from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 @@ -24,10 +28,6 @@ from lightning.pytorch.strategies import FSDPStrategy from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities.consolidate_checkpoint import _format_checkpoint -from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, FullyShardedDataParallel, MixedPrecision -from torch.distributed.fsdp.wrap import ModuleWrapPolicy, always_wrap_policy, size_based_auto_wrap_policy, wrap -from torchmetrics import Accuracy - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/strategies/test_model_parallel.py b/tests/tests_pytorch/strategies/test_model_parallel.py index 731da66d4a61f..86a95944ac20d 100644 --- a/tests/tests_pytorch/strategies/test_model_parallel.py +++ b/tests/tests_pytorch/strategies/test_model_parallel.py @@ -20,11 +20,11 @@ import pytest import torch import torch.nn as nn + from lightning.fabric.strategies.model_parallel import _is_sharded_checkpoint from lightning.pytorch import LightningModule from lightning.pytorch.plugins.environments import LightningEnvironment from lightning.pytorch.strategies import ModelParallelStrategy - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/strategies/test_model_parallel_integration.py b/tests/tests_pytorch/strategies/test_model_parallel_integration.py index 9dcbcc802834b..00600183f4293 100644 --- a/tests/tests_pytorch/strategies/test_model_parallel_integration.py +++ b/tests/tests_pytorch/strategies/test_model_parallel_integration.py @@ -18,12 +18,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from lightning.pytorch import LightningModule, Trainer, seed_everything -from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset -from lightning.pytorch.strategies import ModelParallelStrategy from torch.utils.data import DataLoader, DistributedSampler from torchmetrics.classification import Accuracy +from lightning.pytorch import LightningModule, Trainer, seed_everything +from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset +from lightning.pytorch.strategies import ModelParallelStrategy from tests_pytorch.helpers.runif import RunIf @@ -86,7 +86,7 @@ def fn(model, device_mesh): return fn -@pytest.fixture() +@pytest.fixture def distributed(): yield if torch.distributed.is_initialized(): diff --git a/tests/tests_pytorch/strategies/test_registry.py b/tests/tests_pytorch/strategies/test_registry.py index 90e15638bfd06..d2c580fd28c0f 100644 --- a/tests/tests_pytorch/strategies/test_registry.py +++ b/tests/tests_pytorch/strategies/test_registry.py @@ -14,10 +14,10 @@ from unittest import mock import pytest + from lightning.pytorch import Trainer from lightning.pytorch.plugins import CheckpointIO from lightning.pytorch.strategies import DDPStrategy, DeepSpeedStrategy, FSDPStrategy, StrategyRegistry, XLAStrategy - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/strategies/test_single_device.py b/tests/tests_pytorch/strategies/test_single_device.py index 5b10c4a17d726..7582dfe86dd3c 100644 --- a/tests/tests_pytorch/strategies/test_single_device.py +++ b/tests/tests_pytorch/strategies/test_single_device.py @@ -16,12 +16,12 @@ import pytest import torch +from torch.utils.data import DataLoader + from lightning.pytorch import Trainer from lightning.pytorch.core.optimizer import LightningOptimizer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.strategies import SingleDeviceStrategy -from torch.utils.data import DataLoader - from tests_pytorch.helpers.dataloaders import CustomNotImplementedErrorDataloader from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/strategies/test_xla.py b/tests/tests_pytorch/strategies/test_xla.py index b4f0e2c37ec05..3fde2600c9483 100644 --- a/tests/tests_pytorch/strategies/test_xla.py +++ b/tests/tests_pytorch/strategies/test_xla.py @@ -16,11 +16,11 @@ from unittest.mock import Mock import torch + from lightning.pytorch import Trainer from lightning.pytorch.accelerators import XLAAccelerator from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.strategies import XLAStrategy - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index de89d094cdfcf..5c33a8539b693 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -27,6 +27,14 @@ import pytest import torch import yaml +from lightning_utilities import compare_version +from lightning_utilities.test.warning import no_warning_call +from packaging.version import Version +from tensorboard.backend.event_processing import event_accumulator +from tensorboard.plugins.hparams.plugin_data_pb2 import HParamsPluginData +from torch.optim import SGD +from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR + from lightning.fabric.plugins.environments import SLURMEnvironment from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer, __version__, seed_everything from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint @@ -46,14 +54,6 @@ from lightning.pytorch.trainer.states import TrainerFn from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _TORCHVISION_AVAILABLE -from lightning_utilities import compare_version -from lightning_utilities.test.warning import no_warning_call -from packaging.version import Version -from tensorboard.backend.event_processing import event_accumulator -from tensorboard.plugins.hparams.plugin_data_pb2 import HParamsPluginData -from torch.optim import SGD -from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR - from tests_pytorch.helpers.runif import RunIf if _JSONARGPARSE_SIGNATURES_AVAILABLE: @@ -84,7 +84,7 @@ def mock_subclasses(baseclass, *subclasses): yield None -@pytest.fixture() +@pytest.fixture def cleandir(tmp_path, monkeypatch): monkeypatch.chdir(tmp_path) return @@ -666,7 +666,7 @@ class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.add_optimizer_args(torch.optim.Adam) - match = "BoringModel.configure_optimizers` will be overridden by " "`MyLightningCLI.configure_optimizers`" + match = "BoringModel.configure_optimizers` will be overridden by `MyLightningCLI.configure_optimizers`" argv = ["fit", "--trainer.fast_dev_run=1"] if run else [] with mock.patch("sys.argv", ["any.py"] + argv), pytest.warns(UserWarning, match=match): cli = MyLightningCLI(BoringModel, run=run) @@ -1647,8 +1647,13 @@ def _test_logger_init_args(logger_name, init, unresolved=None): def test_comet_logger_init_args(): _test_logger_init_args( "CometLogger", - init={"save_dir": "comet"}, # Resolve from CometLogger.__init__ - unresolved={"workspace": "comet"}, # Resolve from Comet{,Existing,Offline}Experiment.__init__ + init={ + "experiment_key": "some_key", # Resolve from CometLogger.__init__ + "workspace": "comet", + }, + unresolved={ + "save_dir": "comet", # Resolve from CometLogger.__init__ as kwarg + }, ) diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index 9e947e0723dcd..b8517a0303015 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -19,11 +19,12 @@ from unittest import mock from unittest.mock import Mock -import lightning.fabric -import lightning.pytorch import pytest import torch import torch.distributed + +import lightning.fabric +import lightning.pytorch from lightning.fabric.plugins.environments import ( KubeflowEnvironment, LightningEnvironment, @@ -61,7 +62,6 @@ from lightning.pytorch.utilities.imports import ( _LIGHTNING_HABANA_AVAILABLE, ) - from tests_pytorch.conftest import mock_cuda_count, mock_mps_count, mock_tpu_available, mock_xla_available from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py index eb09413cafcce..94b5fcba652be 100644 --- a/tests/tests_pytorch/trainer/connectors/test_callback_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_callback_connector.py @@ -18,6 +18,7 @@ import pytest import torch + from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_10_0 from lightning.pytorch import Callback, LightningModule, Trainer from lightning.pytorch.callbacks import ( diff --git a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py index d29e2285e983c..722742a3ccae0 100644 --- a/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_checkpoint_connector.py @@ -17,6 +17,7 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import Callback, ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/trainer/connectors/test_data_connector.py b/tests/tests_pytorch/trainer/connectors/test_data_connector.py index ca5690ed20f41..ceb0418f2cb1d 100644 --- a/tests/tests_pytorch/trainer/connectors/test_data_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_data_connector.py @@ -16,8 +16,12 @@ from unittest import mock from unittest.mock import Mock -import lightning.fabric import pytest +from lightning_utilities.test.warning import no_warning_call +from torch import Tensor +from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, Sampler, SequentialSampler + +import lightning.fabric from lightning.fabric.utilities.distributed import DistributedSamplerWrapper from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer @@ -33,10 +37,6 @@ from lightning.pytorch.utilities.combined_loader import CombinedLoader from lightning.pytorch.utilities.data import _update_dataloader from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning_utilities.test.warning import no_warning_call -from torch import Tensor -from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, Sampler, SequentialSampler - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/connectors/test_signal_connector.py b/tests/tests_pytorch/trainer/connectors/test_signal_connector.py index 8825db3727e86..83c5c2bb7e02b 100644 --- a/tests/tests_pytorch/trainer/connectors/test_signal_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_signal_connector.py @@ -18,13 +18,13 @@ from unittest.mock import Mock import pytest + from lightning.fabric.plugins.environments import SLURMEnvironment from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.trainer.connectors.signal_connector import _SignalConnector from lightning.pytorch.utilities.exceptions import SIGTERMException - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/dynamic_args/test_multiple_eval_dataloaders.py b/tests/tests_pytorch/trainer/dynamic_args/test_multiple_eval_dataloaders.py index f7b7d3c8ac8a7..89c842f2bbf0f 100644 --- a/tests/tests_pytorch/trainer/dynamic_args/test_multiple_eval_dataloaders.py +++ b/tests/tests_pytorch/trainer/dynamic_args/test_multiple_eval_dataloaders.py @@ -13,9 +13,10 @@ # limitations under the License. import pytest import torch +from torch.utils.data import Dataset + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel -from torch.utils.data import Dataset class RandomDatasetA(Dataset): diff --git a/tests/tests_pytorch/trainer/flags/test_barebones.py b/tests/tests_pytorch/trainer/flags/test_barebones.py index 329fcf915d751..875aaef40a123 100644 --- a/tests/tests_pytorch/trainer/flags/test_barebones.py +++ b/tests/tests_pytorch/trainer/flags/test_barebones.py @@ -14,6 +14,7 @@ import logging import pytest + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelSummary from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/trainer/flags/test_check_val_every_n_epoch.py b/tests/tests_pytorch/trainer/flags/test_check_val_every_n_epoch.py index fba938d82e761..301721e4b28f5 100644 --- a/tests/tests_pytorch/trainer/flags/test_check_val_every_n_epoch.py +++ b/tests/tests_pytorch/trainer/flags/test_check_val_every_n_epoch.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest +from torch.utils.data import DataLoader + from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.trainer.trainer import Trainer -from torch.utils.data import DataLoader @pytest.mark.parametrize( diff --git a/tests/tests_pytorch/trainer/flags/test_fast_dev_run.py b/tests/tests_pytorch/trainer/flags/test_fast_dev_run.py index 62087619ae42f..63b05a0a131f2 100644 --- a/tests/tests_pytorch/trainer/flags/test_fast_dev_run.py +++ b/tests/tests_pytorch/trainer/flags/test_fast_dev_run.py @@ -3,6 +3,7 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/trainer/flags/test_inference_mode.py b/tests/tests_pytorch/trainer/flags/test_inference_mode.py index bae7b66dbbd55..802d9bf30c59b 100644 --- a/tests/tests_pytorch/trainer/flags/test_inference_mode.py +++ b/tests/tests_pytorch/trainer/flags/test_inference_mode.py @@ -16,6 +16,7 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loops import _Loop diff --git a/tests/tests_pytorch/trainer/flags/test_limit_batches.py b/tests/tests_pytorch/trainer/flags/test_limit_batches.py index e190a0b380377..ef405a8ee95b1 100644 --- a/tests/tests_pytorch/trainer/flags/test_limit_batches.py +++ b/tests/tests_pytorch/trainer/flags/test_limit_batches.py @@ -14,6 +14,7 @@ import logging import pytest + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.trainer.states import TrainerFn diff --git a/tests/tests_pytorch/trainer/flags/test_min_max_epochs.py b/tests/tests_pytorch/trainer/flags/test_min_max_epochs.py index 25aaeb8cff77e..3315c328b6249 100644 --- a/tests/tests_pytorch/trainer/flags/test_min_max_epochs.py +++ b/tests/tests_pytorch/trainer/flags/test_min_max_epochs.py @@ -1,8 +1,9 @@ import pytest +from lightning_utilities.test.warning import no_warning_call + from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel -from lightning_utilities.test.warning import no_warning_call @pytest.mark.parametrize( diff --git a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py index dfe8a63a8bfc4..050818287ba45 100644 --- a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py +++ b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py @@ -16,11 +16,11 @@ import pytest import torch +from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.trainer.states import RunningStage -from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, Sampler, SequentialSampler - from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.datasets import SklearnDataset from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/flags/test_val_check_interval.py b/tests/tests_pytorch/trainer/flags/test_val_check_interval.py index b776263e9953d..b6cc446cb0840 100644 --- a/tests/tests_pytorch/trainer/flags/test_val_check_interval.py +++ b/tests/tests_pytorch/trainer/flags/test_val_check_interval.py @@ -14,10 +14,11 @@ import logging import pytest +from torch.utils.data import DataLoader + from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset from lightning.pytorch.trainer.trainer import Trainer from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch.utils.data import DataLoader @pytest.mark.parametrize("max_epochs", [1, 2, 3]) diff --git a/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py b/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py index af7cecdb21a08..90f9a3e697535 100644 --- a/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_distributed_logging.py @@ -19,7 +19,6 @@ from lightning.pytorch import Callback, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers.logger import Logger - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py index 40c82bec2fd10..be6de37ddff3a 100644 --- a/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_eval_loop_logging.py @@ -24,6 +24,8 @@ import numpy as np import pytest import torch +from torch import Tensor + from lightning.pytorch import Trainer, callbacks from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset @@ -31,8 +33,6 @@ from lightning.pytorch.loops import _EvaluationLoop from lightning.pytorch.trainer.states import RunningStage from lightning.pytorch.utilities.exceptions import MisconfigurationException -from torch import Tensor - from tests_pytorch.helpers.runif import RunIf if _RICH_AVAILABLE: diff --git a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py index e96857a6c192d..faf88a09f6499 100644 --- a/tests/tests_pytorch/trainer/logging_/test_logger_connector.py +++ b/tests/tests_pytorch/trainer/logging_/test_logger_connector.py @@ -18,6 +18,11 @@ import pytest import torch +from lightning_utilities.core.imports import compare_version +from torch.utils.data import DataLoader +from torchmetrics import Accuracy, MeanAbsoluteError, MeanSquaredError, MetricCollection +from torchmetrics import AveragePrecision as AvgPre + from lightning.pytorch import LightningModule from lightning.pytorch.callbacks.callback import Callback from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset @@ -28,11 +33,6 @@ from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_9_1 from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11 -from lightning_utilities.core.imports import compare_version -from torch.utils.data import DataLoader -from torchmetrics import Accuracy, MeanAbsoluteError, MeanSquaredError, MetricCollection -from torchmetrics import AveragePrecision as AvgPre - from tests_pytorch.helpers.runif import RunIf from tests_pytorch.models.test_hooks import get_members diff --git a/tests/tests_pytorch/trainer/logging_/test_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_loop_logging.py index e43f3b8d6bffd..6bc9dd1de0587 100644 --- a/tests/tests_pytorch/trainer/logging_/test_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_loop_logging.py @@ -18,6 +18,7 @@ from unittest.mock import ANY import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.trainer.connectors.logger_connector.fx_validator import _FxValidator diff --git a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py index e48f80d2d1680..be99489cfdf89 100644 --- a/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py +++ b/tests/tests_pytorch/trainer/logging_/test_train_loop_logging.py @@ -22,6 +22,11 @@ import numpy as np import pytest import torch +from lightning_utilities.test.warning import no_warning_call +from torch import Tensor +from torch.utils.data import DataLoader +from torchmetrics import Accuracy + from lightning.pytorch import Trainer, callbacks from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint, TQDMProgressBar from lightning.pytorch.core.module import LightningModule @@ -30,11 +35,6 @@ from lightning.pytorch.trainer.states import RunningStage from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_11 as _TM_GE_0_11 -from lightning_utilities.test.warning import no_warning_call -from torch import Tensor -from torch.utils.data import DataLoader -from torchmetrics import Accuracy - from tests_pytorch.helpers.runif import RunIf @@ -563,7 +563,7 @@ def training_step(self, *args): def test_log_tensor_and_clone_no_torch_warning(tmp_path): - """Regression test for issue https://github.com/Lightning-AI/lightning/issues/14594.""" + """Regression test for issue https://github.com/Lightning-AI/pytorch-lightning/issues/14594.""" class TestModel(BoringModel): def training_step(self, *args): diff --git a/tests/tests_pytorch/trainer/optimization/test_backward_calls.py b/tests/tests_pytorch/trainer/optimization/test_backward_calls.py index b91dbff8c6d09..42332bc05580f 100644 --- a/tests/tests_pytorch/trainer/optimization/test_backward_calls.py +++ b/tests/tests_pytorch/trainer/optimization/test_backward_calls.py @@ -2,6 +2,7 @@ import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py index f0ab8fe401633..3f89e1459298d 100644 --- a/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py +++ b/tests/tests_pytorch/trainer/optimization/test_manual_optimization.py @@ -21,11 +21,11 @@ import torch import torch.distributed as torch_distrib import torch.nn.functional as F + from lightning.fabric.utilities.exceptions import MisconfigurationException from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.demos.boring_classes import BoringModel, ManualOptimBoringModel from lightning.pytorch.strategies import Strategy - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py b/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py index 319eafeb0d0bb..dcbae32827c50 100644 --- a/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py +++ b/tests/tests_pytorch/trainer/optimization/test_multiple_optimizers.py @@ -13,9 +13,10 @@ # limitations under the License. """Tests to ensure that the behaviours related to multiple optimizers works.""" -import lightning.pytorch as pl import pytest import torch + +import lightning.pytorch as pl from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/trainer/optimization/test_optimizers.py b/tests/tests_pytorch/trainer/optimization/test_optimizers.py index ac660b6651be5..6b88534f3430d 100644 --- a/tests/tests_pytorch/trainer/optimization/test_optimizers.py +++ b/tests/tests_pytorch/trainer/optimization/test_optimizers.py @@ -16,6 +16,8 @@ import pytest import torch +from torch import optim + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.core.optimizer import ( @@ -26,8 +28,6 @@ from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.types import LRSchedulerConfig -from torch import optim - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py index 76c0c695b3c02..aa60db594447d 100644 --- a/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py +++ b/tests/tests_pytorch/trainer/properties/test_estimated_stepping_batches.py @@ -18,11 +18,11 @@ import pytest import torch +from torch.utils.data import DataLoader + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomIterableDataset from lightning.pytorch.strategies import SingleDeviceXLAStrategy -from torch.utils.data import DataLoader - from tests_pytorch.conftest import mock_cuda_count from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/properties/test_get_model.py b/tests/tests_pytorch/trainer/properties/test_get_model.py index 72967ce929eea..47d2bffcfca22 100644 --- a/tests/tests_pytorch/trainer/properties/test_get_model.py +++ b/tests/tests_pytorch/trainer/properties/test_get_model.py @@ -13,9 +13,9 @@ # limitations under the License. import pytest + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/properties/test_log_dir.py b/tests/tests_pytorch/trainer/properties/test_log_dir.py index 1fc4f3454f9d0..0f045c2e815fd 100644 --- a/tests/tests_pytorch/trainer/properties/test_log_dir.py +++ b/tests/tests_pytorch/trainer/properties/test_log_dir.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/trainer/properties/test_loggers.py b/tests/tests_pytorch/trainer/properties/test_loggers.py index 7fcf663f59122..1d07f3e99d412 100644 --- a/tests/tests_pytorch/trainer/properties/test_loggers.py +++ b/tests/tests_pytorch/trainer/properties/test_loggers.py @@ -13,9 +13,9 @@ # limitations under the License. import pytest + from lightning.pytorch import Trainer from lightning.pytorch.loggers import TensorBoardLogger - from tests_pytorch.loggers.test_logger import CustomLogger diff --git a/tests/tests_pytorch/trainer/test_config_validator.py b/tests/tests_pytorch/trainer/test_config_validator.py index 8963b1f76186d..cfca98e04c8c8 100644 --- a/tests/tests_pytorch/trainer/test_config_validator.py +++ b/tests/tests_pytorch/trainer/test_config_validator.py @@ -15,6 +15,7 @@ import pytest import torch + from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import LightningDataModule, LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset diff --git a/tests/tests_pytorch/trainer/test_dataloaders.py b/tests/tests_pytorch/trainer/test_dataloaders.py index a2d29baa9fa6f..7fbe55030770e 100644 --- a/tests/tests_pytorch/trainer/test_dataloaders.py +++ b/tests/tests_pytorch/trainer/test_dataloaders.py @@ -14,10 +14,17 @@ import os from unittest.mock import Mock, call, patch -import lightning.pytorch import numpy import pytest import torch +from lightning_utilities.test.warning import no_warning_call +from torch.utils.data import RandomSampler +from torch.utils.data.dataloader import DataLoader +from torch.utils.data.dataset import Dataset, IterableDataset +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data.sampler import SequentialSampler + +import lightning.pytorch from lightning.fabric.utilities.data import _auto_add_worker_init_fn, has_iterable_dataset from lightning.pytorch import Callback, Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint @@ -33,13 +40,6 @@ from lightning.pytorch.utilities.combined_loader import CombinedLoader from lightning.pytorch.utilities.data import has_len_all_ranks from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning_utilities.test.warning import no_warning_call -from torch.utils.data import RandomSampler -from torch.utils.data.dataloader import DataLoader -from torch.utils.data.dataset import Dataset, IterableDataset -from torch.utils.data.distributed import DistributedSampler -from torch.utils.data.sampler import SequentialSampler - from tests_pytorch.helpers.dataloaders import CustomInfDataloader, CustomNotImplementedErrorDataloader from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/trainer/test_states.py b/tests/tests_pytorch/trainer/test_states.py index d89e99c9319c6..fff2d0b464d42 100644 --- a/tests/tests_pytorch/trainer/test_states.py +++ b/tests/tests_pytorch/trainer/test_states.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest + from lightning.pytorch import Callback, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus diff --git a/tests/tests_pytorch/trainer/test_trainer.py b/tests/tests_pytorch/trainer/test_trainer.py index d66f3aafee5df..18ae7ce77bdfc 100644 --- a/tests/tests_pytorch/trainer/test_trainer.py +++ b/tests/tests_pytorch/trainer/test_trainer.py @@ -27,6 +27,12 @@ import pytest import torch import torch.nn as nn +from torch.multiprocessing import ProcessRaisedException +from torch.nn.parallel.distributed import DistributedDataParallel +from torch.optim import SGD +from torch.utils.data import DataLoader, IterableDataset + +import tests_pytorch.helpers.utils as tutils from lightning.fabric.utilities.cloud_io import _load as pl_load from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.fabric.utilities.seed import seed_everything @@ -50,12 +56,6 @@ from lightning.pytorch.trainer.states import RunningStage, TrainerFn from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE -from torch.multiprocessing import ProcessRaisedException -from torch.nn.parallel.distributed import DistributedDataParallel -from torch.optim import SGD -from torch.utils.data import DataLoader, IterableDataset - -import tests_pytorch.helpers.utils as tutils from tests_pytorch.conftest import mock_cuda_count, mock_mps_count from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf @@ -335,9 +335,9 @@ def mock_save_function(filepath, *args): file_lists = set(os.listdir(tmp_path)) - assert len(file_lists) == len( - expected_files - ), f"Should save {len(expected_files)} models when save_top_k={save_top_k} but found={file_lists}" + assert len(file_lists) == len(expected_files), ( + f"Should save {len(expected_files)} models when save_top_k={save_top_k} but found={file_lists}" + ) # verify correct naming for fname in expected_files: diff --git a/tests/tests_pytorch/tuner/test_lr_finder.py b/tests/tests_pytorch/tuner/test_lr_finder.py index a31be67911409..ec894688ccb6c 100644 --- a/tests/tests_pytorch/tuner/test_lr_finder.py +++ b/tests/tests_pytorch/tuner/test_lr_finder.py @@ -20,6 +20,8 @@ import pytest import torch +from lightning_utilities.test.warning import no_warning_call + from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks.lr_finder import LearningRateFinder from lightning.pytorch.demos.boring_classes import BoringModel @@ -27,8 +29,6 @@ from lightning.pytorch.tuner.tuning import Tuner from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.types import STEP_OUTPUT -from lightning_utilities.test.warning import no_warning_call - from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.simple_models import ClassificationModel @@ -73,9 +73,9 @@ def test_model_reset_correctly(tmp_path): after_state_dict = model.state_dict() for key in before_state_dict: - assert torch.all( - torch.eq(before_state_dict[key], after_state_dict[key]) - ), "Model was not reset correctly after learning rate finder" + assert torch.all(torch.eq(before_state_dict[key], after_state_dict[key])), ( + "Model was not reset correctly after learning rate finder" + ) assert not any(f for f in os.listdir(tmp_path) if f.startswith(".lr_find")) diff --git a/tests/tests_pytorch/tuner/test_scale_batch_size.py b/tests/tests_pytorch/tuner/test_scale_batch_size.py index 8dd66fe9bfcff..e4ed533c6fa83 100644 --- a/tests/tests_pytorch/tuner/test_scale_batch_size.py +++ b/tests/tests_pytorch/tuner/test_scale_batch_size.py @@ -18,14 +18,14 @@ import pytest import torch +from lightning_utilities.test.warning import no_warning_call +from torch.utils.data import DataLoader + from lightning.pytorch import Trainer from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset from lightning.pytorch.tuner.tuning import Tuner from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning_utilities.test.warning import no_warning_call -from torch.utils.data import DataLoader - from tests_pytorch.helpers.runif import RunIf @@ -114,9 +114,9 @@ def test_trainer_reset_correctly(tmp_path, trainer_fn): after_state_dict = model.state_dict() for key in before_state_dict: - assert torch.all( - torch.eq(before_state_dict[key], after_state_dict[key]) - ), "Model was not reset correctly after scaling batch size" + assert torch.all(torch.eq(before_state_dict[key], after_state_dict[key])), ( + "Model was not reset correctly after scaling batch size" + ) assert not any(f for f in os.listdir(tmp_path) if f.startswith(".scale_batch_size_temp_model")) diff --git a/tests/tests_pytorch/tuner/test_tuning.py b/tests/tests_pytorch/tuner/test_tuning.py index dda08354575c0..e3b24a69b7999 100644 --- a/tests/tests_pytorch/tuner/test_tuning.py +++ b/tests/tests_pytorch/tuner/test_tuning.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import pytest + from lightning.pytorch import Trainer from lightning.pytorch.callbacks import BatchSizeFinder, LearningRateFinder from lightning.pytorch.demos.boring_classes import BoringModel diff --git a/tests/tests_pytorch/utilities/migration/test_migration.py b/tests/tests_pytorch/utilities/migration/test_migration.py index 9680c90a94c5b..f9c921f6f1bfd 100644 --- a/tests/tests_pytorch/utilities/migration/test_migration.py +++ b/tests/tests_pytorch/utilities/migration/test_migration.py @@ -13,9 +13,10 @@ # limitations under the License. from unittest.mock import ANY, MagicMock -import lightning.pytorch as pl import pytest import torch + +import lightning.pytorch as pl from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint diff --git a/tests/tests_pytorch/utilities/migration/test_utils.py b/tests/tests_pytorch/utilities/migration/test_utils.py index 56b8701cfcfc2..41ca1a779f8a5 100644 --- a/tests/tests_pytorch/utilities/migration/test_utils.py +++ b/tests/tests_pytorch/utilities/migration/test_utils.py @@ -18,16 +18,16 @@ import sys from unittest.mock import ANY -import lightning.pytorch as pl import pytest import torch -from lightning.fabric.utilities.warnings import PossibleUserWarning -from lightning.pytorch.utilities.migration import migrate_checkpoint, pl_legacy_patch -from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint, _RedirectingUnpickler from lightning_utilities.core.imports import module_available from lightning_utilities.test.warning import no_warning_call from packaging.version import Version +import lightning.pytorch as pl +from lightning.fabric.utilities.warnings import PossibleUserWarning +from lightning.pytorch.utilities.migration import migrate_checkpoint, pl_legacy_patch +from lightning.pytorch.utilities.migration.utils import _pl_migrate_checkpoint, _RedirectingUnpickler from tests_pytorch.checkpointing.test_legacy_checkpoints import ( CHECKPOINT_EXTENSION, LEGACY_BACK_COMPATIBLE_PL_VERSIONS, @@ -75,9 +75,9 @@ def _list_sys_modules(pattern: str) -> str: @pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) @pytest.mark.skipif(module_available("lightning"), reason="This test is ONLY relevant for the STANDALONE package") def test_test_patch_legacy_imports_standalone(pl_version): - assert any( - key.startswith("pytorch_lightning") for key in sys.modules - ), f"Imported PL, so it has to be in sys.modules: {_list_sys_modules('pytorch_lightning')}" + assert any(key.startswith("pytorch_lightning") for key in sys.modules), ( + f"Imported PL, so it has to be in sys.modules: {_list_sys_modules('pytorch_lightning')}" + ) path_legacy = os.path.join(LEGACY_CHECKPOINTS_PATH, pl_version) path_ckpts = sorted(glob.glob(os.path.join(path_legacy, f"*{CHECKPOINT_EXTENSION}"))) assert path_ckpts, f'No checkpoints found in folder "{path_legacy}"' @@ -86,9 +86,9 @@ def test_test_patch_legacy_imports_standalone(pl_version): with no_warning_call(match="Redirecting import of*"), pl_legacy_patch(): torch.load(path_ckpt, weights_only=False) - assert any( - key.startswith("pytorch_lightning") for key in sys.modules - ), f"Imported PL, so it has to be in sys.modules: {_list_sys_modules('pytorch_lightning')}" + assert any(key.startswith("pytorch_lightning") for key in sys.modules), ( + f"Imported PL, so it has to be in sys.modules: {_list_sys_modules('pytorch_lightning')}" + ) assert not any(key.startswith("lightning." + "pytorch") for key in sys.modules), ( "Did not import the unified package," f" so it should not be in sys.modules: {_list_sys_modules('lightning' + '.pytorch')}" @@ -98,9 +98,9 @@ def test_test_patch_legacy_imports_standalone(pl_version): @pytest.mark.parametrize("pl_version", LEGACY_BACK_COMPATIBLE_PL_VERSIONS) @pytest.mark.skipif(not module_available("lightning"), reason="This test is ONLY relevant for the UNIFIED package") def test_patch_legacy_imports_unified(pl_version): - assert any( - key.startswith("lightning." + "pytorch") for key in sys.modules - ), f"Imported unified package, so it has to be in sys.modules: {_list_sys_modules('lightning' + '.pytorch')}" + assert any(key.startswith("lightning." + "pytorch") for key in sys.modules), ( + f"Imported unified package, so it has to be in sys.modules: {_list_sys_modules('lightning' + '.pytorch')}" + ) assert not any(key.startswith("pytorch_lightning") for key in sys.modules), ( "Should not import standalone package, all imports should be redirected to the unified package;\n" f" environment: {_list_sys_modules('pytorch_lightning')}" @@ -119,9 +119,9 @@ def test_patch_legacy_imports_unified(pl_version): with context, pl_legacy_patch(): torch.load(path_ckpt, weights_only=False) - assert any( - key.startswith("lightning." + "pytorch") for key in sys.modules - ), f"Imported unified package, so it has to be in sys.modules: {_list_sys_modules('lightning' + '.pytorch')}" + assert any(key.startswith("lightning." + "pytorch") for key in sys.modules), ( + f"Imported unified package, so it has to be in sys.modules: {_list_sys_modules('lightning' + '.pytorch')}" + ) assert not any(key.startswith("pytorch_lightning") for key in sys.modules), ( "Should not import standalone package, all imports should be redirected to the unified package;\n" f" environment: {_list_sys_modules('pytorch_lightning')}" diff --git a/tests/tests_pytorch/utilities/test_all_gather_grad.py b/tests/tests_pytorch/utilities/test_all_gather_grad.py index 9b034cdcd34e2..82ca15fd87432 100644 --- a/tests/tests_pytorch/utilities/test_all_gather_grad.py +++ b/tests/tests_pytorch/utilities/test_all_gather_grad.py @@ -15,9 +15,9 @@ import numpy as np import pytest import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.core.test_results import spawn_launch from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_auto_restart.py b/tests/tests_pytorch/utilities/test_auto_restart.py index d4b2fdaf82834..4da7bfd098a0b 100644 --- a/tests/tests_pytorch/utilities/test_auto_restart.py +++ b/tests/tests_pytorch/utilities/test_auto_restart.py @@ -14,13 +14,13 @@ import inspect import pytest +from torch.utils.data.dataloader import DataLoader + from lightning.fabric.utilities.seed import seed_everything from lightning.pytorch import Callback, Trainer from lightning.pytorch.callbacks import OnExceptionCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.utilities.exceptions import SIGTERMException -from torch.utils.data.dataloader import DataLoader - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_combined_loader.py b/tests/tests_pytorch/utilities/test_combined_loader.py index 43a146c6eb089..da168be1e3e8a 100644 --- a/tests/tests_pytorch/utilities/test_combined_loader.py +++ b/tests/tests_pytorch/utilities/test_combined_loader.py @@ -19,6 +19,13 @@ import pytest import torch +from torch import Tensor +from torch.utils._pytree import tree_flatten +from torch.utils.data import DataLoader, TensorDataset +from torch.utils.data.dataset import Dataset, IterableDataset +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data.sampler import RandomSampler, SequentialSampler + from lightning.fabric.utilities.types import _Stateful from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset @@ -31,13 +38,6 @@ _MinSize, _Sequential, ) -from torch import Tensor -from torch.utils._pytree import tree_flatten -from torch.utils.data import DataLoader, TensorDataset -from torch.utils.data.dataset import Dataset, IterableDataset -from torch.utils.data.distributed import DistributedSampler -from torch.utils.data.sampler import RandomSampler, SequentialSampler - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_compile.py b/tests/tests_pytorch/utilities/test_compile.py index 67f992421f7ce..a053c847dfd6c 100644 --- a/tests/tests_pytorch/utilities/test_compile.py +++ b/tests/tests_pytorch/utilities/test_compile.py @@ -18,12 +18,12 @@ import pytest import torch +from lightning_utilities.core.imports import RequirementCache + from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_4 from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.compile import from_compiled, to_uncompiled -from lightning_utilities.core.imports import RequirementCache - from tests_pytorch.conftest import mock_cuda_count from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index d79e9e24383a0..65a3a47715bfe 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -4,6 +4,10 @@ import numpy as np import pytest import torch +from lightning_utilities.test.warning import no_warning_call +from torch import Tensor +from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler + from lightning.fabric.utilities.data import _replace_dunder_methods from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer @@ -19,9 +23,6 @@ warning_cache, ) from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning_utilities.test.warning import no_warning_call -from torch import Tensor -from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler def test_extract_batch_size(): diff --git a/tests/tests_pytorch/utilities/test_deepspeed_collate_checkpoint.py b/tests/tests_pytorch/utilities/test_deepspeed_collate_checkpoint.py index a44ed655c2e61..05657854eb0b1 100644 --- a/tests/tests_pytorch/utilities/test_deepspeed_collate_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_deepspeed_collate_checkpoint.py @@ -14,11 +14,11 @@ import os import torch + from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.strategies import DeepSpeedStrategy from lightning.pytorch.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_deepspeed_model_summary.py b/tests/tests_pytorch/utilities/test_deepspeed_model_summary.py index c8a138ba0a02a..256233e01fa98 100644 --- a/tests/tests_pytorch/utilities/test_deepspeed_model_summary.py +++ b/tests/tests_pytorch/utilities/test_deepspeed_model_summary.py @@ -17,7 +17,6 @@ from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.strategies import DeepSpeedStrategy from lightning.pytorch.utilities.model_summary import DeepSpeedSummary - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_dtype_device_mixin.py b/tests/tests_pytorch/utilities/test_dtype_device_mixin.py index 1ba3aff359609..171656d072076 100644 --- a/tests/tests_pytorch/utilities/test_dtype_device_mixin.py +++ b/tests/tests_pytorch/utilities/test_dtype_device_mixin.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch.nn as nn + from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin from lightning.pytorch import Callback, Trainer from lightning.pytorch.demos.boring_classes import BoringModel - from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_grads.py b/tests/tests_pytorch/utilities/test_grads.py index ade88f767670c..277971d64b180 100644 --- a/tests/tests_pytorch/utilities/test_grads.py +++ b/tests/tests_pytorch/utilities/test_grads.py @@ -16,6 +16,7 @@ import pytest import torch import torch.nn as nn + from lightning.pytorch.utilities import grad_norm diff --git a/tests/tests_pytorch/utilities/test_imports.py b/tests/tests_pytorch/utilities/test_imports.py index 56ee326f076dc..301c97d756899 100644 --- a/tests/tests_pytorch/utilities/test_imports.py +++ b/tests/tests_pytorch/utilities/test_imports.py @@ -19,10 +19,10 @@ from unittest import mock import pytest -from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE from lightning_utilities.core.imports import RequirementCache from torch.distributed import is_available +from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE from tests_pytorch.helpers.runif import RunIf @@ -60,7 +60,7 @@ def new_fn(*args, **kwargs): return new_fn -@pytest.fixture() +@pytest.fixture def clean_import(): """This fixture allows test to import {pytorch_}lightning* modules completely cleanly, regardless of the current state of the imported modules. diff --git a/tests/tests_pytorch/utilities/test_memory.py b/tests/tests_pytorch/utilities/test_memory.py index c1ebff03a3a4f..336a9fafa3243 100644 --- a/tests/tests_pytorch/utilities/test_memory.py +++ b/tests/tests_pytorch/utilities/test_memory.py @@ -13,6 +13,7 @@ # limitations under the License. import torch + from lightning.pytorch.utilities.memory import recursive_detach diff --git a/tests/tests_pytorch/utilities/test_model_helpers.py b/tests/tests_pytorch/utilities/test_model_helpers.py index 78a63a7e9d2a7..e7a9d9275a484 100644 --- a/tests/tests_pytorch/utilities/test_model_helpers.py +++ b/tests/tests_pytorch/utilities/test_model_helpers.py @@ -16,10 +16,11 @@ import pytest import torch.nn +from lightning_utilities import module_available + from lightning.pytorch import LightningDataModule from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel from lightning.pytorch.utilities.model_helpers import _ModuleMode, _restricted_classmethod, is_overridden -from lightning_utilities import module_available def test_is_overridden(): diff --git a/tests/tests_pytorch/utilities/test_model_summary.py b/tests/tests_pytorch/utilities/test_model_summary.py index cced6546aab75..54c5572d01767 100644 --- a/tests/tests_pytorch/utilities/test_model_summary.py +++ b/tests/tests_pytorch/utilities/test_model_summary.py @@ -18,6 +18,7 @@ import pytest import torch import torch.nn as nn + from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.model_summary.model_summary import ( @@ -27,7 +28,6 @@ ModelSummary, summarize, ) - from tests_pytorch.helpers.advanced_models import ParityModuleRNN from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_parameter_tying.py b/tests/tests_pytorch/utilities/test_parameter_tying.py index 9dc9b5648ff01..e45fb39f81b34 100644 --- a/tests/tests_pytorch/utilities/test_parameter_tying.py +++ b/tests/tests_pytorch/utilities/test_parameter_tying.py @@ -13,9 +13,10 @@ # limitations under the License. import pytest import torch +from torch import nn + from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities import find_shared_parameters, set_shared_parameters -from torch import nn class ParameterSharingModule(BoringModel): diff --git a/tests/tests_pytorch/utilities/test_parsing.py b/tests/tests_pytorch/utilities/test_parsing.py index 1c126723d89a6..a2671eb8790de 100644 --- a/tests/tests_pytorch/utilities/test_parsing.py +++ b/tests/tests_pytorch/utilities/test_parsing.py @@ -15,6 +15,8 @@ import threading import pytest +from torch.jit import ScriptModule + from lightning.pytorch import LightningDataModule, LightningModule, Trainer from lightning.pytorch.utilities.parsing import ( _get_init_args, @@ -26,7 +28,6 @@ lightning_setattr, parse_class_init_keys, ) -from torch.jit import ScriptModule unpicklable_function = lambda: None @@ -103,12 +104,12 @@ def test_lightning_hasattr(): assert lightning_hasattr(model3, "learning_rate"), "lightning_hasattr failed to find hparams dict variable" assert not lightning_hasattr(model4, "learning_rate"), "lightning_hasattr found variable when it should not" assert lightning_hasattr(model5, "batch_size"), "lightning_hasattr failed to find batch_size in datamodule" - assert lightning_hasattr( - model6, "batch_size" - ), "lightning_hasattr failed to find batch_size in datamodule w/ hparams present" - assert lightning_hasattr( - model7, "batch_size" - ), "lightning_hasattr failed to find batch_size in hparams w/ datamodule present" + assert lightning_hasattr(model6, "batch_size"), ( + "lightning_hasattr failed to find batch_size in datamodule w/ hparams present" + ) + assert lightning_hasattr(model7, "batch_size"), ( + "lightning_hasattr failed to find batch_size in hparams w/ datamodule present" + ) assert lightning_hasattr(model8, "batch_size") for m in models: diff --git a/tests/tests_pytorch/utilities/test_pytree.py b/tests/tests_pytorch/utilities/test_pytree.py index afd198919e23f..c87a83f85f6ea 100644 --- a/tests/tests_pytorch/utilities/test_pytree.py +++ b/tests/tests_pytorch/utilities/test_pytree.py @@ -1,7 +1,8 @@ import torch -from lightning.pytorch.utilities._pytree import _tree_flatten, tree_unflatten from torch.utils.data import DataLoader, TensorDataset +from lightning.pytorch.utilities._pytree import _tree_flatten, tree_unflatten + def assert_tree_flatten_unflatten(pytree, leaves): flat, spec = _tree_flatten(pytree) diff --git a/tests/tests_pytorch/utilities/test_seed.py b/tests/tests_pytorch/utilities/test_seed.py index 282009f6b93cf..00484009481e9 100644 --- a/tests/tests_pytorch/utilities/test_seed.py +++ b/tests/tests_pytorch/utilities/test_seed.py @@ -4,8 +4,8 @@ import numpy as np import pytest import torch -from lightning.pytorch.utilities.seed import isolate_rng +from lightning.pytorch.utilities.seed import isolate_rng from tests_pytorch.helpers.runif import RunIf diff --git a/tests/tests_pytorch/utilities/test_signature_utils.py b/tests/tests_pytorch/utilities/test_signature_utils.py index e453459670360..23f9258b0d56b 100644 --- a/tests/tests_pytorch/utilities/test_signature_utils.py +++ b/tests/tests_pytorch/utilities/test_signature_utils.py @@ -1,4 +1,5 @@ import torch + from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature diff --git a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py index 2ef1ecd4fe3e5..15db0becb9551 100644 --- a/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py +++ b/tests/tests_pytorch/utilities/test_upgrade_checkpoint.py @@ -18,6 +18,7 @@ import pytest import torch + from lightning.pytorch.utilities.upgrade_checkpoint import main as upgrade_main diff --git a/tests/tests_pytorch/utilities/test_warnings.py b/tests/tests_pytorch/utilities/test_warnings.py index 8c385a2d2e49f..ab7b7d2c15a4d 100644 --- a/tests/tests_pytorch/utilities/test_warnings.py +++ b/tests/tests_pytorch/utilities/test_warnings.py @@ -24,11 +24,12 @@ from io import StringIO from unittest import mock -import lightning.pytorch import pytest -from lightning.pytorch.utilities.warnings import PossibleUserWarning from lightning_utilities.test.warning import no_warning_call +import lightning.pytorch +from lightning.pytorch.utilities.warnings import PossibleUserWarning + if __name__ == "__main__": # check that logging is properly configured import logging