Skip to content

Commit 14b9f29

Browse files
committed
feat: add to_tensorrt in the LightningModule.
1 parent 23f02ce commit 14b9f29

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

src/lightning/pytorch/core/module.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
from torch.distributed.device_mesh import DeviceMesh
7777

7878
_ONNX_AVAILABLE = RequirementCache("onnx")
79+
_TORCH_TRT_AVAILABLE = RequirementCache("torch_tensorrt")
7980

8081
warning_cache = WarningCache()
8182
log = logging.getLogger(__name__)
@@ -1489,6 +1490,79 @@ def forward(self, x):
14891490

14901491
return torchscript_module
14911492

1493+
@torch.no_grad()
1494+
def to_tensorrt(
1495+
self,
1496+
file_path: str | Path | BytesIO | None = None,
1497+
inputs: Any | None = None,
1498+
ir: Literal["default", "dynamo", "ts"] = "default",
1499+
output_format: Literal["exported_program", "torchscript"] = "exported_program",
1500+
retrace: bool = False,
1501+
**compile_kwargs,
1502+
) -> torch.ScriptModule | torch.fx.GraphModule:
1503+
"""Export the model to ScriptModule or GraphModule using TensorRT compile backend.
1504+
1505+
Args:
1506+
file_path: Path where to save the tensorrt model. Default: None (no file saved).
1507+
inputs: inputs to be used during `torch_tensorrt.compile`. Default: None (Use self.example_input_array).
1508+
ir: The IR mode to use for TensorRT compilation. Default: "default".
1509+
output_format: The format of the output model. Default: "exported_program".
1510+
retrace: Whether to retrace the model. Default: False.
1511+
**compile_kwargs: Additional arguments that will be passed to the TensorRT compile function.
1512+
1513+
Example::
1514+
1515+
class SimpleModel(LightningModule):
1516+
def __init__(self):
1517+
super().__init__()
1518+
self.l1 = torch.nn.Linear(in_features=64, out_features=4)
1519+
1520+
def forward(self, x):
1521+
return torch.relu(self.l1(x.view(x.size(0), -1)
1522+
1523+
model = SimpleModel()
1524+
input_sample = torch.randn(1, 64)
1525+
exported_program = model.to_tensorrt(
1526+
file_path="export.ep",
1527+
inputs=input_sample,
1528+
)
1529+
1530+
"""
1531+
1532+
if not _TORCH_TRT_AVAILABLE:
1533+
raise ModuleNotFoundError(
1534+
f"`{type(self).__name__}.to_tensorrt` requires `torch_tensorrt` to be installed. "
1535+
)
1536+
1537+
import torch_tensorrt
1538+
1539+
mode = self.training
1540+
1541+
if inputs is None:
1542+
if self.example_input_array is None:
1543+
raise ValueError("Please provide an example input for the model.")
1544+
inputs = self.example_input_array
1545+
inputs = self._on_before_batch_transfer(inputs)
1546+
inputs = self._apply_batch_transfer_handler(inputs)
1547+
1548+
trt_obj = torch_tensorrt.compile(
1549+
module=self.eval(),
1550+
ir=ir,
1551+
inputs=inputs,
1552+
**compile_kwargs,
1553+
)
1554+
self.train(mode)
1555+
1556+
if file_path is not None:
1557+
torch_tensorrt.save(
1558+
trt_obj,
1559+
file_path,
1560+
inputs=inputs,
1561+
output_format=output_format,
1562+
retrace=retrace,
1563+
)
1564+
return trt_obj
1565+
14921566
@_restricted_classmethod
14931567
def load_from_checkpoint(
14941568
cls,

0 commit comments

Comments
 (0)