Skip to content

Commit c4d2263

Browse files
committed
fix: fix bug in torch-tensorrt 2.8.0
1 parent a1bfa06 commit c4d2263

File tree

2 files changed

+20
-13
lines changed

2 files changed

+20
-13
lines changed

src/lightning/pytorch/core/module.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1661,18 +1661,26 @@ def forward(self, x):
16611661
self.to(device)
16621662

16631663
if file_path is not None:
1664-
if ir == "ts" and output_format != "torchscript":
1665-
raise ValueError(
1666-
"TensorRT with IR mode 'ts' only supports output format 'torchscript'."
1667-
f" The current output format is {output_format}."
1664+
if ir == "ts":
1665+
if output_format != "torchscript":
1666+
raise ValueError(
1667+
"TensorRT with IR mode 'ts' only supports output format 'torchscript'."
1668+
f" The current output format is {output_format}."
1669+
)
1670+
assert isinstance(trt_obj, (torch.jit.ScriptModule, torch.jit.ScriptFunction)), (
1671+
f"Expected TensorRT object to be a ScriptModule, but got {type(trt_obj)}."
1672+
)
1673+
# Because of https://github.com/pytorch/TensorRT/issues/3775,
1674+
# we'll need to take special care for the ScriptModule
1675+
torch.jit.save(trt_obj, file_path)
1676+
else:
1677+
torch_tensorrt.save(
1678+
trt_obj,
1679+
file_path,
1680+
inputs=input_sample,
1681+
output_format=output_format,
1682+
retrace=retrace,
16681683
)
1669-
torch_tensorrt.save(
1670-
trt_obj,
1671-
file_path,
1672-
inputs=input_sample,
1673-
output_format=output_format,
1674-
retrace=retrace,
1675-
)
16761684
return trt_obj
16771685

16781686
@_restricted_classmethod

tests/tests_pytorch/models/test_torch_tensorrt.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,8 @@ def test_tensorrt_save_ir_type(ir, export_type):
130130
"ir",
131131
["default", "dynamo", "ts"],
132132
)
133-
@RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0", max_torch="2.8.0")
133+
@RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0")
134134
def test_tensorrt_export_reload(output_format, ir, tmp_path):
135-
# todo remove max_torch once https://github.com/pytorch/TensorRT/issues/3775 is fixed
136135
if ir == "ts" and output_format == "exported_program":
137136
pytest.skip("TorchScript cannot be exported as exported_program")
138137

0 commit comments

Comments
 (0)