diff --git a/.azure/gpu-tests-pytorch.yml b/.azure/gpu-tests-pytorch.yml index 56c0ace195ed0..4605e824426e9 100644 --- a/.azure/gpu-tests-pytorch.yml +++ b/.azure/gpu-tests-pytorch.yml @@ -126,7 +126,7 @@ jobs: - bash: | set -e python requirements/collect_env_details.py - python -c "import torch ; mgpu = torch.cuda.device_count() ; assert mgpu == 2, f'GPU: {mgpu}'" + python -c "import torch ; nb_gpus = torch.cuda.device_count() ; assert nb_gpus == 2, f'GPU: {nb_gpus}'" python requirements/pytorch/check-avail-extras.py python -c "import bitsandbytes" displayName: "Env details" @@ -138,10 +138,12 @@ jobs: displayName: "Testing: PyTorch doctests" - bash: | - python .actions/assistant.py copy_replace_imports --source_dir="./tests/tests_pytorch" \ + python .actions/assistant.py copy_replace_imports \ + --source_dir="./tests/tests_pytorch" \ --source_import="lightning.fabric,lightning.pytorch" \ --target_import="lightning_fabric,pytorch_lightning" - python .actions/assistant.py copy_replace_imports --source_dir="./examples/pytorch/basics" \ + python .actions/assistant.py copy_replace_imports \ + --source_dir="./examples/pytorch/basics" \ --source_import="lightning.fabric,lightning.pytorch" \ --target_import="lightning_fabric,pytorch_lightning" # without succeeded this could run even if the job has already failed @@ -183,8 +185,10 @@ jobs: # https://docs.codecov.com/docs/codecov-uploader curl -Os https://uploader.codecov.io/latest/linux/codecov chmod +x codecov - ./codecov --token=$(CODECOV_TOKEN) --commit=$(Build.SourceVersion) \ - --flags=gpu,pytest,${COVERAGE_SOURCE} --name="GPU-coverage" --env=linux,azure + ./codecov --token=$(CODECOV_TOKEN) \ + --commit=$(Build.SourceVersion) \ + --flags=gpu,pytest,${COVERAGE_SOURCE} \ + --name="GPU-coverage" --env=linux,azure ls -l workingDirectory: tests/tests_pytorch displayName: "Statistics" diff --git a/.github/workflows/_build-packages.yml b/.github/workflows/_build-packages.yml index e0262ac63b685..48f7257674ec6 100644 --- a/.github/workflows/_build-packages.yml +++ b/.github/workflows/_build-packages.yml @@ -33,6 +33,7 @@ jobs: name: ${{ inputs.artifact-name }} path: dist retention-days: ${{ steps.keep-artifact.outputs.DAYS }} + include-hidden-files: true build-packages: needs: init @@ -66,3 +67,4 @@ jobs: with: name: ${{ inputs.artifact-name }} path: pypi + include-hidden-files: true diff --git a/.github/workflows/_legacy-checkpoints.yml b/.github/workflows/_legacy-checkpoints.yml index 16072112b80a9..15d226eed7fec 100644 --- a/.github/workflows/_legacy-checkpoints.yml +++ b/.github/workflows/_legacy-checkpoints.yml @@ -109,6 +109,7 @@ jobs: name: checkpoints-${{ github.sha }} path: ${{ env.LEGACY_FOLDER }}/checkpoints/ retention-days: ${{ env.KEEP_DAYS }} + include-hidden-files: true - run: pip install -r requirements/ci.txt - name: Upload checkpoints to S3 @@ -138,7 +139,7 @@ jobs: run: echo ${PL_VERSION} >> back-compatible-versions.txt - name: Create Pull Request - uses: peter-evans/create-pull-request@v6 + uses: peter-evans/create-pull-request@v7 with: title: Adding test for legacy checkpoint created with ${{ env.PL_VERSION }} committer: GitHub diff --git a/.github/workflows/call-clear-cache.yml b/.github/workflows/call-clear-cache.yml index f1f0404299568..091e6a002ab3c 100644 --- a/.github/workflows/call-clear-cache.yml +++ b/.github/workflows/call-clear-cache.yml @@ -23,18 +23,18 @@ on: jobs: cron-clear: if: github.event_name == 'schedule' || github.event_name == 'pull_request' - uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.11.6 + uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.11.7 with: - scripts-ref: v0.11.6 + scripts-ref: v0.11.7 dry-run: ${{ github.event_name == 'pull_request' }} pattern: "latest|docs" age-days: 7 direct-clear: if: github.event_name == 'workflow_dispatch' || github.event_name == 'pull_request' - uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.11.6 + uses: Lightning-AI/utilities/.github/workflows/cleanup-caches.yml@v0.11.7 with: - scripts-ref: v0.11.6 + scripts-ref: v0.11.7 dry-run: ${{ github.event_name == 'pull_request' }} pattern: ${{ inputs.pattern || 'pypi_wheels' }} # setting str in case of PR / debugging age-days: ${{ fromJSON(inputs.age-days) || 0 }} # setting 0 in case of PR / debugging diff --git a/.github/workflows/ci-check-md-links.yml b/.github/workflows/ci-check-md-links.yml index d60d4f1cfa322..53b06c207482d 100644 --- a/.github/workflows/ci-check-md-links.yml +++ b/.github/workflows/ci-check-md-links.yml @@ -14,7 +14,7 @@ on: jobs: check-md-links: - uses: Lightning-AI/utilities/.github/workflows/check-md-links.yml@v0.11.6 + uses: Lightning-AI/utilities/.github/workflows/check-md-links.yml@v0.11.7 with: config-file: ".github/markdown-links-config.json" base-branch: "master" diff --git a/.github/workflows/ci-schema.yml b/.github/workflows/ci-schema.yml index 632366a211177..e5ae526f196b7 100644 --- a/.github/workflows/ci-schema.yml +++ b/.github/workflows/ci-schema.yml @@ -8,7 +8,7 @@ on: jobs: check: - uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.11.6 + uses: Lightning-AI/utilities/.github/workflows/check-schema.yml@v0.11.7 with: # skip azure due to the wrong schema file by MSFT # https://github.com/Lightning-AI/lightning-flash/pull/1455#issuecomment-1244793607 diff --git a/.github/workflows/docs-build.yml b/.github/workflows/docs-build.yml index 8f385fcb39fd7..adbc4613f4ca1 100644 --- a/.github/workflows/docs-build.yml +++ b/.github/workflows/docs-build.yml @@ -134,6 +134,7 @@ jobs: name: docs-${{ matrix.pkg-name }}-${{ github.sha }} path: docs/build/html/ retention-days: ${{ env.ARTIFACT_DAYS }} + include-hidden-files: true #- name: Dump handy wheels # if: github.event_name == 'push' && github.ref == 'refs/heads/master' diff --git a/.github/workflows/docs-tutorials.yml b/.github/workflows/docs-tutorials.yml index e4d78483fa81b..5879a7dd58744 100644 --- a/.github/workflows/docs-tutorials.yml +++ b/.github/workflows/docs-tutorials.yml @@ -48,7 +48,7 @@ jobs: - name: Create Pull Request if: ${{ github.event_name != 'pull_request' && env.SHA_ACTUAL != env.SHA_LATEST }} - uses: peter-evans/create-pull-request@v6 + uses: peter-evans/create-pull-request@v7 with: title: "docs: update ref to latest tutorials" committer: GitHub diff --git a/.github/workflows/release-nightly.yml b/.github/workflows/release-nightly.yml index 9578f84b87093..396e485b90065 100644 --- a/.github/workflows/release-nightly.yml +++ b/.github/workflows/release-nightly.yml @@ -44,6 +44,7 @@ jobs: with: name: nightly-packages-${{ github.sha }} path: dist + include-hidden-files: true publish-packages: runs-on: ubuntu-22.04 diff --git a/.github/workflows/release-pkg.yml b/.github/workflows/release-pkg.yml index a11751c13790e..39f02676305f8 100644 --- a/.github/workflows/release-pkg.yml +++ b/.github/workflows/release-pkg.yml @@ -104,7 +104,7 @@ jobs: - name: Create Pull Request if: github.event_name != 'pull_request' - uses: peter-evans/create-pull-request@v6 + uses: peter-evans/create-pull-request@v7 with: title: "Bump lightning ver `${{ env.TAG }}`" committer: GitHub diff --git a/README.md b/README.md index f0d3e8baa6034..d0c5a26bcab7e 100644 --- a/README.md +++ b/README.md @@ -7,18 +7,18 @@ **The deep learning framework to pretrain, finetune and deploy AI models.** -**NEW- Lightning 2.0 features a clean and stable API!!** +**NEW- Deploying models? Check out [LitServe](https://github.com/Lightning-AI/litserve), the PyTorch Lightning for model serving** ______________________________________________________________________

