From f589e31d5f04100e7ef8655999d300c956ab0dee Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sun, 4 May 2025 18:09:03 +0800 Subject: [PATCH 01/29] feat: add `to_tensorrt` in the `LightningModule`. --- src/lightning/pytorch/core/module.py | 74 ++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index b8624daac3fa3..805ba6284b94a 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -76,6 +76,7 @@ from torch.distributed.device_mesh import DeviceMesh _ONNX_AVAILABLE = RequirementCache("onnx") +_TORCH_TRT_AVAILABLE = RequirementCache("torch_tensorrt") warning_cache = WarningCache() log = logging.getLogger(__name__) @@ -1489,6 +1490,79 @@ def forward(self, x): return torchscript_module + @torch.no_grad() + def to_tensorrt( + self, + file_path: str | Path | BytesIO | None = None, + inputs: Any | None = None, + ir: Literal["default", "dynamo", "ts"] = "default", + output_format: Literal["exported_program", "torchscript"] = "exported_program", + retrace: bool = False, + **compile_kwargs, + ) -> torch.ScriptModule | torch.fx.GraphModule: + """Export the model to ScriptModule or GraphModule using TensorRT compile backend. + + Args: + file_path: Path where to save the tensorrt model. Default: None (no file saved). + inputs: inputs to be used during `torch_tensorrt.compile`. Default: None (Use self.example_input_array). + ir: The IR mode to use for TensorRT compilation. Default: "default". + output_format: The format of the output model. Default: "exported_program". + retrace: Whether to retrace the model. Default: False. + **compile_kwargs: Additional arguments that will be passed to the TensorRT compile function. + + Example:: + + class SimpleModel(LightningModule): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(in_features=64, out_features=4) + + def forward(self, x): + return torch.relu(self.l1(x.view(x.size(0), -1) + + model = SimpleModel() + input_sample = torch.randn(1, 64) + exported_program = model.to_tensorrt( + file_path="export.ep", + inputs=input_sample, + ) + + """ + + if not _TORCH_TRT_AVAILABLE: + raise ModuleNotFoundError( + f"`{type(self).__name__}.to_tensorrt` requires `torch_tensorrt` to be installed. " + ) + + import torch_tensorrt + + mode = self.training + + if inputs is None: + if self.example_input_array is None: + raise ValueError("Please provide an example input for the model.") + inputs = self.example_input_array + inputs = self._on_before_batch_transfer(inputs) + inputs = self._apply_batch_transfer_handler(inputs) + + trt_obj = torch_tensorrt.compile( + module=self.eval(), + ir=ir, + inputs=inputs, + **compile_kwargs, + ) + self.train(mode) + + if file_path is not None: + torch_tensorrt.save( + trt_obj, + file_path, + inputs=inputs, + output_format=output_format, + retrace=retrace, + ) + return trt_obj + @_restricted_classmethod def load_from_checkpoint( cls, From 14b9f294825b8d66094d8972a7e1086df42379af Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sun, 4 May 2025 18:09:03 +0800 Subject: [PATCH 02/29] feat: add `to_tensorrt` in the `LightningModule`. --- src/lightning/pytorch/core/module.py | 74 ++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index b8624daac3fa3..805ba6284b94a 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -76,6 +76,7 @@ from torch.distributed.device_mesh import DeviceMesh _ONNX_AVAILABLE = RequirementCache("onnx") +_TORCH_TRT_AVAILABLE = RequirementCache("torch_tensorrt") warning_cache = WarningCache() log = logging.getLogger(__name__) @@ -1489,6 +1490,79 @@ def forward(self, x): return torchscript_module + @torch.no_grad() + def to_tensorrt( + self, + file_path: str | Path | BytesIO | None = None, + inputs: Any | None = None, + ir: Literal["default", "dynamo", "ts"] = "default", + output_format: Literal["exported_program", "torchscript"] = "exported_program", + retrace: bool = False, + **compile_kwargs, + ) -> torch.ScriptModule | torch.fx.GraphModule: + """Export the model to ScriptModule or GraphModule using TensorRT compile backend. + + Args: + file_path: Path where to save the tensorrt model. Default: None (no file saved). + inputs: inputs to be used during `torch_tensorrt.compile`. Default: None (Use self.example_input_array). + ir: The IR mode to use for TensorRT compilation. Default: "default". + output_format: The format of the output model. Default: "exported_program". + retrace: Whether to retrace the model. Default: False. + **compile_kwargs: Additional arguments that will be passed to the TensorRT compile function. + + Example:: + + class SimpleModel(LightningModule): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(in_features=64, out_features=4) + + def forward(self, x): + return torch.relu(self.l1(x.view(x.size(0), -1) + + model = SimpleModel() + input_sample = torch.randn(1, 64) + exported_program = model.to_tensorrt( + file_path="export.ep", + inputs=input_sample, + ) + + """ + + if not _TORCH_TRT_AVAILABLE: + raise ModuleNotFoundError( + f"`{type(self).__name__}.to_tensorrt` requires `torch_tensorrt` to be installed. " + ) + + import torch_tensorrt + + mode = self.training + + if inputs is None: + if self.example_input_array is None: + raise ValueError("Please provide an example input for the model.") + inputs = self.example_input_array + inputs = self._on_before_batch_transfer(inputs) + inputs = self._apply_batch_transfer_handler(inputs) + + trt_obj = torch_tensorrt.compile( + module=self.eval(), + ir=ir, + inputs=inputs, + **compile_kwargs, + ) + self.train(mode) + + if file_path is not None: + torch_tensorrt.save( + trt_obj, + file_path, + inputs=inputs, + output_format=output_format, + retrace=retrace, + ) + return trt_obj + @_restricted_classmethod def load_from_checkpoint( cls, From 314c4630481e1c43f2ee2c02536b718966e89304 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sat, 10 May 2025 15:58:11 +0800 Subject: [PATCH 03/29] refactor: fix `to_tensorrt` impl --- src/lightning/pytorch/core/module.py | 59 ++++++++++++++++++++-------- 1 file changed, 42 insertions(+), 17 deletions(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 805ba6284b94a..ff9ed9b8d810b 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -13,11 +13,12 @@ # limitations under the License. """The LightningModule - an nn.Module with many additional features.""" +import copy import logging import numbers import weakref from collections.abc import Generator, Mapping, Sequence -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from io import BytesIO from pathlib import Path from typing import ( @@ -1494,20 +1495,23 @@ def forward(self, x): def to_tensorrt( self, file_path: str | Path | BytesIO | None = None, - inputs: Any | None = None, + input_sample: Any | None = None, ir: Literal["default", "dynamo", "ts"] = "default", output_format: Literal["exported_program", "torchscript"] = "exported_program", retrace: bool = False, + default_device: str | torch.device = "cuda", **compile_kwargs, - ) -> torch.ScriptModule | torch.fx.GraphModule: + ) -> ScriptModule | torch.fx.GraphModule: """Export the model to ScriptModule or GraphModule using TensorRT compile backend. Args: file_path: Path where to save the tensorrt model. Default: None (no file saved). - inputs: inputs to be used during `torch_tensorrt.compile`. Default: None (Use self.example_input_array). + input_sample: inputs to be used during `torch_tensorrt.compile`. + Default: None (Use :attr:`example_input_array`). ir: The IR mode to use for TensorRT compilation. Default: "default". output_format: The format of the output model. Default: "exported_program". retrace: Whether to retrace the model. Default: False. + default_device: The device to use for the model when the current model is not in CUDA. Default: "cuda". **compile_kwargs: Additional arguments that will be passed to the TensorRT compile function. Example:: @@ -1537,27 +1541,48 @@ def forward(self, x): import torch_tensorrt mode = self.training + device = self.device + if self.device.type != "cuda": + default_device = torch.device(default_device) if isinstance(default_device, str) else default_device + if default_device.type != "cuda": + raise ValueError( + f"TensorRT only supports CUDA devices. The current device is {self.device}." + f" Please set the `default_device` argument to a CUDA device." + ) + + self.to(default_device) - if inputs is None: + if input_sample is None: if self.example_input_array is None: - raise ValueError("Please provide an example input for the model.") - inputs = self.example_input_array - inputs = self._on_before_batch_transfer(inputs) - inputs = self._apply_batch_transfer_handler(inputs) - - trt_obj = torch_tensorrt.compile( - module=self.eval(), - ir=ir, - inputs=inputs, - **compile_kwargs, - ) + raise ValueError( + "Could not export to TensorRT since neither `input_sample` nor" + " `model.example_input_array` attribute is set." + ) + input_sample = self.example_input_array + input_sample = copy.deepcopy((input_sample,) if isinstance(input_sample, torch.Tensor) else input_sample) + input_sample = self._on_before_batch_transfer(input_sample) + input_sample = self._apply_batch_transfer_handler(input_sample) + + with _jit_is_scripting() if ir == "ts" else nullcontext(): + trt_obj = torch_tensorrt.compile( + module=self.eval(), + ir=ir, + inputs=input_sample, + **compile_kwargs, + ) self.train(mode) + self.to(device) if file_path is not None: + if ir == "ts" and output_format != "torchscript": + raise ValueError( + "TensorRT with IR mode 'ts' only supports output format 'torchscript'." + f" The current output format is {output_format}." + ) torch_tensorrt.save( trt_obj, file_path, - inputs=inputs, + inputs=input_sample, output_format=output_format, retrace=retrace, ) From 534c6c43ce402fbbecbe994d27fd402575ab4f2d Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sat, 10 May 2025 16:00:19 +0800 Subject: [PATCH 04/29] test: add test_torch_tensorrt.py --- .../pytorch/utilities/testing/_runif.py | 6 +- .../models/test_torch_tensorrt.py | 121 ++++++++++++++++++ 2 files changed, 126 insertions(+), 1 deletion(-) create mode 100644 tests/tests_pytorch/models/test_torch_tensorrt.py diff --git a/src/lightning/pytorch/utilities/testing/_runif.py b/src/lightning/pytorch/utilities/testing/_runif.py index 9c46913681143..19e267d67e724 100644 --- a/src/lightning/pytorch/utilities/testing/_runif.py +++ b/src/lightning/pytorch/utilities/testing/_runif.py @@ -18,7 +18,7 @@ from lightning.fabric.utilities.testing import _runif_reasons as fabric_run_if from lightning.pytorch.accelerators.cpu import _PSUTIL_AVAILABLE from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE -from lightning.pytorch.core.module import _ONNX_AVAILABLE +from lightning.pytorch.core.module import _ONNX_AVAILABLE, _TORCH_TRT_AVAILABLE from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE _SKLEARN_AVAILABLE = RequirementCache("scikit-learn") @@ -42,6 +42,7 @@ def _runif_reasons( psutil: bool = False, sklearn: bool = False, onnx: bool = False, + tensorrt: bool = False, ) -> tuple[list[str], dict[str, bool]]: """Construct reasons for pytest skipif. @@ -96,4 +97,7 @@ def _runif_reasons( if onnx and not _ONNX_AVAILABLE: reasons.append("onnx") + if onnx and not _TORCH_TRT_AVAILABLE: + reasons.append("torch-tensorrt") + return reasons, kwargs diff --git a/tests/tests_pytorch/models/test_torch_tensorrt.py b/tests/tests_pytorch/models/test_torch_tensorrt.py new file mode 100644 index 0000000000000..a413b0b9a09f1 --- /dev/null +++ b/tests/tests_pytorch/models/test_torch_tensorrt.py @@ -0,0 +1,121 @@ +import os +from io import BytesIO +from pathlib import Path + +import pytest +import torch + +import tests_pytorch.helpers.pipelines as tpipes +from lightning.pytorch.demos.boring_classes import BoringModel +from tests_pytorch.helpers.runif import RunIf + + +@RunIf(tensorrt=True, min_cuda_gpus=1) +def test_tensorrt_saves_with_input_sample(tmp_path): + model = BoringModel() + ori_device = model.device + input_sample = torch.randn((1, 32)) + + file_path = os.path.join(tmp_path, "model.trt") + model.to_tensorrt(file_path, input_sample) + + assert os.path.isfile(file_path) + assert os.path.getsize(file_path) > 4e2 + assert model.device == ori_device + + file_path = Path(tmp_path) / "model.trt" + model.to_tensorrt(file_path, input_sample) + assert os.path.isfile(file_path) + assert os.path.getsize(file_path) > 4e2 + assert model.device == ori_device + + file_path = BytesIO() + model.to_tensorrt(file_path, input_sample) + assert len(file_path.getvalue()) > 4e2 + + +def test_tensorrt_error_if_no_input(tmp_path): + model = BoringModel() + model.example_input_array = None + file_path = os.path.join(tmp_path, "model.trt") + + with pytest.raises( + ValueError, + match=r"Could not export to TensorRT since neither `input_sample` nor " + r"`model.example_input_array` attribute is set.", + ): + model.to_tensorrt(file_path) + + +@RunIf(tensorrt=True, min_cuda_gpus=2) +def test_tensorrt_saves_on_multi_gpu(tmp_path): + trainer_options = { + "default_root_dir": tmp_path, + "max_epochs": 1, + "limit_train_batches": 10, + "limit_val_batches": 10, + "accelerator": "gpu", + "devices": [0, 1], + "strategy": "ddp_spawn", + "enable_progress_bar": False, + } + + model = BoringModel() + model.example_input_array = torch.randn((4, 32)) + + tpipes.run_model_test(trainer_options, model, min_acc=0.08) + + file_path = os.path.join(tmp_path, "model.trt") + model.to_tensorrt(file_path) + + assert os.path.exists(file_path) + + +@pytest.mark.parametrize( + ("ir", "export_type"), + [ + ("default", torch.fx.GraphModule), + ("dynamo", torch.fx.GraphModule), + ("ts", torch.jit.ScriptModule), + ], +) +@RunIf(tensorrt=True, min_cuda_gpus=1) +def test_tensorrt_save_ir_type(ir, export_type): + model = BoringModel() + model.example_input_array = torch.randn((4, 32)) + + ret = model.to_tensorrt(ir=ir) + assert isinstance(ret, export_type) + + +@pytest.mark.parametrize( + "output_format", + ["exported_program", "torchscript"], +) +@pytest.mark.parametrize( + "ir", + ["default", "dynamo", "ts"], +) +@RunIf(tensorrt=True, min_cuda_gpus=1) +def test_tensorrt_export_reload(output_format, ir, tmp_path): + import torch_tensorrt + + if ir == "ts" and output_format == "exported_program": + pytest.skip("TorchScript cannot be exported as exported_program") + + model = BoringModel() + model.cuda().eval() + model.example_input_array = torch.randn((4, 32)) + + file_path = os.path.join(tmp_path, "model.trt") + model.to_tensorrt(file_path, output_format=output_format, ir=ir) + + loaded_model = torch_tensorrt.load(file_path) + if output_format == "exported_program": + loaded_model = loaded_model.module() + + with torch.no_grad(), torch.inference_mode(): + model_output = model(model.example_input_array.to(model.device)) + + jit_output = loaded_model(model.example_input_array.to("cuda")) + assert torch.allclose(model_output, jit_output) From 5c01acb76d91b2d954c70a7e109f1286f7f72369 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sat, 10 May 2025 16:13:50 +0800 Subject: [PATCH 05/29] add dependency in test requirement. --- requirements/pytorch/test.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index 412a8f270bf47..e2fe8581e745e 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -17,3 +17,5 @@ fastapi # for `ServableModuleValidator` # not setting version as re-defined in uvicorn # for `ServableModuleValidator` # not setting version as re-defined in App tensorboard >=2.9.1, <2.20.0 # for `TensorBoardLogger` + +torch-tensorrt >=2.1.0, <2.8.0 From 26d2788c89f1c77d010758dbdf4be28cdfb6296d Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sat, 10 May 2025 16:24:48 +0800 Subject: [PATCH 06/29] update dependency in test requirement. --- requirements/pytorch/test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index e2fe8581e745e..2b218de660643 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -18,4 +18,4 @@ uvicorn # for `ServableModuleValidator` # not setting version as re-defined in tensorboard >=2.9.1, <2.20.0 # for `TensorBoardLogger` -torch-tensorrt >=2.1.0, <2.8.0 +torch-tensorrt >=1.4.0, <2.8.0 From d42ebe4a8688d4c891e76dd6c57aca85582ba6b3 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Wed, 21 May 2025 22:24:29 +0800 Subject: [PATCH 07/29] fix mypy error. --- src/lightning/pytorch/core/module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index ff9ed9b8d810b..8a93e537a9a48 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1500,8 +1500,8 @@ def to_tensorrt( output_format: Literal["exported_program", "torchscript"] = "exported_program", retrace: bool = False, default_device: str | torch.device = "cuda", - **compile_kwargs, - ) -> ScriptModule | torch.fx.GraphModule: + **compile_kwargs: Any, + ) -> Union[ScriptModule, torch.fx.GraphModule]: """Export the model to ScriptModule or GraphModule using TensorRT compile backend. Args: From 958968da6bea70e70239548aed295f5ac7615642 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Thu, 22 May 2025 01:56:50 +0800 Subject: [PATCH 08/29] fix: fix unittest. --- tests/tests_pytorch/models/test_torch_tensorrt.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/tests_pytorch/models/test_torch_tensorrt.py b/tests/tests_pytorch/models/test_torch_tensorrt.py index a413b0b9a09f1..aee773c610ddc 100644 --- a/tests/tests_pytorch/models/test_torch_tensorrt.py +++ b/tests/tests_pytorch/models/test_torch_tensorrt.py @@ -5,12 +5,12 @@ import pytest import torch -import tests_pytorch.helpers.pipelines as tpipes +import tests_pytorch.helpers.pipelines as pipes from lightning.pytorch.demos.boring_classes import BoringModel from tests_pytorch.helpers.runif import RunIf -@RunIf(tensorrt=True, min_cuda_gpus=1) +@RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0") def test_tensorrt_saves_with_input_sample(tmp_path): model = BoringModel() ori_device = model.device @@ -34,6 +34,7 @@ def test_tensorrt_saves_with_input_sample(tmp_path): assert len(file_path.getvalue()) > 4e2 +@RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0") def test_tensorrt_error_if_no_input(tmp_path): model = BoringModel() model.example_input_array = None @@ -47,7 +48,7 @@ def test_tensorrt_error_if_no_input(tmp_path): model.to_tensorrt(file_path) -@RunIf(tensorrt=True, min_cuda_gpus=2) +@RunIf(tensorrt=True, min_cuda_gpus=2, min_torch="2.2.0") def test_tensorrt_saves_on_multi_gpu(tmp_path): trainer_options = { "default_root_dir": tmp_path, @@ -63,7 +64,7 @@ def test_tensorrt_saves_on_multi_gpu(tmp_path): model = BoringModel() model.example_input_array = torch.randn((4, 32)) - tpipes.run_model_test(trainer_options, model, min_acc=0.08) + pipes.run_model_test(trainer_options, model, min_acc=0.08) file_path = os.path.join(tmp_path, "model.trt") model.to_tensorrt(file_path) @@ -79,7 +80,7 @@ def test_tensorrt_saves_on_multi_gpu(tmp_path): ("ts", torch.jit.ScriptModule), ], ) -@RunIf(tensorrt=True, min_cuda_gpus=1) +@RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0") def test_tensorrt_save_ir_type(ir, export_type): model = BoringModel() model.example_input_array = torch.randn((4, 32)) @@ -96,7 +97,7 @@ def test_tensorrt_save_ir_type(ir, export_type): "ir", ["default", "dynamo", "ts"], ) -@RunIf(tensorrt=True, min_cuda_gpus=1) +@RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0") def test_tensorrt_export_reload(output_format, ir, tmp_path): import torch_tensorrt From 53257f9635b5eaf07547f41c1fa2e3c342789a04 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sat, 24 May 2025 01:37:58 +0800 Subject: [PATCH 09/29] fix: fix unittest. --- tests/tests_pytorch/models/test_torch_tensorrt.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/tests_pytorch/models/test_torch_tensorrt.py b/tests/tests_pytorch/models/test_torch_tensorrt.py index aee773c610ddc..8fc19531dcbea 100644 --- a/tests/tests_pytorch/models/test_torch_tensorrt.py +++ b/tests/tests_pytorch/models/test_torch_tensorrt.py @@ -99,14 +99,14 @@ def test_tensorrt_save_ir_type(ir, export_type): ) @RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0") def test_tensorrt_export_reload(output_format, ir, tmp_path): - import torch_tensorrt - if ir == "ts" and output_format == "exported_program": pytest.skip("TorchScript cannot be exported as exported_program") + import torch_tensorrt + model = BoringModel() model.cuda().eval() - model.example_input_array = torch.randn((4, 32)) + model.example_input_array = torch.ones((4, 32)) file_path = os.path.join(tmp_path, "model.trt") model.to_tensorrt(file_path, output_format=output_format, ir=ir) @@ -116,7 +116,7 @@ def test_tensorrt_export_reload(output_format, ir, tmp_path): loaded_model = loaded_model.module() with torch.no_grad(), torch.inference_mode(): - model_output = model(model.example_input_array.to(model.device)) + model_output = model(model.example_input_array.to("cuda")) + jit_output = loaded_model(model.example_input_array.to("cuda")) - jit_output = loaded_model(model.example_input_array.to("cuda")) - assert torch.allclose(model_output, jit_output) + assert torch.allclose(model_output, jit_output, rtol=1e-03, atol=1e-06) From 0723071209bad06bb259011b14f5f51de3702a95 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sat, 24 May 2025 01:47:00 +0800 Subject: [PATCH 10/29] fix: fix type annotation. --- src/lightning/pytorch/core/module.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 8a93e537a9a48..52b7ab7201cc9 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1494,12 +1494,12 @@ def forward(self, x): @torch.no_grad() def to_tensorrt( self, - file_path: str | Path | BytesIO | None = None, - input_sample: Any | None = None, + file_path: Optional[Union[str, Path, BytesIO]] = None, + input_sample: Optional[Any] = None, ir: Literal["default", "dynamo", "ts"] = "default", output_format: Literal["exported_program", "torchscript"] = "exported_program", retrace: bool = False, - default_device: str | torch.device = "cuda", + default_device: Union[str, torch.device] = "cuda", **compile_kwargs: Any, ) -> Union[ScriptModule, torch.fx.GraphModule]: """Export the model to ScriptModule or GraphModule using TensorRT compile backend. From c97ae0d0ce2b99b7008e370fad82956341e0af9a Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sat, 7 Jun 2025 14:32:41 +0800 Subject: [PATCH 11/29] fix: fix runif tensorrt logic. --- src/lightning/pytorch/utilities/testing/_runif.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/utilities/testing/_runif.py b/src/lightning/pytorch/utilities/testing/_runif.py index 19e267d67e724..9a14f0fe0d71e 100644 --- a/src/lightning/pytorch/utilities/testing/_runif.py +++ b/src/lightning/pytorch/utilities/testing/_runif.py @@ -97,7 +97,7 @@ def _runif_reasons( if onnx and not _ONNX_AVAILABLE: reasons.append("onnx") - if onnx and not _TORCH_TRT_AVAILABLE: + if tensorrt and not _TORCH_TRT_AVAILABLE: reasons.append("torch-tensorrt") return reasons, kwargs From 484f9cedfa41ada370a9f23d4b1a8ff67925589c Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sat, 7 Jun 2025 14:35:04 +0800 Subject: [PATCH 12/29] test: add test `test_missing_tensorrt_package`. --- tests/tests_pytorch/models/test_torch_tensorrt.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/tests_pytorch/models/test_torch_tensorrt.py b/tests/tests_pytorch/models/test_torch_tensorrt.py index 8fc19531dcbea..a09f081a2542a 100644 --- a/tests/tests_pytorch/models/test_torch_tensorrt.py +++ b/tests/tests_pytorch/models/test_torch_tensorrt.py @@ -1,4 +1,5 @@ import os +import re from io import BytesIO from pathlib import Path @@ -6,10 +7,21 @@ import torch import tests_pytorch.helpers.pipelines as pipes +from lightning.pytorch.core.module import _TORCH_TRT_AVAILABLE from lightning.pytorch.demos.boring_classes import BoringModel from tests_pytorch.helpers.runif import RunIf +@pytest.mark.skipif(_TORCH_TRT_AVAILABLE, reason="Run this test only if tensorrt is not available.") +def test_missing_tensorrt_package(): + model = BoringModel() + with pytest.raises( + ModuleNotFoundError, + match=re.escape(f"`{type(model).__name__}.to_tensorrt` requires `torch_tensorrt` to be installed. "), + ): + model.to_tensorrt("model.trt") + + @RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0") def test_tensorrt_saves_with_input_sample(tmp_path): model = BoringModel() From b35c60ca6d1de8093e7d5b9116264841828c3ea5 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sat, 7 Jun 2025 14:56:41 +0800 Subject: [PATCH 13/29] req: remove mac from the tensorrt dependency. --- requirements/pytorch/test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index 76917f737ec8c..222c7e6237ebd 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -18,4 +18,4 @@ uvicorn # for `ServableModuleValidator` # not setting version as re-defined in tensorboard >=2.9.1, <2.20.0 # for `TensorBoardLogger` -torch-tensorrt >=1.4.0, <2.8.0 +torch-tensorrt >=1.4.0, <2.8.0; platform_system != "Darwin" From a9047fe7a5810f11dce22ffbbfb4ed6cf080cfb6 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sat, 7 Jun 2025 17:35:47 +0800 Subject: [PATCH 14/29] fix: fix default device logics. --- src/lightning/pytorch/core/module.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index b554ae907e17a..da8a37560068e 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1548,8 +1548,9 @@ def forward(self, x): device = self.device if self.device.type != "cuda": default_device = torch.device(default_device) if isinstance(default_device, str) else default_device - if default_device.type != "cuda": - raise ValueError( + + if not torch.cuda.is_available() or default_device.type != "cuda": + raise MisconfigurationException( f"TensorRT only supports CUDA devices. The current device is {self.device}." f" Please set the `default_device` argument to a CUDA device." ) From a36907f38d5ad2178eb153faac775281288f4e6a Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sat, 7 Jun 2025 17:36:25 +0800 Subject: [PATCH 15/29] test: add test `test_tensorrt_with_wrong_default_device`. --- tests/tests_pytorch/models/test_torch_tensorrt.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/tests_pytorch/models/test_torch_tensorrt.py b/tests/tests_pytorch/models/test_torch_tensorrt.py index a09f081a2542a..c403b329a7144 100644 --- a/tests/tests_pytorch/models/test_torch_tensorrt.py +++ b/tests/tests_pytorch/models/test_torch_tensorrt.py @@ -9,6 +9,7 @@ import tests_pytorch.helpers.pipelines as pipes from lightning.pytorch.core.module import _TORCH_TRT_AVAILABLE from lightning.pytorch.demos.boring_classes import BoringModel +from lightning.pytorch.utilities.exceptions import MisconfigurationException from tests_pytorch.helpers.runif import RunIf @@ -22,6 +23,15 @@ def test_missing_tensorrt_package(): model.to_tensorrt("model.trt") +@RunIf(tensorrt=True, min_torch="2.2.0") +def test_tensorrt_with_wrong_default_device(tmp_path): + model = BoringModel() + input_sample = torch.randn((1, 32)) + file_path = os.path.join(tmp_path, "model.trt") + with pytest.raises(MisconfigurationException): + model.to_tensorrt(file_path, input_sample, default_device="cpu") + + @RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0") def test_tensorrt_saves_with_input_sample(tmp_path): model = BoringModel() From 937dedf2d1c003096d60ac8d05884d99956b9bdf Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sat, 7 Jun 2025 23:54:23 +0800 Subject: [PATCH 16/29] fix: reorder the import sequence. --- src/lightning/pytorch/core/module.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index da8a37560068e..1a72efeb4d6a1 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -1542,8 +1542,6 @@ def forward(self, x): f"`{type(self).__name__}.to_tensorrt` requires `torch_tensorrt` to be installed. " ) - import torch_tensorrt - mode = self.training device = self.device if self.device.type != "cuda": @@ -1564,6 +1562,9 @@ def forward(self, x): " `model.example_input_array` attribute is set." ) input_sample = self.example_input_array + + import torch_tensorrt + input_sample = copy.deepcopy((input_sample,) if isinstance(input_sample, torch.Tensor) else input_sample) input_sample = self._on_before_batch_transfer(input_sample) input_sample = self._apply_batch_transfer_handler(input_sample) From 15f19619247f670815e8e1771c0020d6ef0b0f8b Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Fri, 13 Jun 2025 11:25:06 +0800 Subject: [PATCH 17/29] feat: add exception when torch is below 2.2.0. --- src/lightning/pytorch/core/module.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 1a72efeb4d6a1..5dec11e5788e2 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -48,6 +48,7 @@ from lightning.fabric.utilities.apply_func import convert_to_tensors from lightning.fabric.utilities.cloud_io import get_filesystem from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2 from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH from lightning.fabric.wrappers import _FabricOptimizer from lightning.pytorch.callbacks.callback import Callback @@ -1536,6 +1537,10 @@ def forward(self, x): ) """ + if not _TORCH_GREATER_EQUAL_2_2: + raise MisconfigurationException( + f"TensorRT export requires PyTorch 2.2 or higher. Current version is {torch.__version__}." + ) if not _TORCH_TRT_AVAILABLE: raise ModuleNotFoundError( From 67546dc974585430b2273ea4ef34ba82d3e8d77f Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Fri, 13 Jun 2025 22:30:00 +0800 Subject: [PATCH 18/29] add unittests. --- tests/tests_pytorch/models/test_torch_tensorrt.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/tests_pytorch/models/test_torch_tensorrt.py b/tests/tests_pytorch/models/test_torch_tensorrt.py index c403b329a7144..3c08459fe367d 100644 --- a/tests/tests_pytorch/models/test_torch_tensorrt.py +++ b/tests/tests_pytorch/models/test_torch_tensorrt.py @@ -13,6 +13,16 @@ from tests_pytorch.helpers.runif import RunIf +@RunIf(max_torch="2.2.0") +def test_torch_minimum_version(): + model = BoringModel() + with pytest.raises( + MisconfigurationException, + match=re.escape(f"TensorRT export requires PyTorch 2.2 or higher. Current version is {torch.__version__}."), + ): + model.to_tensorrt("model.trt") + + @pytest.mark.skipif(_TORCH_TRT_AVAILABLE, reason="Run this test only if tensorrt is not available.") def test_missing_tensorrt_package(): model = BoringModel() From 9d92ddfad1c9929804fad2fc59e94d4f3e686db4 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Thu, 3 Jul 2025 01:19:51 +0800 Subject: [PATCH 19/29] torch-tensorrt deps. --- requirements/pytorch/test.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index 24e8cb896610f..ed23b61d6d743 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -18,4 +18,6 @@ uvicorn # for `ServableModuleValidator` # not setting version as re-defined in tensorboard >=2.9.1, <2.20.0 # for `TensorBoardLogger` -torch-tensorrt >=1.4.0, <2.8.0; platform_system != "Darwin" +torch-tensorrt >=1.4.0, <2.8.0; platform_system == "Linux" and python_version < "3.11" # Initial linux support starts from 1.4.0 +torch-tensorrt >=2.3.0, <2.8.0; platform_system != "Darwin" and python_version < "3.11" # Initial windows support starts from 2.3.0 +torch-tensorrt >=2.5.0, <2.8.0; platform_system != "Darwin" and python_version >= "3.12" From f16deedcdf13d4228a635b5e3a35d6b615d70816 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sun, 20 Jul 2025 18:25:22 +0800 Subject: [PATCH 20/29] fix: fix unittest `test_missing_tensorrt_package` to run only when min_torch is 2.2.0 or above. --- tests/tests_pytorch/models/test_torch_tensorrt.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests_pytorch/models/test_torch_tensorrt.py b/tests/tests_pytorch/models/test_torch_tensorrt.py index 3c08459fe367d..630e59f711348 100644 --- a/tests/tests_pytorch/models/test_torch_tensorrt.py +++ b/tests/tests_pytorch/models/test_torch_tensorrt.py @@ -24,6 +24,7 @@ def test_torch_minimum_version(): @pytest.mark.skipif(_TORCH_TRT_AVAILABLE, reason="Run this test only if tensorrt is not available.") +@RunIf(min_torch="2.2.0") def test_missing_tensorrt_package(): model = BoringModel() with pytest.raises( From 42fafb7de341bbaf5354c5b21b4fb8e44f79ba88 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sat, 9 Aug 2025 16:02:33 +0800 Subject: [PATCH 21/29] revert test dependencies. --- requirements/pytorch/test.txt | 4 ---- 1 file changed, 4 deletions(-) diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index 6b3a3318eb554..dec156747f714 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -17,7 +17,3 @@ fastapi # for `ServableModuleValidator` # not setting version as re-defined in uvicorn # for `ServableModuleValidator` # not setting version as re-defined in App tensorboard >=2.9.1, <2.21.0 # for `TensorBoardLogger` - -torch-tensorrt >=1.4.0, <2.8.0; platform_system == "Linux" and python_version < "3.11" # Initial linux support starts from 1.4.0 -torch-tensorrt >=2.3.0, <2.8.0; platform_system != "Darwin" and python_version < "3.11" # Initial windows support starts from 2.3.0 -torch-tensorrt >=2.5.0, <2.8.0; platform_system != "Darwin" and python_version >= "3.12" From 7016bdb89d21e3d3ec3fe7bc7b600abd85855ba2 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sat, 9 Aug 2025 16:45:09 +0800 Subject: [PATCH 22/29] add the minimum support in torch-tensorrt when testing. --- requirements/pytorch/test.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index dec156747f714..04a70ed67c748 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -17,3 +17,5 @@ fastapi # for `ServableModuleValidator` # not setting version as re-defined in uvicorn # for `ServableModuleValidator` # not setting version as re-defined in App tensorboard >=2.9.1, <2.21.0 # for `TensorBoardLogger` + +torch-tensorrt >=1.4.0; platform_system == "Linux" From 69c87b35e5260911398d287c60040d41c9bc6406 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sat, 9 Aug 2025 17:04:13 +0800 Subject: [PATCH 23/29] add torch-tensorrt support when python is lower then py312. --- requirements/pytorch/test.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index 04a70ed67c748..765822965ac97 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -18,4 +18,5 @@ uvicorn # for `ServableModuleValidator` # not setting version as re-defined in tensorboard >=2.9.1, <2.21.0 # for `TensorBoardLogger` -torch-tensorrt >=1.4.0; platform_system == "Linux" +torch-tensorrt >=1.4.0; platform_system == "Linux" and python_version < "3.11" # Initial linux support starts from 1.4.0 +torch-tensorrt >=2.3.0; platform_system != "Darwin" and python_version < "3.11" # Initial windows support starts from 2.3.0 From 63e3cc72cdce64a43e7d2c7e1aeab3e634010839 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sun, 10 Aug 2025 15:03:23 +0800 Subject: [PATCH 24/29] limit the torch-tensorrt condition again. --- requirements/pytorch/test.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index 765822965ac97..978d5669a0495 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -18,5 +18,4 @@ uvicorn # for `ServableModuleValidator` # not setting version as re-defined in tensorboard >=2.9.1, <2.21.0 # for `TensorBoardLogger` -torch-tensorrt >=1.4.0; platform_system == "Linux" and python_version < "3.11" # Initial linux support starts from 1.4.0 -torch-tensorrt >=2.3.0; platform_system != "Darwin" and python_version < "3.11" # Initial windows support starts from 2.3.0 +torch-tensorrt >=2.3.0; platform_system != "Darwin" and python_version >= "3.11" # Initial windows support starts from 2.3.0 From bf63faaeff4b47e08225747d2f1b4e9036c1dc36 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sun, 10 Aug 2025 15:05:49 +0800 Subject: [PATCH 25/29] limit the torch-tensorrt condition again. --- requirements/pytorch/test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index 978d5669a0495..b8299a787e331 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -18,4 +18,4 @@ uvicorn # for `ServableModuleValidator` # not setting version as re-defined in tensorboard >=2.9.1, <2.21.0 # for `TensorBoardLogger` -torch-tensorrt >=2.3.0; platform_system != "Darwin" and python_version >= "3.11" # Initial windows support starts from 2.3.0 +torch-tensorrt >=2.2.0; platform_system == "Linux" and python_version >= "3.11" # Initial linux support starts from 1.4.0 From 02757a397332a72c5d4617b54a4b5e84b5d934a4 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sun, 10 Aug 2025 16:29:35 +0800 Subject: [PATCH 26/29] docs: add tensorrt description in docstring. --- src/lightning/pytorch/utilities/testing/_runif.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning/pytorch/utilities/testing/_runif.py b/src/lightning/pytorch/utilities/testing/_runif.py index 9a14f0fe0d71e..6effe533a057c 100644 --- a/src/lightning/pytorch/utilities/testing/_runif.py +++ b/src/lightning/pytorch/utilities/testing/_runif.py @@ -65,6 +65,7 @@ def _runif_reasons( psutil: Require that psutil is installed. sklearn: Require that scikit-learn is installed. onnx: Require that onnx is installed. + tensorrt: Require that torch-tensorrt is installed. """ From a2b36cddfffb54aac24543ae35a9dc21b8ea9984 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sun, 10 Aug 2025 16:30:15 +0800 Subject: [PATCH 27/29] loosen the torch-tensorrt version limits. --- requirements/pytorch/test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index b8299a787e331..602b7e949d950 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -18,4 +18,4 @@ uvicorn # for `ServableModuleValidator` # not setting version as re-defined in tensorboard >=2.9.1, <2.21.0 # for `TensorBoardLogger` -torch-tensorrt >=2.2.0; platform_system == "Linux" and python_version >= "3.11" # Initial linux support starts from 1.4.0 +torch-tensorrt; platform_system == "Linux" and python_version >= "3.11" # Initial linux support starts from 1.4.0 From 87b2f0138e787f14d6ab36db6562687945a95908 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Sun, 10 Aug 2025 16:38:18 +0800 Subject: [PATCH 28/29] update tensorrt version. --- requirements/pytorch/test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index 602b7e949d950..d9959c1cb1a2c 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -18,4 +18,4 @@ uvicorn # for `ServableModuleValidator` # not setting version as re-defined in tensorboard >=2.9.1, <2.21.0 # for `TensorBoardLogger` -torch-tensorrt; platform_system == "Linux" and python_version >= "3.11" # Initial linux support starts from 1.4.0 +torch-tensorrt >= 2.5.0; platform_system == "Linux" and python_version >= "3.12" # Initial linux support starts from 1.4.0 From f9821f66d4d5d4de1fcc19853505bace79b30d45 Mon Sep 17 00:00:00 2001 From: GdoongMathew Date: Tue, 12 Aug 2025 18:02:06 +0800 Subject: [PATCH 29/29] Update test.txt description --- requirements/pytorch/test.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index b4e2ed1f98209..023ccf444c44b 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -18,4 +18,4 @@ uvicorn # for `ServableModuleValidator` # not setting version as re-defined in tensorboard >=2.9.1, <2.21.0 # for `TensorBoardLogger` -torch-tensorrt >= 2.5.0; platform_system == "Linux" and python_version >= "3.12" # Initial linux support starts from 1.4.0 +torch-tensorrt >= 2.5.0; platform_system == "Linux" and python_version >= "3.12" # Limit the torch-tensorrt version for testing.