Skip to content

Commit 9d6dab6

Browse files
authored
hook torch.jit.trace (#95)
1 parent 5ce9c59 commit 9d6dab6

File tree

3 files changed

+112
-71
lines changed

3 files changed

+112
-71
lines changed

intel_pytorch_extension_py/ops/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
from .reshape import *
66
from .mlp import *
77
from .linear_fuse_relu import *
8-
from .jit_script import *
8+
from .jit import *

intel_pytorch_extension_py/ops/jit_script.py renamed to intel_pytorch_extension_py/ops/jit.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
torch._C._jit_set_profiling_executor(False)
88

99
orig_script = torch.jit.script
10+
orig_trace = torch.jit.trace
1011

1112
def script_(obj, optimize=None, _frames_up=0, _rcb=None):
1213
torch.jit.script = orig_script
@@ -20,8 +21,21 @@ def script_(obj, optimize=None, _frames_up=0, _rcb=None):
2021
ipex.enable_auto_mix_precision(None)
2122
jit_m = wrap_cpp_module(torch._C._jit_pass_fold_convbn(jit_m._c))
2223
ipex.enable_auto_mix_precision(orig_mixed_type)
24+
return jit_m
25+
26+
def trace_(func, example_inputs, *args, **kwargs):
27+
# Disable mix precision. torch.jit.trace will check the traced output
28+
# against what is expected. Since mix precision will lead to
29+
# loss of accuracy, this will raise warning during torch.jit.trace
30+
orig_mixed_type = ipex.get_auto_mix_precision()
31+
ipex.enable_auto_mix_precision(None)
32+
jit_m = orig_trace(func, example_inputs, *args, **kwargs)
2333

34+
if core.get_jit_opt():
35+
jit_m = wrap_cpp_module(torch._C._jit_pass_fold_convbn(jit_m._c))
36+
ipex.enable_auto_mix_precision(orig_mixed_type)
2437
return jit_m
2538

2639

2740
torch.jit.script = script_
41+
torch.jit.trace = trace_

tests/cpu/test_jit.py

Lines changed: 97 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def forward(self, x):
152152

153153
class Tester(TestCase):
154154

155-
def _test_output(self, model, x, kind=None):
155+
def _test_output(self, model, x, kind_in_graph=None, kind_not_in_graph=None):
156156
modelName = model.__class__.__name__
157157
core.disable_jit_opt()
158158
core.disable_mix_bf16_fp32()
@@ -164,180 +164,207 @@ def _test_output(self, model, x, kind=None):
164164

165165
script_model = torch.jit.script(model)
166166
script_model.eval()
167+
168+
trace_model = torch.jit.trace(model, x)
169+
trace_model.eval()
167170
with torch.no_grad():
168171
sresult = script_model(x)
172+
tresult = trace_model(x)
169173

170174
self.assertEqual(result, sresult)
175+
self.assertEqual(result, tresult)
171176

172177
core.enable_jit_opt()
173-
fused_model = torch.jit.script(model)
178+
script_fused_model = torch.jit.script(model)
179+
trace_fused_model = torch.jit.trace(model, x)
174180
with torch.no_grad():
175181
# conv relu fusion, conv sum fusion or conv sum relu fusion
176-
graph = fused_model.graph_for(x)
177-
# print(graph)
178-
fresult = fused_model(x)
182+
script_graph = script_fused_model.graph_for(x)
183+
fused_sresult = script_fused_model(x)
179184

180-
# print(result)
181-
# print(sresult)
182-
# print(fresult)
185+
trace_graph = trace_fused_model.graph_for(x)
186+
fused_tresult = trace_fused_model(x)
183187

184-
self.assertEqual(result, fresult)
188+
self.assertEqual(result, fused_sresult)
189+
self.assertEqual(result, fused_tresult)
185190

186191
# check if the fused node exists in the graph
187-
if kind is not None:
188-
self.assertTrue(any(n.kind() == kind for n in graph.nodes()))
192+
if kind_in_graph is not None:
193+
self.assertTrue(any(n.kind() == kind_in_graph for n in script_graph.nodes()))
194+
self.assertTrue(any(n.kind() == kind_in_graph for n in trace_graph.nodes()))
195+
196+
# check if certain node does not exist in the graph
197+
if kind_not_in_graph is not None:
198+
self.assertTrue(all(n.kind() != kind_not_in_graph for n in script_graph.nodes()))
199+
self.assertTrue(all(n.kind() != kind_not_in_graph for n in trace_graph.nodes()))
200+
189201

190-
def _test_output_bf16(self, model, x, kind=None, prec=None):
202+
def _test_output_bf16(self, model, x, kind_in_graph=None, kind_not_in_graph=None, prec=None):
191203
modelName = model.__class__.__name__
192204

193205
core.enable_auto_dnnl()
194206
core.enable_jit_opt()
195-
core.disable_mix_bf16_fp32()
196-
207+
core.enable_mix_bf16_fp32()
208+
197209
model = model.to(ipex.DEVICE).eval()
198210
x = x.to(ipex.DEVICE)
199211
x2 = x.clone()
212+
x3 = x.clone()
213+
214+
script_fused_model = torch.jit.script(copy.deepcopy(model))
215+
trace_fused_model = torch.jit.trace(copy.deepcopy(model), x3)
200216

201-
fused_model = torch.jit.script(copy.deepcopy(model))
202-
203-
# bn folding, removing it after solve some issue, using mix_preci? to check
204-
core.disable_auto_dnnl()
205-
fused_model = wrap_cpp_module(torch._C._jit_pass_fold_convbn(fused_model._c))
206-
core.enable_auto_dnnl()
207-
208-
core.enable_mix_bf16_fp32()
209217

210218
with torch.no_grad():
211219
# bf16, native path
212220
result = model(x)
213-
# bf16, jit path
214-
graph = fused_model.graph_for(x2)
215-
# print(graph)
216-
fresult = fused_model(x2)
217-
218-
#print(result)
219-
#print(fresult)
220-
221-
self.assertEqual(fresult, result, prec=prec)
221+
# bf16, jit script path
222+
script_graph = script_fused_model.graph_for(x2)
223+
fused_sresult = script_fused_model(x2)
224+
# bf 16, jit trace path
225+
trace_graph = trace_fused_model.graph_for(x3)
226+
fused_tresult = trace_fused_model(x3)
227+
228+
# disable mix_bf16_fp32 when the calculation is done
229+
# to avoid affecting other scripts
230+
core.disable_mix_bf16_fp32()
231+
232+
self.assertEqual(fused_sresult, result, prec=prec)
233+
self.assertEqual(fused_tresult, result, prec=prec)
222234

223235
# check if the fused node exists in the graph
224-
if kind is not None:
225-
self.assertTrue(any(n.kind() == kind for n in graph.nodes()))
226-
236+
if kind_in_graph is not None:
237+
self.assertTrue(any(n.kind() == kind_in_graph for n in script_graph.nodes()))
238+
self.assertTrue(any(n.kind() == kind_in_graph for n in trace_graph.nodes()))
239+
240+
# check if certain node does not exist in the graph
241+
if kind_not_in_graph is not None:
242+
self.assertTrue(all(n.kind() != kind_not_in_graph for n in script_graph.nodes()))
243+
self.assertTrue(all(n.kind() != kind_not_in_graph for n in trace_graph.nodes()))
244+
227245

228246
def test_output_conv_bn_2d(self):
229247
self._test_output(
230248
ConvBatchNorm_Fixed(2, 3, 32, kernel_size=3, stride=1),
231-
torch.randn(32, 3, 224, 224),
232-
kind="aten::conv2d")
249+
torch.randn(32, 3, 64, 64),
250+
kind_in_graph="aten::conv2d",
251+
kind_not_in_graph="aten::batch_norm",)
233252
self._test_output_bf16(
234253
ConvBatchNorm_Fixed(2, 3, 32, kernel_size=3, stride=1),
235-
torch.randn(32, 3, 224, 224),
236-
kind="aten::conv2d",
254+
torch.randn(32, 3, 64, 64),
255+
kind_in_graph="aten::conv2d",
256+
kind_not_in_graph="aten::batch_norm",
237257
prec=0.02)
238258