- Lightning AI • + Quick startExamples • - PyTorch Lightning • + PyTorch LightningFabric • - Docs • + Lightning AICommunity • - Contribute • + Docs

@@ -53,9 +53,24 @@ ______________________________________________________________________ -## Install Lightning +  + +# Lightning has 2 core packages + +[PyTorch Lightning: Train and deploy PyTorch at scale](#why-pytorch-lightning). +
+[Lightning Fabric: Expert control](#lightning-fabric-expert-control). + +Lightning gives you granular control over how much abstraction you want to add over PyTorch. + +
+ +
+ +  -Simple installation from PyPI +# Quick start +Install Lightning: ```bash pip install lightning @@ -64,7 +79,7 @@ pip install lightning
- Other installation options + Advanced install options #### Install with optional dependencies @@ -104,48 +119,8 @@ pip install -iU https://test.pypi.org/simple/ pytorch-lightning
-______________________________________________________________________ - -## Lightning has 2 core packages - -[PyTorch Lightning: Train and deploy PyTorch at scale](#pytorch-lightning-train-and-deploy-pytorch-at-scale). -
-[Lightning Fabric: Expert control](#lightning-fabric-expert-control). - -Lightning gives you granular control over how much abstraction you want to add over PyTorch. - -
- -
- -  -  - - -# PyTorch Lightning: Train and Deploy PyTorch at Scale - -PyTorch Lightning is just organized PyTorch - Lightning disentangles PyTorch code to decouple the science from the engineering. - -![PT to PL](docs/source-pytorch/_static/images/general/pl_quick_start_full_compressed.gif) - -______________________________________________________________________ - -### Examples -Explore various types of training possible with PyTorch Lightning. Pretrain and finetune ANY kind of model to perform ANY task like classification, segmentation, summarization and more: - -| Task | Description | Run | -|-------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------|---| -| [Hello world](#hello-simple-model) | Pretrain - Hello world example | Open In Studio | -| [Image classification](https://lightning.ai/lightning-ai/studios/image-classification-with-pytorch-lightning) | Finetune - ResNet-34 model to classify images of cars | Open In Studio | -| [Image segmentation](https://lightning.ai/lightning-ai/studios/image-segmentation-with-pytorch-lightning) | Finetune - ResNet-50 model to segment images | Open In Studio | -| [Object detection](https://lightning.ai/lightning-ai/studios/object-detection-with-pytorch-lightning) | Finetune - Faster R-CNN model to detect objects | Open In Studio | -| [Text classification](https://lightning.ai/lightning-ai/studios/text-classification-with-pytorch-lightning) | Finetune - text classifier (BERT model) | Open In Studio | -| [Text summarization](https://lightning.ai/lightning-ai/studios/text-summarization-with-pytorch-lightning) | Finetune - text summarization (Hugging Face transformer model) | Open In Studio | -| [Audio generation](https://lightning.ai/lightning-ai/studios/finetune-a-personal-ai-music-generator) | Finetune - audio generator (transformer model) | Open In Studio | -| [LLM finetuning](https://lightning.ai/lightning-ai/studios/finetune-an-llm-with-pytorch-lightning) | Finetune - LLM (Meta Llama 3.1 8B) | Open In Studio | -| [Image generation](https://lightning.ai/lightning-ai/studios/train-a-diffusion-model-with-pytorch-lightning) | Pretrain - Image generator (diffusion model) | Open In Studio | - -### Hello simple model +### PyTorch Lightning example +Define the training workflow. Here's a toy example ([explore real examples](https://lightning.ai/lightning-ai/studios?view=public§ion=featured&query=pytorch+lightning)): ```python # main.py @@ -207,6 +182,36 @@ pip install torchvision python main.py ``` +  + + +# Why PyTorch Lightning? + +PyTorch Lightning is just organized PyTorch - Lightning disentangles PyTorch code to decouple the science from the engineering. + +![PT to PL](docs/source-pytorch/_static/images/general/pl_quick_start_full_compressed.gif) + +  + +---- + +### Examples +Explore various types of training possible with PyTorch Lightning. Pretrain and finetune ANY kind of model to perform ANY task like classification, segmentation, summarization and more: + +| Task | Description | Run | +|-------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------|---| +| [Hello world](#hello-simple-model) | Pretrain - Hello world example | Open In Studio | +| [Image classification](https://lightning.ai/lightning-ai/studios/image-classification-with-pytorch-lightning) | Finetune - ResNet-34 model to classify images of cars | Open In Studio | +| [Image segmentation](https://lightning.ai/lightning-ai/studios/image-segmentation-with-pytorch-lightning) | Finetune - ResNet-50 model to segment images | Open In Studio | +| [Object detection](https://lightning.ai/lightning-ai/studios/object-detection-with-pytorch-lightning) | Finetune - Faster R-CNN model to detect objects | Open In Studio | +| [Text classification](https://lightning.ai/lightning-ai/studios/text-classification-with-pytorch-lightning) | Finetune - text classifier (BERT model) | Open In Studio | +| [Text summarization](https://lightning.ai/lightning-ai/studios/text-summarization-with-pytorch-lightning) | Finetune - text summarization (Hugging Face transformer model) | Open In Studio | +| [Audio generation](https://lightning.ai/lightning-ai/studios/finetune-a-personal-ai-music-generator) | Finetune - audio generator (transformer model) | Open In Studio | +| [LLM finetuning](https://lightning.ai/lightning-ai/studios/finetune-an-llm-with-pytorch-lightning) | Finetune - LLM (Meta Llama 3.1 8B) | Open In Studio | +| [Image generation](https://lightning.ai/lightning-ai/studios/train-a-diffusion-model-with-pytorch-lightning) | Pretrain - Image generator (diffusion model) | Open In Studio | +| [Recommendation system](https://lightning.ai/lightning-ai/studios/recommendation-system-with-pytorch-lightning) | Train - recommendation system (factorization and embedding) | Open In Studio | +| [Time-series forecasting](https://lightning.ai/lightning-ai/studios/time-series-forecasting-with-pytorch-lightning) | Train - Time-series forecasting with LSTM | Open In Studio | + ______________________________________________________________________ ## Advanced features diff --git a/_notebooks b/_notebooks index e0720299da014..d527353491441 160000 --- a/_notebooks +++ b/_notebooks @@ -1 +1 @@ -Subproject commit e0720299da014bfaaeb50dea6778b962e28ca69d +Subproject commit d5273534914411886ed45d59536f6042d24f6fe0 diff --git a/docs/source-fabric/_static/images/icon.svg b/docs/source-fabric/_static/images/icon.svg index e88fc19036178..3272f7f87d0fc 100644 --- a/docs/source-fabric/_static/images/icon.svg +++ b/docs/source-fabric/_static/images/icon.svg @@ -1,9 +1,12 @@ - - - - - - - - + + + + + + + + + + + diff --git a/docs/source-fabric/_static/images/logo-large.svg b/docs/source-fabric/_static/images/logo-large.svg index 39531f95e9dba..b4814805e2ddf 100644 --- a/docs/source-fabric/_static/images/logo-large.svg +++ b/docs/source-fabric/_static/images/logo-large.svg @@ -1,9 +1,12 @@ - - - - - - - - + + + + + + + + + + + diff --git a/docs/source-fabric/_static/images/logo-small.svg b/docs/source-fabric/_static/images/logo-small.svg index 1f523a57c4a16..aac0b9618ab37 100644 --- a/docs/source-fabric/_static/images/logo-small.svg +++ b/docs/source-fabric/_static/images/logo-small.svg @@ -1,9 +1,12 @@ - - - - - - - - + + + + + + + + + + + diff --git a/docs/source-pytorch/_static/images/icon.svg b/docs/source-pytorch/_static/images/icon.svg index 481762a961dda..aac0b9618ab37 100644 --- a/docs/source-pytorch/_static/images/icon.svg +++ b/docs/source-pytorch/_static/images/icon.svg @@ -1,3 +1,12 @@ - + + + + + + + + + + diff --git a/docs/source-pytorch/visualize/supported_exp_managers.rst b/docs/source-pytorch/visualize/supported_exp_managers.rst index 42a0e6c9a85ed..e26514e9747c4 100644 --- a/docs/source-pytorch/visualize/supported_exp_managers.rst +++ b/docs/source-pytorch/visualize/supported_exp_managers.rst @@ -134,7 +134,7 @@ Here's the full documentation for the :class:`~lightning.pytorch.loggers.TensorB Weights and Biases ================== -To use `Weights and Biases `_ (wandb) first install the wandb package: +To use `Weights and Biases `_ (wandb) first install the wandb package: .. code-block:: bash diff --git a/pyproject.toml b/pyproject.toml index 6edd6d1a8f11f..da4cd7f197d5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,6 @@ ignore = [ "S108", "E203", # conflicts with black ] -ignore-init-module-imports = true [tool.ruff.lint.per-file-ignores] ".actions/*" = ["S101", "S310"] diff --git a/requirements/typing.txt b/requirements/typing.txt index 9f1952605babc..0323edfd6098a 100644 --- a/requirements/typing.txt +++ b/requirements/typing.txt @@ -1,5 +1,5 @@ mypy==1.11.0 -torch==2.4.0 +torch==2.4.1 types-Markdown types-PyYAML diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index 4dbd57e531859..a1c5a6f6dcd1b 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -31,7 +31,9 @@ _TORCH_GREATER_EQUAL_2_2 = compare_version("torch", operator.ge, "2.2.0") _TORCH_GREATER_EQUAL_2_3 = compare_version("torch", operator.ge, "2.3.0") +_TORCH_EQUAL_2_4_0 = compare_version("torch", operator.eq, "2.4.0") _TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0") +_TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1") _PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 647f6e6e41af7..782fc40d928ef 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -17,6 +17,7 @@ import numbers import weakref from contextlib import contextmanager +from io import BytesIO from pathlib import Path from typing import ( IO, @@ -1364,7 +1365,7 @@ def _verify_is_manual_optimization(self, fn_name: str) -> None: ) @torch.no_grad() - def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = None, **kwargs: Any) -> None: + def to_onnx(self, file_path: Union[str, Path, BytesIO], input_sample: Optional[Any] = None, **kwargs: Any) -> None: """Saves the model in ONNX format. Args: @@ -1403,7 +1404,8 @@ def forward(self, x): input_sample = self._on_before_batch_transfer(input_sample) input_sample = self._apply_batch_transfer_handler(input_sample) - torch.onnx.export(self, input_sample, str(file_path), **kwargs) + file_path = str(file_path) if isinstance(file_path, Path) else file_path + torch.onnx.export(self, input_sample, file_path, **kwargs) self.train(mode) @torch.no_grad() diff --git a/src/lightning/pytorch/loggers/wandb.py b/src/lightning/pytorch/loggers/wandb.py index c5d995bff35a5..20f8d02a7ab9b 100644 --- a/src/lightning/pytorch/loggers/wandb.py +++ b/src/lightning/pytorch/loggers/wandb.py @@ -48,7 +48,7 @@ class WandbLogger(Logger): - r"""Log using `Weights and Biases `_. + r"""Log using `Weights and Biases `_. **Installation and set-up** @@ -253,7 +253,7 @@ def any_lightning_module_function_or_hook(self): See Also: - `Demo in Google Colab `__ with hyperparameter search and model logging - - `W&B Documentation `__ + - `W&B Documentation `__ Args: name: Display name for the run. diff --git a/tests/tests_pytorch/callbacks/test_early_stopping.py b/tests/tests_pytorch/callbacks/test_early_stopping.py index 633c1dc0853e0..b7e52ee549bcc 100644 --- a/tests/tests_pytorch/callbacks/test_early_stopping.py +++ b/tests/tests_pytorch/callbacks/test_early_stopping.py @@ -23,7 +23,7 @@ import cloudpickle import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0 from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel @@ -193,12 +193,12 @@ def test_pickling(): early_stopping = EarlyStopping(monitor="foo") early_stopping_pickled = pickle.dumps(early_stopping) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): early_stopping_loaded = pickle.loads(early_stopping_pickled) assert vars(early_stopping) == vars(early_stopping_loaded) early_stopping_pickled = cloudpickle.dumps(early_stopping) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): early_stopping_loaded = cloudpickle.loads(early_stopping_pickled) assert vars(early_stopping) == vars(early_stopping_loaded) diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index 8ef78a742f9a7..97d8d3c4d0e4a 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -32,7 +32,7 @@ import yaml from jsonargparse import ArgumentParser from lightning.fabric.utilities.cloud_io import _load as pl_load -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0 from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.demos.boring_classes import BoringModel @@ -352,12 +352,12 @@ def test_pickling(tmp_path): ckpt = ModelCheckpoint(dirpath=tmp_path) ckpt_pickled = pickle.dumps(ckpt) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): ckpt_loaded = pickle.loads(ckpt_pickled) assert vars(ckpt) == vars(ckpt_loaded) ckpt_pickled = cloudpickle.dumps(ckpt) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): ckpt_loaded = cloudpickle.loads(ckpt_pickled) assert vars(ckpt) == vars(ckpt_loaded) diff --git a/tests/tests_pytorch/core/test_metric_result_integration.py b/tests/tests_pytorch/core/test_metric_result_integration.py index 9818f9807ae6d..ef340d1e17ea9 100644 --- a/tests/tests_pytorch/core/test_metric_result_integration.py +++ b/tests/tests_pytorch/core/test_metric_result_integration.py @@ -19,7 +19,7 @@ import lightning.pytorch as pl import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0 from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.pytorch import Trainer from lightning.pytorch.callbacks import OnExceptionCheckpoint @@ -254,7 +254,7 @@ def lightning_log(fx, *args, **kwargs): } # make sure can be pickled - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): pickle.loads(pickle.dumps(result)) # make sure can be torch.loaded filepath = str(tmp_path / "result") diff --git a/tests/tests_pytorch/helpers/test_datasets.py b/tests/tests_pytorch/helpers/test_datasets.py index ddc20c29e62e8..98d77a6d9a8ad 100644 --- a/tests/tests_pytorch/helpers/test_datasets.py +++ b/tests/tests_pytorch/helpers/test_datasets.py @@ -17,7 +17,7 @@ import cloudpickle import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0 from tests_pytorch import _PATH_DATASETS from tests_pytorch.helpers.datasets import MNIST, AverageDataset, TrialMNIST @@ -44,9 +44,9 @@ def test_pickling_dataset_mnist(dataset_cls, args): mnist = dataset_cls(**args) mnist_pickled = pickle.dumps(mnist) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): pickle.loads(mnist_pickled) mnist_pickled = cloudpickle.dumps(mnist) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): cloudpickle.loads(mnist_pickled) diff --git a/tests/tests_pytorch/loggers/test_all.py b/tests/tests_pytorch/loggers/test_all.py index 503e49fe6cdad..c5b07562afb0a 100644 --- a/tests/tests_pytorch/loggers/test_all.py +++ b/tests/tests_pytorch/loggers/test_all.py @@ -20,7 +20,7 @@ import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0, _TORCH_GREATER_EQUAL_2_4_1 from lightning.pytorch import Callback, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.loggers import ( @@ -163,7 +163,7 @@ def test_loggers_pickle_all(tmp_path, monkeypatch, logger_class): pytest.xfail(f"pickle test requires {logger_class.__class__} dependencies to be installed.") -def _test_loggers_pickle(tmp_path, monkeypatch, logger_class): +def _test_loggers_pickle(tmp_path, monkeypatch, logger_class: Logger): """Verify that pickling trainer with logger works.""" _patch_comet_atexit(monkeypatch) @@ -184,7 +184,11 @@ def _test_loggers_pickle(tmp_path, monkeypatch, logger_class): trainer = Trainer(max_epochs=1, logger=logger) pkl_bytes = pickle.dumps(trainer) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with ( + pytest.warns(FutureWarning, match="`weights_only=False`") + if _TORCH_EQUAL_2_4_0 or (_TORCH_GREATER_EQUAL_2_4_1 and logger_class not in (CSVLogger, TensorBoardLogger)) + else nullcontext() + ): trainer2 = pickle.loads(pkl_bytes) trainer2.logger.log_metrics({"acc": 1.0}) diff --git a/tests/tests_pytorch/loggers/test_logger.py b/tests/tests_pytorch/loggers/test_logger.py index 7b384890f6148..de0028000cd9f 100644 --- a/tests/tests_pytorch/loggers/test_logger.py +++ b/tests/tests_pytorch/loggers/test_logger.py @@ -21,7 +21,7 @@ import numpy as np import pytest import torch -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0 from lightning.fabric.utilities.logger import _convert_params, _sanitize_params from lightning.pytorch import Trainer from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel @@ -124,7 +124,7 @@ def test_multiple_loggers_pickle(tmp_path): trainer = Trainer(logger=[logger1, logger2]) pkl_bytes = pickle.dumps(trainer) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): trainer2 = pickle.loads(pkl_bytes) for logger in trainer2.loggers: logger.log_metrics({"acc": 1.0}, 0) diff --git a/tests/tests_pytorch/loggers/test_wandb.py b/tests/tests_pytorch/loggers/test_wandb.py index e9195f628348b..4e3fbb287a1f9 100644 --- a/tests/tests_pytorch/loggers/test_wandb.py +++ b/tests/tests_pytorch/loggers/test_wandb.py @@ -19,7 +19,7 @@ import pytest import yaml -from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_4 +from lightning.fabric.utilities.imports import _TORCH_EQUAL_2_4_0 from lightning.pytorch import Trainer from lightning.pytorch.callbacks import ModelCheckpoint from lightning.pytorch.cli import LightningCLI @@ -162,7 +162,7 @@ def name(self): assert trainer.logger.experiment, "missing experiment" assert trainer.log_dir == logger.save_dir pkl_bytes = pickle.dumps(trainer) - with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_GREATER_EQUAL_2_4 else nullcontext(): + with pytest.warns(FutureWarning, match="`weights_only=False`") if _TORCH_EQUAL_2_4_0 else nullcontext(): trainer2 = pickle.loads(pkl_bytes) assert os.environ["WANDB_MODE"] == "dryrun" diff --git a/tests/tests_pytorch/models/test_onnx.py b/tests/tests_pytorch/models/test_onnx.py index 15d06355946fc..ee670cd66e871 100644 --- a/tests/tests_pytorch/models/test_onnx.py +++ b/tests/tests_pytorch/models/test_onnx.py @@ -13,6 +13,7 @@ # limitations under the License. import operator import os +from io import BytesIO from pathlib import Path from unittest.mock import patch @@ -45,6 +46,10 @@ def test_model_saves_with_input_sample(tmp_path): assert os.path.isfile(file_path) assert os.path.getsize(file_path) > 4e2 + file_path = BytesIO() + model.to_onnx(file_path=file_path, input_sample=input_sample) + assert len(file_path.getvalue()) > 4e2 + @pytest.mark.parametrize( "accelerator", [pytest.param("mps", marks=RunIf(mps=True)), pytest.param("gpu", marks=RunIf(min_cuda_gpus=True))]