Skip to content

Commit 7e4b07c

Browse files
committed
docs: New documentation on the to_backend api integration
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent a720f91 commit 7e4b07c

File tree

3 files changed

+67
-1
lines changed

3 files changed

+67
-1
lines changed

docsrc/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Getting Started
2525
* :ref:`getting_started`
2626
* :ref:`ptq`
2727
* :ref:`trtorchc`
28+
* :ref:`use_from_pytorch`
2829

2930

3031
.. toctree::
@@ -36,6 +37,7 @@ Getting Started
3637
tutorials/getting_started
3738
tutorials/ptq
3839
tutorials/trtorchc
40+
tutorials/use_from_pytorch
3941
_notebooks/lenet
4042

4143
.. toctree::

docsrc/py_api/trtorch.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@ Functions
1717

1818
.. autofunction:: check_method_op_support
1919

20+
.. autofunction:: get_build_info
21+
2022
.. autofunction:: dump_build_info
2123

22-
.. autofunction:: get_build_info
24+
.. autofunction:: TensorRTCompileSpec
2325

2426
Enums
2527
-------

docsrc/tutorials/use_from_pytorch.rst

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
.. _use_from_pytorch:
2+
3+
Using TRTorch Directly From PyTorch
4+
====================================
5+
6+
Starting in TRTorch 0.1.0, you will now be able to directly access TensorRT from PyTorch APIs. The process to use this feature
7+
is very similar to the compilation workflow described in :ref:`getting_started`
8+
9+
Start by loading ``trtorch`` into your application.
10+
11+
.. code-block:: python
12+
13+
import torch
14+
import trtorch
15+
16+
17+
Then given a TorchScript module, you can lower it to TensorRT using the ``torch._C._jit_to_tensorrt`` API.
18+
19+
.. code-block:: python
20+
21+
import torchvision.models as models
22+
23+
model = models.mobilenet_v2(pretrained=True)
24+
script_model = torch.jit.script(model)
25+
26+
Unlike the ``compile`` API in TRTorch which assumes you are trying to compile the ``forward`` function of a module
27+
or the ``convert_method_to_trt_engine`` which converts a specified function to a TensorRT engine, the backend API
28+
will take a dictionary which maps names of functions to compile to Compilation Spec objects which wrap the same
29+
sort of dictionary you would provide to ``compile``. For more information on the compile spec dictionary take a look
30+
at the documentation for the TRTorch ``TensorRTCompileSpec`` API.
31+
32+
.. code-block:: python
33+
34+
spec = {
35+
"forward": trtorch.TensorRTCompileSpec({
36+
"input_shapes": [[1, 3, 300, 300]],
37+
"op_precision": torch.half,
38+
"refit": False,
39+
"debug": False,
40+
"strict_types": False,
41+
"allow_gpu_fallback": True,
42+
"device_type": "gpu",
43+
"capability": trtorch.EngineCapability.default,
44+
"num_min_timing_iters": 2,
45+
"num_avg_timing_iters": 1,
46+
"max_batch_size": 0,
47+
})
48+
}
49+
50+
Now to compile with TRTorch, provide the target module objects and the spec dictionary to ``torch._C._jit_to_tensorrt``
51+
52+
.. code-block:: python
53+
54+
trt_model = torch._C._jit_to_tensorrt(script_model._c, spec)
55+
56+
To run explicitly call the function of the method you want to run (vs. how you can just call on the module itself in standard PyTorch)
57+
58+
.. code-block:: python
59+
60+
input = torch.randn((1, 3, 300, 300).to("cuda").to(torch.half)
61+
print(trt_model.forward(input))
62+

0 commit comments

Comments
 (0)