|
76 | 76 | from torch.distributed.device_mesh import DeviceMesh
|
77 | 77 |
|
78 | 78 | _ONNX_AVAILABLE = RequirementCache("onnx")
|
| 79 | +_TORCH_TRT_AVAILABLE = RequirementCache("torch_tensorrt") |
79 | 80 |
|
80 | 81 | warning_cache = WarningCache()
|
81 | 82 | log = logging.getLogger(__name__)
|
@@ -1489,6 +1490,79 @@ def forward(self, x):
|
1489 | 1490 |
|
1490 | 1491 | return torchscript_module
|
1491 | 1492 |
|
| 1493 | + @torch.no_grad() |
| 1494 | + def to_tensorrt( |
| 1495 | + self, |
| 1496 | + file_path: str | Path | BytesIO | None = None, |
| 1497 | + inputs: Any | None = None, |
| 1498 | + ir: Literal["default", "dynamo", "ts"] = "default", |
| 1499 | + output_format: Literal["exported_program", "torchscript"] = "exported_program", |
| 1500 | + retrace: bool = False, |
| 1501 | + **compile_kwargs, |
| 1502 | + ) -> torch.ScriptModule | torch.fx.GraphModule: |
| 1503 | + """Export the model to ScriptModule or GraphModule using TensorRT compile backend. |
| 1504 | +
|
| 1505 | + Args: |
| 1506 | + file_path: Path where to save the tensorrt model. Default: None (no file saved). |
| 1507 | + inputs: inputs to be used during `torch_tensorrt.compile`. Default: None (Use self.example_input_array). |
| 1508 | + ir: The IR mode to use for TensorRT compilation. Default: "default". |
| 1509 | + output_format: The format of the output model. Default: "exported_program". |
| 1510 | + retrace: Whether to retrace the model. Default: False. |
| 1511 | + **compile_kwargs: Additional arguments that will be passed to the TensorRT compile function. |
| 1512 | +
|
| 1513 | + Example:: |
| 1514 | +
|
| 1515 | + class SimpleModel(LightningModule): |
| 1516 | + def __init__(self): |
| 1517 | + super().__init__() |
| 1518 | + self.l1 = torch.nn.Linear(in_features=64, out_features=4) |
| 1519 | +
|
| 1520 | + def forward(self, x): |
| 1521 | + return torch.relu(self.l1(x.view(x.size(0), -1) |
| 1522 | +
|
| 1523 | + model = SimpleModel() |
| 1524 | + input_sample = torch.randn(1, 64) |
| 1525 | + exported_program = model.to_tensorrt( |
| 1526 | + file_path="export.ep", |
| 1527 | + inputs=input_sample, |
| 1528 | + ) |
| 1529 | +
|
| 1530 | + """ |
| 1531 | + |
| 1532 | + if not _TORCH_TRT_AVAILABLE: |
| 1533 | + raise ModuleNotFoundError( |
| 1534 | + f"`{type(self).__name__}.to_tensorrt` requires `torch_tensorrt` to be installed. " |
| 1535 | + ) |
| 1536 | + |
| 1537 | + import torch_tensorrt |
| 1538 | + |
| 1539 | + mode = self.training |
| 1540 | + |
| 1541 | + if inputs is None: |
| 1542 | + if self.example_input_array is None: |
| 1543 | + raise ValueError("Please provide an example input for the model.") |
| 1544 | + inputs = self.example_input_array |
| 1545 | + inputs = self._on_before_batch_transfer(inputs) |
| 1546 | + inputs = self._apply_batch_transfer_handler(inputs) |
| 1547 | + |
| 1548 | + trt_obj = torch_tensorrt.compile( |
| 1549 | + module=self.eval(), |
| 1550 | + ir=ir, |
| 1551 | + inputs=inputs, |
| 1552 | + **compile_kwargs, |
| 1553 | + ) |
| 1554 | + self.train(mode) |
| 1555 | + |
| 1556 | + if file_path is not None: |
| 1557 | + torch_tensorrt.save( |
| 1558 | + trt_obj, |
| 1559 | + file_path, |
| 1560 | + inputs=inputs, |
| 1561 | + output_format=output_format, |
| 1562 | + retrace=retrace, |
| 1563 | + ) |
| 1564 | + return trt_obj |
| 1565 | + |
1492 | 1566 | @_restricted_classmethod
|
1493 | 1567 | def load_from_checkpoint(
|
1494 | 1568 | cls,
|
|
0 commit comments