|
13 | 13 | # limitations under the License.
|
14 | 14 | """The LightningModule - an nn.Module with many additional features."""
|
15 | 15 |
|
| 16 | +import copy |
16 | 17 | import logging
|
17 | 18 | import numbers
|
18 | 19 | import weakref
|
19 | 20 | from collections.abc import Generator, Mapping, Sequence
|
20 |
| -from contextlib import contextmanager |
| 21 | +from contextlib import contextmanager, nullcontext |
21 | 22 | from io import BytesIO
|
22 | 23 | from pathlib import Path
|
23 | 24 | from typing import (
|
@@ -1494,20 +1495,23 @@ def forward(self, x):
|
1494 | 1495 | def to_tensorrt(
|
1495 | 1496 | self,
|
1496 | 1497 | file_path: str | Path | BytesIO | None = None,
|
1497 |
| - inputs: Any | None = None, |
| 1498 | + input_sample: Any | None = None, |
1498 | 1499 | ir: Literal["default", "dynamo", "ts"] = "default",
|
1499 | 1500 | output_format: Literal["exported_program", "torchscript"] = "exported_program",
|
1500 | 1501 | retrace: bool = False,
|
| 1502 | + default_device: str | torch.device = "cuda", |
1501 | 1503 | **compile_kwargs,
|
1502 |
| - ) -> torch.ScriptModule | torch.fx.GraphModule: |
| 1504 | + ) -> ScriptModule | torch.fx.GraphModule: |
1503 | 1505 | """Export the model to ScriptModule or GraphModule using TensorRT compile backend.
|
1504 | 1506 |
|
1505 | 1507 | Args:
|
1506 | 1508 | file_path: Path where to save the tensorrt model. Default: None (no file saved).
|
1507 |
| - inputs: inputs to be used during `torch_tensorrt.compile`. Default: None (Use self.example_input_array). |
| 1509 | + input_sample: inputs to be used during `torch_tensorrt.compile`. |
| 1510 | + Default: None (Use :attr:`example_input_array`). |
1508 | 1511 | ir: The IR mode to use for TensorRT compilation. Default: "default".
|
1509 | 1512 | output_format: The format of the output model. Default: "exported_program".
|
1510 | 1513 | retrace: Whether to retrace the model. Default: False.
|
| 1514 | + default_device: The device to use for the model when the current model is not in CUDA. Default: "cuda". |
1511 | 1515 | **compile_kwargs: Additional arguments that will be passed to the TensorRT compile function.
|
1512 | 1516 |
|
1513 | 1517 | Example::
|
@@ -1537,27 +1541,48 @@ def forward(self, x):
|
1537 | 1541 | import torch_tensorrt
|
1538 | 1542 |
|
1539 | 1543 | mode = self.training
|
| 1544 | + device = self.device |
| 1545 | + if self.device.type != "cuda": |
| 1546 | + default_device = torch.device(default_device) if isinstance(default_device, str) else default_device |
| 1547 | + if default_device.type != "cuda": |
| 1548 | + raise ValueError( |
| 1549 | + f"TensorRT only supports CUDA devices. The current device is {self.device}." |
| 1550 | + f" Please set the `default_device` argument to a CUDA device." |
| 1551 | + ) |
| 1552 | + |
| 1553 | + self.to(default_device) |
1540 | 1554 |
|
1541 |
| - if inputs is None: |
| 1555 | + if input_sample is None: |
1542 | 1556 | if self.example_input_array is None:
|
1543 |
| - raise ValueError("Please provide an example input for the model.") |
1544 |
| - inputs = self.example_input_array |
1545 |
| - inputs = self._on_before_batch_transfer(inputs) |
1546 |
| - inputs = self._apply_batch_transfer_handler(inputs) |
1547 |
| - |
1548 |
| - trt_obj = torch_tensorrt.compile( |
1549 |
| - module=self.eval(), |
1550 |
| - ir=ir, |
1551 |
| - inputs=inputs, |
1552 |
| - **compile_kwargs, |
1553 |
| - ) |
| 1557 | + raise ValueError( |
| 1558 | + "Could not export to TensorRT since neither `input_sample` nor" |
| 1559 | + " `model.example_input_array` attribute is set." |
| 1560 | + ) |
| 1561 | + input_sample = self.example_input_array |
| 1562 | + input_sample = copy.deepcopy((input_sample,) if isinstance(input_sample, torch.Tensor) else input_sample) |
| 1563 | + input_sample = self._on_before_batch_transfer(input_sample) |
| 1564 | + input_sample = self._apply_batch_transfer_handler(input_sample) |
| 1565 | + |
| 1566 | + with _jit_is_scripting() if ir == "ts" else nullcontext(): |
| 1567 | + trt_obj = torch_tensorrt.compile( |
| 1568 | + module=self.eval(), |
| 1569 | + ir=ir, |
| 1570 | + inputs=input_sample, |
| 1571 | + **compile_kwargs, |
| 1572 | + ) |
1554 | 1573 | self.train(mode)
|
| 1574 | + self.to(device) |
1555 | 1575 |
|
1556 | 1576 | if file_path is not None:
|
| 1577 | + if ir == "ts" and output_format != "torchscript": |
| 1578 | + raise ValueError( |
| 1579 | + "TensorRT with IR mode 'ts' only supports output format 'torchscript'." |
| 1580 | + f" The current output format is {output_format}." |
| 1581 | + ) |
1557 | 1582 | torch_tensorrt.save(
|
1558 | 1583 | trt_obj,
|
1559 | 1584 | file_path,
|
1560 |
| - inputs=inputs, |
| 1585 | + inputs=input_sample, |
1561 | 1586 | output_format=output_format,
|
1562 | 1587 | retrace=retrace,
|
1563 | 1588 | )
|
|
0 commit comments