Skip to content

Commit 534c6c4

Browse files
committed
test: add test_torch_tensorrt.py
1 parent 314c463 commit 534c6c4

File tree

2 files changed

+126
-1
lines changed

2 files changed

+126
-1
lines changed

src/lightning/pytorch/utilities/testing/_runif.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from lightning.fabric.utilities.testing import _runif_reasons as fabric_run_if
1919
from lightning.pytorch.accelerators.cpu import _PSUTIL_AVAILABLE
2020
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
21-
from lightning.pytorch.core.module import _ONNX_AVAILABLE
21+
from lightning.pytorch.core.module import _ONNX_AVAILABLE, _TORCH_TRT_AVAILABLE
2222
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
2323

2424
_SKLEARN_AVAILABLE = RequirementCache("scikit-learn")
@@ -42,6 +42,7 @@ def _runif_reasons(
4242
psutil: bool = False,
4343
sklearn: bool = False,
4444
onnx: bool = False,
45+
tensorrt: bool = False,
4546
) -> tuple[list[str], dict[str, bool]]:
4647
"""Construct reasons for pytest skipif.
4748
@@ -96,4 +97,7 @@ def _runif_reasons(
9697
if onnx and not _ONNX_AVAILABLE:
9798
reasons.append("onnx")
9899

100+
if onnx and not _TORCH_TRT_AVAILABLE:
101+
reasons.append("torch-tensorrt")
102+
99103
return reasons, kwargs
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import os
2+
from io import BytesIO
3+
from pathlib import Path
4+
5+
import pytest
6+
import torch
7+
8+
import tests_pytorch.helpers.pipelines as tpipes
9+
from lightning.pytorch.demos.boring_classes import BoringModel
10+
from tests_pytorch.helpers.runif import RunIf
11+
12+
13+
@RunIf(tensorrt=True, min_cuda_gpus=1)
14+
def test_tensorrt_saves_with_input_sample(tmp_path):
15+
model = BoringModel()
16+
ori_device = model.device
17+
input_sample = torch.randn((1, 32))
18+
19+
file_path = os.path.join(tmp_path, "model.trt")
20+
model.to_tensorrt(file_path, input_sample)
21+
22+
assert os.path.isfile(file_path)
23+
assert os.path.getsize(file_path) > 4e2
24+
assert model.device == ori_device
25+
26+
file_path = Path(tmp_path) / "model.trt"
27+
model.to_tensorrt(file_path, input_sample)
28+
assert os.path.isfile(file_path)
29+
assert os.path.getsize(file_path) > 4e2
30+
assert model.device == ori_device
31+
32+
file_path = BytesIO()
33+
model.to_tensorrt(file_path, input_sample)
34+
assert len(file_path.getvalue()) > 4e2
35+
36+
37+
def test_tensorrt_error_if_no_input(tmp_path):
38+
model = BoringModel()
39+
model.example_input_array = None
40+
file_path = os.path.join(tmp_path, "model.trt")
41+
42+
with pytest.raises(
43+
ValueError,
44+
match=r"Could not export to TensorRT since neither `input_sample` nor "
45+
r"`model.example_input_array` attribute is set.",
46+
):
47+
model.to_tensorrt(file_path)
48+
49+
50+
@RunIf(tensorrt=True, min_cuda_gpus=2)
51+
def test_tensorrt_saves_on_multi_gpu(tmp_path):
52+
trainer_options = {
53+
"default_root_dir": tmp_path,
54+
"max_epochs": 1,
55+
"limit_train_batches": 10,
56+
"limit_val_batches": 10,
57+
"accelerator": "gpu",
58+
"devices": [0, 1],
59+
"strategy": "ddp_spawn",
60+
"enable_progress_bar": False,
61+
}
62+
63+
model = BoringModel()
64+
model.example_input_array = torch.randn((4, 32))
65+
66+
tpipes.run_model_test(trainer_options, model, min_acc=0.08)
67+
68+
file_path = os.path.join(tmp_path, "model.trt")
69+
model.to_tensorrt(file_path)
70+
71+
assert os.path.exists(file_path)
72+
73+
74+
@pytest.mark.parametrize(
75+
("ir", "export_type"),
76+
[
77+
("default", torch.fx.GraphModule),
78+
("dynamo", torch.fx.GraphModule),
79+
("ts", torch.jit.ScriptModule),
80+
],
81+
)
82+
@RunIf(tensorrt=True, min_cuda_gpus=1)
83+
def test_tensorrt_save_ir_type(ir, export_type):
84+
model = BoringModel()
85+
model.example_input_array = torch.randn((4, 32))
86+
87+
ret = model.to_tensorrt(ir=ir)
88+
assert isinstance(ret, export_type)
89+
90+
91+
@pytest.mark.parametrize(
92+
"output_format",
93+
["exported_program", "torchscript"],
94+
)
95+
@pytest.mark.parametrize(
96+
"ir",
97+
["default", "dynamo", "ts"],
98+
)
99+
@RunIf(tensorrt=True, min_cuda_gpus=1)
100+
def test_tensorrt_export_reload(output_format, ir, tmp_path):
101+
import torch_tensorrt
102+
103+
if ir == "ts" and output_format == "exported_program":
104+
pytest.skip("TorchScript cannot be exported as exported_program")
105+
106+
model = BoringModel()
107+
model.cuda().eval()
108+
model.example_input_array = torch.randn((4, 32))
109+
110+
file_path = os.path.join(tmp_path, "model.trt")
111+
model.to_tensorrt(file_path, output_format=output_format, ir=ir)
112+
113+
loaded_model = torch_tensorrt.load(file_path)
114+
if output_format == "exported_program":
115+
loaded_model = loaded_model.module()
116+
117+
with torch.no_grad(), torch.inference_mode():
118+
model_output = model(model.example_input_array.to(model.device))
119+
120+
jit_output = loaded_model(model.example_input_array.to("cuda"))
121+
assert torch.allclose(model_output, jit_output)

0 commit comments

Comments
 (0)