Skip to content

Commit 57c6d46

Browse files
committed
docs: Update the docs to include new device API for to_backend
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]> docs: Update docs for to_backend API for new device API and new PyTorch API Changes the docs to show the new device dictionary API and how to use the new to backend api (changed from PyTorch 1.6.0) Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 0618b6b commit 57c6d46

File tree

1 file changed

+20
-16
lines changed

1 file changed

+20
-16
lines changed

docsrc/tutorials/use_from_pytorch.rst

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -32,31 +32,35 @@ at the documentation for the TRTorch ``TensorRTCompileSpec`` API.
3232
.. code-block:: python
3333
3434
spec = {
35-
"forward": trtorch.TensorRTCompileSpec({
36-
"input_shapes": [[1, 3, 300, 300]],
37-
"op_precision": torch.half,
38-
"refit": False,
39-
"debug": False,
40-
"strict_types": False,
41-
"allow_gpu_fallback": True,
42-
"device_type": "gpu",
43-
"capability": trtorch.EngineCapability.default,
44-
"num_min_timing_iters": 2,
45-
"num_avg_timing_iters": 1,
46-
"max_batch_size": 0,
47-
})
48-
}
35+
"forward":
36+
trtorch.TensorRTCompileSpec({
37+
"input_shapes": [[1, 3, 300, 300]],
38+
"op_precision": torch.half,
39+
"refit": False,
40+
"debug": False,
41+
"strict_types": False,
42+
"device": {
43+
"device_type": trtorch.DeviceType.GPU,
44+
"gpu_id": 0,
45+
"allow_gpu_fallback": True
46+
},
47+
"capability": trtorch.EngineCapability.default,
48+
"num_min_timing_iters": 2,
49+
"num_avg_timing_iters": 1,
50+
"max_batch_size": 0,
51+
})
52+
}
4953
5054
Now to compile with TRTorch, provide the target module objects and the spec dictionary to ``torch._C._jit_to_tensorrt``
5155

5256
.. code-block:: python
5357
54-
trt_model = torch._C._jit_to_tensorrt(script_model._c, spec)
58+
trt_model = torch._C._jit_to_backend("tensorrt", script_model, spec)
5559
5660
To run explicitly call the function of the method you want to run (vs. how you can just call on the module itself in standard PyTorch)
5761

5862
.. code-block:: python
5963
60-
input = torch.randn((1, 3, 300, 300).to("cuda").to(torch.half)
64+
input = torch.randn((1, 3, 300, 300)).to("cuda").to(torch.half)
6165
print(trt_model.forward(input))
6266

0 commit comments

Comments
 (0)