Skip to content

Commit 314c463

Browse files
committed
refactor: fix to_tensorrt impl
1 parent 14b9f29 commit 314c463

File tree

1 file changed

+42
-17
lines changed

1 file changed

+42
-17
lines changed

src/lightning/pytorch/core/module.py

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
# limitations under the License.
1414
"""The LightningModule - an nn.Module with many additional features."""
1515

16+
import copy
1617
import logging
1718
import numbers
1819
import weakref
1920
from collections.abc import Generator, Mapping, Sequence
20-
from contextlib import contextmanager
21+
from contextlib import contextmanager, nullcontext
2122
from io import BytesIO
2223
from pathlib import Path
2324
from typing import (
@@ -1494,20 +1495,23 @@ def forward(self, x):
14941495
def to_tensorrt(
14951496
self,
14961497
file_path: str | Path | BytesIO | None = None,
1497-
inputs: Any | None = None,
1498+
input_sample: Any | None = None,
14981499
ir: Literal["default", "dynamo", "ts"] = "default",
14991500
output_format: Literal["exported_program", "torchscript"] = "exported_program",
15001501
retrace: bool = False,
1502+
default_device: str | torch.device = "cuda",
15011503
**compile_kwargs,
1502-
) -> torch.ScriptModule | torch.fx.GraphModule:
1504+
) -> ScriptModule | torch.fx.GraphModule:
15031505
"""Export the model to ScriptModule or GraphModule using TensorRT compile backend.
15041506
15051507
Args:
15061508
file_path: Path where to save the tensorrt model. Default: None (no file saved).
1507-
inputs: inputs to be used during `torch_tensorrt.compile`. Default: None (Use self.example_input_array).
1509+
input_sample: inputs to be used during `torch_tensorrt.compile`.
1510+
Default: None (Use :attr:`example_input_array`).
15081511
ir: The IR mode to use for TensorRT compilation. Default: "default".
15091512
output_format: The format of the output model. Default: "exported_program".
15101513
retrace: Whether to retrace the model. Default: False.
1514+
default_device: The device to use for the model when the current model is not in CUDA. Default: "cuda".
15111515
**compile_kwargs: Additional arguments that will be passed to the TensorRT compile function.
15121516
15131517
Example::
@@ -1537,27 +1541,48 @@ def forward(self, x):
15371541
import torch_tensorrt
15381542

15391543
mode = self.training
1544+
device = self.device
1545+
if self.device.type != "cuda":
1546+
default_device = torch.device(default_device) if isinstance(default_device, str) else default_device
1547+
if default_device.type != "cuda":
1548+
raise ValueError(
1549+
f"TensorRT only supports CUDA devices. The current device is {self.device}."
1550+
f" Please set the `default_device` argument to a CUDA device."
1551+
)
1552+
1553+
self.to(default_device)
15401554

1541-
if inputs is None:
1555+
if input_sample is None:
15421556
if self.example_input_array is None:
1543-
raise ValueError("Please provide an example input for the model.")
1544-
inputs = self.example_input_array
1545-
inputs = self._on_before_batch_transfer(inputs)
1546-
inputs = self._apply_batch_transfer_handler(inputs)
1547-
1548-
trt_obj = torch_tensorrt.compile(
1549-
module=self.eval(),
1550-
ir=ir,
1551-
inputs=inputs,
1552-
**compile_kwargs,
1553-
)
1557+
raise ValueError(
1558+
"Could not export to TensorRT since neither `input_sample` nor"
1559+
" `model.example_input_array` attribute is set."
1560+
)
1561+
input_sample = self.example_input_array
1562+
input_sample = copy.deepcopy((input_sample,) if isinstance(input_sample, torch.Tensor) else input_sample)
1563+
input_sample = self._on_before_batch_transfer(input_sample)
1564+
input_sample = self._apply_batch_transfer_handler(input_sample)
1565+
1566+
with _jit_is_scripting() if ir == "ts" else nullcontext():
1567+
trt_obj = torch_tensorrt.compile(
1568+
module=self.eval(),
1569+
ir=ir,
1570+
inputs=input_sample,
1571+
**compile_kwargs,
1572+
)
15541573
self.train(mode)
1574+
self.to(device)
15551575

15561576
if file_path is not None:
1577+
if ir == "ts" and output_format != "torchscript":
1578+
raise ValueError(
1579+
"TensorRT with IR mode 'ts' only supports output format 'torchscript'."
1580+
f" The current output format is {output_format}."
1581+
)
15571582
torch_tensorrt.save(
15581583
trt_obj,
15591584
file_path,
1560-
inputs=inputs,
1585+
inputs=input_sample,
15611586
output_format=output_format,
15621587
retrace=retrace,
15631588
)

0 commit comments

Comments
 (0)