You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -5,6 +5,10 @@ Dynamic shapes with Torch-TensorRT
5
5
6
6
By default, you can run a pytorch model with varied input shapes and the output shapes are determined eagerly.
7
7
However, Torch-TensorRT is an AOT compiler which requires some prior information about the input shapes to compile and optimize the model.
8
+
9
+
Dynamic shapes using torch.export (AOT)
10
+
------------------------------------
11
+
8
12
In the case of dynamic input shapes, we must provide the (min_shape, opt_shape, max_shape) arguments so that the model can be optimized for
9
13
this range of input shapes. An example usage of static and dynamic shapes is as follows.
10
14
@@ -30,168 +34,57 @@ Under the hood
30
34
31
35
There are two phases of compilation when we use ``torch_tensorrt.compile`` API with ``ir=dynamo`` (default).
32
36
33
-
- aten_tracer.trace (which uses torch.export to trace the graph with the given inputs)
37
+
- torch_tensorrt.dynamo.trace (which uses torch.export to trace the graph with the given inputs)
34
38
35
-
In the tracing phase, we use torch.export along with the constraints. In the case of
36
-
dynamic shaped inputs, the range can be provided to the tracing via constraints. Please
37
-
refer to this `docstring <https://github.com/pytorch/pytorch/blob/5dcee01c2b89f6bedeef9dd043fd8d6728286582/torch/export/__init__.py#L372-L434>`_
38
-
for detailed information on how to set constraints. In short, we create new inputs for
39
-
torch.export tracing and provide constraints on the min and max values(provided by the user), a particular dimension can take.
40
-
Please take a look at ``aten_tracer.py`` file to understand how this works under the hood.
39
+
We use ``torch.export.export()`` API for tracing and exporting a PyTorch module into ``torch.export.ExportedProgram``. In the case of
40
+
dynamic shaped inputs, the ``(min_shape, opt_shape, max_shape)`` range provided via ``torch_tensorrt.Input`` API is used to construct ``torch.export.Dim`` objects
41
+
which is used in the ``dynamic_shapes`` argument for the export API.
42
+
Please take a look at ``_tracer.py`` file to understand how this works under the hood.
41
43
42
-
- dynamo.compile (which compiles a torch.fx.GraphModule object using TensorRT)
44
+
- torch_tensorrt.dynamo.compile (which compiles an torch.export.ExportedProgram object using TensorRT)
43
45
44
-
In the conversion to TensorRT, we use the user provided dynamic shape inputs.
45
-
We perform shape analysis using dummy inputs (across min, opt and max shapes) and store the
46
-
intermediate output shapes which can be used in case the graph has a mix of Pytorch
47
-
and TensorRT submodules.
46
+
In the conversion to TensorRT, the graph already has the dynamic shape information in the node's metadata which will be used during engine building phase.
48
47
49
-
Custom Constraints
50
-
------------------
48
+
Custom Dynamic Shape Constraints
49
+
---------------------------------
51
50
52
51
Given an input ``x = torch_tensorrt.Input(min_shape, opt_shape, max_shape, dtype)``,
53
-
Torch-TensorRT automatically sets the constraints during ``torch.export`` tracing as follows
Sometimes, we might need to set additional constraints and Torchdynamo errors out if we don't specify them.
64
-
For example, in the case of BERT model compilation, there are two inputs and a constraint has to be set involving the sequence length size of these two inputs.
If there are operations in the graph that use the dynamic dimension of the input, Pytorch
114
-
introduces ``torch.ops.aten.sym_size.int`` ops in the graph. Currently, we cannot handle these operators and
115
-
the compilation results in undefined behavior. We plan to add support for these operators and implement
116
-
robust support for shape tensors in the next release. Here is an example of the limitation described above
52
+
Torch-TensorRT attempts to automatically set the constraints during ``torch.export`` tracing by constructing
53
+
`torch.export.Dim` objects with the provided dynamic dimensions accordingly. Sometimes, we might need to set additional constraints and Torchdynamo errors out if we don't specify them.
54
+
If you have to set any custom constraints to your model (by using `torch.export.Dim`), we recommend exporting your program first before compiling with Torch-TensorRT.
55
+
Please refer to this `documentation <https://pytorch.org/tutorials/intermediate/torch_export_tutorial.html#constraints-dynamic-shapes>`_ to export the Pytorch module with dynamic shapes.
56
+
Here's a simple example that exports a matmul layer with some restrictions on dynamic dimensions.
``torch_tensorrt.compile(model, inputs, ir="torch_compile")`` returns a torch.compile boxed function with the backend
192
-
configured to Tensorrt. In the case of ``ir=torch_compile``, when the input size changes, Dynamo will trigger a recompilation
193
-
of the TensorRT engine automatically giving dynamic shape behavior similar to native PyTorch eager however with the cost of rebuilding
194
-
TRT engine. This limitation will be addressed in future versions of Torch-TensorRT.
86
+
configured to TensorRT. In the case of ``ir=torch_compile``, users can provide dynamic shape information for the inputs using ``torch._dynamo.mark_dynamic`` API (https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html)
87
+
to avoid recompilation of TensorRT engines.
195
88
196
89
.. code-block:: python
197
90
@@ -200,10 +93,12 @@ TRT engine. This limitation will be addressed in future versions of Torch-Tensor
description="Load pre-trained VGG model and then tune with FP8 and PTQ"
103
+
description="Load pre-trained VGG model and then tune with FP8 and PTQ. For having a pre-trained VGG model, please refer to https://github.com/pytorch/TensorRT/tree/main/examples/int8/training/vgg16"
104
104
)
105
105
PARSER.add_argument(
106
106
"--ckpt", type=str, required=True, help="Path to the pre-trained checkpoint"
@@ -226,6 +226,8 @@ def calibrate_loop(model):
226
226
min_block_size=1,
227
227
debug=False,
228
228
)
229
+
# You can also use torch compile path to compile the model with Torch-TensorRT:
0 commit comments