239259

240260
def test_output_conv_bn_3d(self):
241261
self._test_output(
242262
ConvBatchNorm_Fixed(3, 3, 32, kernel_size=3, stride=1),
243-
torch.randn(32, 3, 112, 112, 112),
244-
kind="aten::conv3d")
263+
torch.randn(32, 3, 32, 32, 32),
264+
kind_in_graph="aten::conv3d",
265+
kind_not_in_graph="aten::batch_norm",)
245266
self._test_output_bf16(
246267
ConvBatchNorm_Fixed(3, 3, 32, kernel_size=3, stride=1),
247-
torch.randn(32, 3, 112, 112, 112),
248-
kind="aten::conv3d",
268+
torch.randn(32, 3, 32, 32, 32),
269+
kind_in_graph="aten::conv3d",
270+
kind_not_in_graph="aten::batch_norm",
249271
prec=0.02)
250272

251273

252274
def test_output_conv_relu_2d(self):
253275
self._test_output(
254276
ConvRelu_Fixed(2, 3, 32, kernel_size=3, stride=1),
255-
torch.randn(32, 3, 224, 224),
256-
kind="ipex::conv2d_relu")
277+
torch.randn(32, 3, 64, 64),
278+
kind_in_graph="ipex::conv2d_relu")
257279
self._test_output_bf16(
258280
ConvRelu_Fixed(2, 3, 32, kernel_size=3, stride=1),
259-
torch.randn(32, 3, 224, 224),
260-
kind="ipex::conv2d_relu")
281+
torch.randn(32, 3, 64, 64),
282+
kind_in_graph="ipex::conv2d_relu")
261283

262284

