diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index 536081798c522..023ccf444c44b 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -17,3 +17,5 @@ fastapi # for `ServableModuleValidator` # not setting version as re-defined in uvicorn # for `ServableModuleValidator` # not setting version as re-defined in App tensorboard >=2.9.1, <2.21.0 # for `TensorBoardLogger` + +torch-tensorrt >= 2.5.0; platform_system == "Linux" and python_version >= "3.12" # Limit the torch-tensorrt version for testing. diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 905cc9a67c85a..077de6d2b9629 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -13,11 +13,12 @@ # limitations under the License. """The LightningModule - an nn.Module with many additional features.""" +import copy import logging import numbers import weakref from collections.abc import Generator, Mapping, Sequence -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from io import BytesIO from pathlib import Path from typing import ( @@ -47,6 +48,7 @@ from lightning.fabric.utilities.apply_func import convert_to_tensors from lightning.fabric.utilities.cloud_io import get_filesystem from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH from lightning.fabric.wrappers import _FabricOptimizer from lightning.pytorch.callbacks.callback import Callback @@ -76,6 +78,7 @@ from torch.distributed.device_mesh import DeviceMesh _ONNX_AVAILABLE = RequirementCache("onnx") +_TORCH_TRT_AVAILABLE = RequirementCache("torch_tensorrt") warning_cache = WarningCache() log = logging.getLogger(__name__) @@ -1519,6 +1522,109 @@ def forward(self, x): return torchscript_module + @torch.no_grad() + def to_tensorrt( + self, + file_path: Optional[Union[str, Path, BytesIO]] = None, + input_sample: Optional[Any] = None, + ir: Literal["default", "dynamo", "ts"] = "default", + output_format: Literal["exported_program", "torchscript"] = "exported_program", + retrace: bool = False, + default_device: Union[str, torch.device] = "cuda", + **compile_kwargs: Any, + ) -> Union[ScriptModule, torch.fx.GraphModule]: + """Export the model to ScriptModule or GraphModule using TensorRT compile backend. + + Args: + file_path: Path where to save the tensorrt model. Default: None (no file saved). + input_sample: inputs to be used during `torch_tensorrt.compile`. + Default: None (Use :attr:`example_input_array`). + ir: The IR mode to use for TensorRT compilation. Default: "default". + output_format: The format of the output model. Default: "exported_program". + retrace: Whether to retrace the model. Default: False. + default_device: The device to use for the model when the current model is not in CUDA. Default: "cuda". + **compile_kwargs: Additional arguments that will be passed to the TensorRT compile function. + + Example:: + + class SimpleModel(LightningModule): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(in_features=64, out_features=4) + + def forward(self, x): + return torch.relu(self.l1(x.view(x.size(0), -1) + + model = SimpleModel() + input_sample = torch.randn(1, 64) + exported_program = model.to_tensorrt( + file_path="export.ep", + inputs=input_sample, + ) + + """ + if not _TORCH_GREATER_EQUAL_2_2: + raise MisconfigurationException( + f"TensorRT export requires PyTorch 2.2 or higher. Current version is {torch.__version__}." + ) + + if not _TORCH_TRT_AVAILABLE: + raise ModuleNotFoundError( + f"`{type(self).__name__}.to_tensorrt` requires `torch_tensorrt` to be installed. " + ) + + mode = self.training + device = self.device + if self.device.type != "cuda": + default_device = torch.device(default_device) if isinstance(default_device, str) else default_device + + if not torch.cuda.is_available() or default_device.type != "cuda": + raise MisconfigurationException( + f"TensorRT only supports CUDA devices. The current device is {self.device}." + f" Please set the `default_device` argument to a CUDA device." + ) + + self.to(default_device) + + if input_sample is None: + if self.example_input_array is None: + raise ValueError( + "Could not export to TensorRT since neither `input_sample` nor" + " `model.example_input_array` attribute is set." + ) + input_sample = self.example_input_array + + import torch_tensorrt + + input_sample = copy.deepcopy((input_sample,) if isinstance(input_sample, torch.Tensor) else input_sample) + input_sample = self._on_before_batch_transfer(input_sample) + input_sample = self._apply_batch_transfer_handler(input_sample) + + with _jit_is_scripting() if ir == "ts" else nullcontext(): + trt_obj = torch_tensorrt.compile( + module=self.eval(), + ir=ir, + inputs=input_sample, + **compile_kwargs, + ) + self.train(mode) + self.to(device) + + if file_path is not None: + if ir == "ts" and output_format != "torchscript": + raise ValueError( + "TensorRT with IR mode 'ts' only supports output format 'torchscript'." + f" The current output format is {output_format}." + ) + torch_tensorrt.save( + trt_obj, + file_path, + inputs=input_sample, + output_format=output_format, + retrace=retrace, + ) + return trt_obj + @_restricted_classmethod def load_from_checkpoint( cls, diff --git a/src/lightning/pytorch/utilities/testing/_runif.py b/src/lightning/pytorch/utilities/testing/_runif.py index 9c46913681143..6effe533a057c 100644 --- a/src/lightning/pytorch/utilities/testing/_runif.py +++ b/src/lightning/pytorch/utilities/testing/_runif.py @@ -18,7 +18,7 @@ from lightning.fabric.utilities.testing import _runif_reasons as fabric_run_if from lightning.pytorch.accelerators.cpu import _PSUTIL_AVAILABLE from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE -from lightning.pytorch.core.module import _ONNX_AVAILABLE +from lightning.pytorch.core.module import _ONNX_AVAILABLE, _TORCH_TRT_AVAILABLE from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE _SKLEARN_AVAILABLE = RequirementCache("scikit-learn") @@ -42,6 +42,7 @@ def _runif_reasons( psutil: bool = False, sklearn: bool = False, onnx: bool = False, + tensorrt: bool = False, ) -> tuple[list[str], dict[str, bool]]: """Construct reasons for pytest skipif. @@ -64,6 +65,7 @@ def _runif_reasons( psutil: Require that psutil is installed. sklearn: Require that scikit-learn is installed. onnx: Require that onnx is installed. + tensorrt: Require that torch-tensorrt is installed. """ @@ -96,4 +98,7 @@ def _runif_reasons( if onnx and not _ONNX_AVAILABLE: reasons.append("onnx") + if tensorrt and not _TORCH_TRT_AVAILABLE: + reasons.append("torch-tensorrt") + return reasons, kwargs diff --git a/tests/tests_pytorch/models/test_torch_tensorrt.py b/tests/tests_pytorch/models/test_torch_tensorrt.py new file mode 100644 index 0000000000000..630e59f711348 --- /dev/null +++ b/tests/tests_pytorch/models/test_torch_tensorrt.py @@ -0,0 +1,155 @@ +import os +import re +from io import BytesIO +from pathlib import Path + +import pytest +import torch + +import tests_pytorch.helpers.pipelines as pipes +from lightning.pytorch.core.module import _TORCH_TRT_AVAILABLE +from lightning.pytorch.demos.boring_classes import BoringModel +from lightning.pytorch.utilities.exceptions import MisconfigurationException +from tests_pytorch.helpers.runif import RunIf + + +@RunIf(max_torch="2.2.0") +def test_torch_minimum_version(): + model = BoringModel() + with pytest.raises( + MisconfigurationException, + match=re.escape(f"TensorRT export requires PyTorch 2.2 or higher. Current version is {torch.__version__}."), + ): + model.to_tensorrt("model.trt") + + +@pytest.mark.skipif(_TORCH_TRT_AVAILABLE, reason="Run this test only if tensorrt is not available.") +@RunIf(min_torch="2.2.0") +def test_missing_tensorrt_package(): + model = BoringModel() + with pytest.raises( + ModuleNotFoundError, + match=re.escape(f"`{type(model).__name__}.to_tensorrt` requires `torch_tensorrt` to be installed. "), + ): + model.to_tensorrt("model.trt") + + +@RunIf(tensorrt=True, min_torch="2.2.0") +def test_tensorrt_with_wrong_default_device(tmp_path): + model = BoringModel() + input_sample = torch.randn((1, 32)) + file_path = os.path.join(tmp_path, "model.trt") + with pytest.raises(MisconfigurationException): + model.to_tensorrt(file_path, input_sample, default_device="cpu") + + +@RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0") +def test_tensorrt_saves_with_input_sample(tmp_path): + model = BoringModel() + ori_device = model.device + input_sample = torch.randn((1, 32)) + + file_path = os.path.join(tmp_path, "model.trt") + model.to_tensorrt(file_path, input_sample) + + assert os.path.isfile(file_path) + assert os.path.getsize(file_path) > 4e2 + assert model.device == ori_device + + file_path = Path(tmp_path) / "model.trt" + model.to_tensorrt(file_path, input_sample) + assert os.path.isfile(file_path) + assert os.path.getsize(file_path) > 4e2 + assert model.device == ori_device + + file_path = BytesIO() + model.to_tensorrt(file_path, input_sample) + assert len(file_path.getvalue()) > 4e2 + + +@RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0") +def test_tensorrt_error_if_no_input(tmp_path): + model = BoringModel() + model.example_input_array = None + file_path = os.path.join(tmp_path, "model.trt") + + with pytest.raises( + ValueError, + match=r"Could not export to TensorRT since neither `input_sample` nor " + r"`model.example_input_array` attribute is set.", + ): + model.to_tensorrt(file_path) + + +@RunIf(tensorrt=True, min_cuda_gpus=2, min_torch="2.2.0") +def test_tensorrt_saves_on_multi_gpu(tmp_path): + trainer_options = { + "default_root_dir": tmp_path, + "max_epochs": 1, + "limit_train_batches": 10, + "limit_val_batches": 10, + "accelerator": "gpu", + "devices": [0, 1], + "strategy": "ddp_spawn", + "enable_progress_bar": False, + } + + model = BoringModel() + model.example_input_array = torch.randn((4, 32)) + + pipes.run_model_test(trainer_options, model, min_acc=0.08) + + file_path = os.path.join(tmp_path, "model.trt") + model.to_tensorrt(file_path) + + assert os.path.exists(file_path) + + +@pytest.mark.parametrize( + ("ir", "export_type"), + [ + ("default", torch.fx.GraphModule), + ("dynamo", torch.fx.GraphModule), + ("ts", torch.jit.ScriptModule), + ], +) +@RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0") +def test_tensorrt_save_ir_type(ir, export_type): + model = BoringModel() + model.example_input_array = torch.randn((4, 32)) + + ret = model.to_tensorrt(ir=ir) + assert isinstance(ret, export_type) + + +@pytest.mark.parametrize( + "output_format", + ["exported_program", "torchscript"], +) +@pytest.mark.parametrize( + "ir", + ["default", "dynamo", "ts"], +) +@RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0") +def test_tensorrt_export_reload(output_format, ir, tmp_path): + if ir == "ts" and output_format == "exported_program": + pytest.skip("TorchScript cannot be exported as exported_program") + + import torch_tensorrt + + model = BoringModel() + model.cuda().eval() + model.example_input_array = torch.ones((4, 32)) + + file_path = os.path.join(tmp_path, "model.trt") + model.to_tensorrt(file_path, output_format=output_format, ir=ir) + + loaded_model = torch_tensorrt.load(file_path) + if output_format == "exported_program": + loaded_model = loaded_model.module() + + with torch.no_grad(), torch.inference_mode(): + model_output = model(model.example_input_array.to("cuda")) + jit_output = loaded_model(model.example_input_array.to("cuda")) + + assert torch.allclose(model_output, jit_output, rtol=1e-03, atol=1e-06)