@@ -21,6 +21,9 @@ For example, we can define a LeNet module like this:
21
21
.. code-block :: python
22
22
:linenos:
23
23
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+
24
27
class LeNetFeatExtractor (nn .Module ):
25
28
def __init__ (self ):
26
29
super (LeNetFeatExtractor, self ).__init__ ()
@@ -56,6 +59,7 @@ For example, we can define a LeNet module like this:
56
59
x = self .feat(x)
57
60
x = self .classifer(x)
58
61
return x
62
+
59
63
.
60
64
61
65
Obviously you may want to consolidate such a simple model into a single module but we can see the composability of PyTorch here
@@ -67,15 +71,20 @@ To trace an instance of our LeNet module, we can call ``torch.jit.trace`` with a
67
71
68
72
.. code-block :: python
69
73
74
+ import torch.jit
75
+
70
76
model = LeNet()
71
- traced_model = torch.jit.trace(model, torch.empty([1 ,1 ,32 ,32 ]))
77
+ input_data = torch.empty([1 ,1 ,32 ,32 ])
78
+ traced_model = torch.jit.trace(model, input_data)
72
79
73
80
Scripting actually inspects your code with a compiler and generates an equivalent TorchScript program. The difference is that since tracing
74
81
is following the execution of your module, it cannot pick up control flow for instance. By working from the Python code, the compiler can
75
82
include these components. We can run the script compiler on our LeNet module by calling ``torch.jit.script ``
76
83
77
84
.. code-block :: python
78
85
86
+ import torch.jit
87
+
79
88
model = LeNet()
80
89
script_model = torch.jit.script(model)
81
90
@@ -138,20 +147,23 @@ to load in a deployment application. In order to load a TensorRT/TorchScript mod
138
147
import trtorch
139
148
140
149
...
150
+
151
+ script_model.eval() # torch module needs to be in eval (not training) mode
152
+
141
153
compile_settings = {
142
154
" input_shapes" : [
143
155
{
144
- " min" : [1 , 3 , 224 , 224 ],
145
- " opt" : [1 , 3 , 512 , 512 ],
146
- " max" : [1 , 3 , 1024 , 1024 ]
147
- }, # For static size [1, 3, 224, 224]
156
+ " min" : [1 , 1 , 32 , 32 ],
157
+ " opt" : [1 , 1 , 32 , 32 ],
158
+ " max" : [1 , 1 , 32 , 32 ]
159
+ },
148
160
],
149
- " op_precision" : torch.half # Run with FP16
161
+ " op_precision" : torch.half # Run with fp16
150
162
}
151
163
152
- trt_ts_module = trtorch.compile(torch_script_module , compile_settings)
164
+ trt_ts_module = trtorch.compile(script_model , compile_settings)
153
165
154
- input_data = input_data.half()
166
+ input_data = input_data.to( ' cuda ' ). half()
155
167
result = trt_ts_module(input_data)
156
168
torch.jit.save(trt_ts_module, " trt_ts_module.ts" )
157
169
@@ -162,7 +174,7 @@ to load in a deployment application. In order to load a TensorRT/TorchScript mod
162
174
import trtorch
163
175
164
176
trt_ts_module = torch.jit.load(" trt_ts_module.ts" )
165
- input_data = input_data.half()
177
+ input_data = input_data.to( ' cuda ' ). half()
166
178
result = trt_ts_module(input_data)
167
179
168
180
.. _ts_in_cc :
0 commit comments