263285
def test_output_conv_relu_3d(self):
264286
self._test_output(
265287
ConvRelu_Fixed(3, 3, 32, kernel_size=3, stride=1),
266-
torch.randn(32, 3, 112, 112, 112),
267-
kind="ipex::conv3d_relu")
288+
torch.randn(32, 3, 32, 32, 32),
289+
kind_in_graph="ipex::conv3d_relu")
268290
self._test_output_bf16(
269291
ConvRelu_Fixed(3, 3, 32, kernel_size=3, stride=1),
270-
torch.randn(32, 3, 112, 112, 112),
271-
kind="ipex::conv3d_relu")
292+
torch.randn(32, 3, 32, 32, 32),
293+
kind_in_graph="ipex::conv3d_relu")
272294

273295

274296
def test_output_conv_sum_2d(self):
275297
self._test_output(
276298
ConvSum(2, 3, 32, kernel_size=3, stride=1),
277-
torch.randn(32, 3, 224, 224),
278-
kind="ipex::conv2d_sum")
299+
torch.randn(32, 3, 64, 64),
300+
kind_in_graph="ipex::conv2d_sum")
279301
self._test_output_bf16(
280302
ConvSum(2, 3, 32, kernel_size=3, stride=1),
281-
torch.randn(32, 3, 224, 224),
282-
kind="ipex::conv2d_sum",
303+
torch.randn(32, 3, 64, 64),
304+
kind_in_graph="ipex::conv2d_sum",
283305
prec=0.04)
284306

285307

286308
def test_output_conv_sum_3d(self):
287309
self._test_output(
288310
ConvSum(3, 3, 32, kernel_size=3, stride=1),
289-
torch.randn(32, 3, 112, 112, 112),
290-
kind="ipex::conv3d_sum")
311+
torch.randn(32, 3, 32, 32, 32),
312+
kind_in_graph="ipex::conv3d_sum")
291313
self._test_output_bf16(
292314
ConvSum(3, 3, 32, kernel_size=3, stride=1),
293-
torch.randn(32, 3, 112, 112, 112),
294-
kind="ipex::conv3d_sum",
315+
torch.randn(32, 3, 32, 32, 32),
316+
kind_in_graph="ipex::conv3d_sum",
295317
prec=0.04)
296318

297319

298320
def test_output_cascaded_conv_bn_sum_relu_2d(self):
299321
self._test_output(
300322
CascadedConvBnSumRelu(2, 3, 64, 32, kernel_size=3, stride=1),
301-
torch.rand(32, 3, 224, 224),
302-
kind="ipex::conv2d_sum_relu")
323+
torch.rand(32, 3, 64, 64),
324+
kind_in_graph="ipex::conv2d_sum_relu",
325+
kind_not_in_graph="aten::batch_norm")
303326
self._test_output_bf16(
304327
CascadedConvBnSumRelu(2, 3, 64, 32, kernel_size=3, stride=1),
305-
torch.rand(32, 3, 224, 224),
306-
kind="ipex::conv2d_sum_relu",
328+
torch.rand(32, 3, 64, 64),
329+
kind_in_graph="ipex::conv2d_sum_relu",
330+
kind_not_in_graph="aten::batch_norm",
307331
prec=0.02)
308332

309333

310334
def test_output_cascaded_conv_bn_sum_relu_3d(self):
311335
self._test_output(
312336
CascadedConvBnSumRelu(3, 3, 64, 32, kernel_size=3, stride=1),
313-
torch.rand(32, 3, 112, 112, 112),
314-
kind="ipex::conv3d_sum_relu")
337+
torch.rand(32, 3, 32, 32, 32),
338+
kind_in_graph="ipex::conv3d_sum_relu",
339+
kind_not_in_graph="aten::batch_norm",)
315340
self._test_output_bf16(
316341
CascadedConvBnSumRelu(3, 3, 64, 32, kernel_size=3, stride=1),
317-
torch.rand(32, 3, 112, 112, 112),
318-
kind="ipex::conv3d_sum_relu",
342+
torch.rand(32, 3, 32, 32, 32),
343+
kind_in_graph="ipex::conv3d_sum_relu",
344+
kind_not_in_graph="aten::batch_norm",
319345
prec=0.02)
320346

321347

322348
def test_output_linear_relu(self):
323349
self._test_output(
324350
LinearRelu(3, 32, bias=True),
325351
torch.rand(32, 3),
326-
kind="ipex::linear_relu")
352+
kind_in_graph="ipex::linear_relu")
327353
self._test_output_bf16(
328354
LinearRelu(3, 32, bias=True),
329355
torch.rand(32, 3),
330-
kind="ipex::linear_relu")
356+
kind_in_graph="ipex::linear_relu")
331357
self._test_output(
332358
LinearRelu(3, 32, bias=False),
333359
torch.rand(32, 3),
334-
kind="ipex::linear_relu")
360+
kind_in_graph="ipex::linear_relu")
335361
self._test_output_bf16(
336362
LinearRelu(3, 32, bias=False),
337363
torch.rand(32, 3),
338-
kind="ipex::linear_relu")
364+
kind_in_graph="ipex::linear_relu")
339365

340366

341367
if __name__ == '__main__':
368+
torch.manual_seed(2020)
342369
core.enable_auto_dnnl()
343370
test = unittest.main()

0 commit comments

Comments
 (0)