Skip to content

Commit 1f486a5

Browse files
Checking hasattr _c for script (#105)
1 parent 775e658 commit 1f486a5

File tree

2 files changed

+27
-10
lines changed

2 files changed

+27
-10
lines changed

intel_pytorch_extension_py/ops/jit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def script_(obj, optimize=None, _frames_up=0, _rcb=None):
1414
jit_m = orig_script(obj, optimize=optimize, _frames_up=_frames_up+1, _rcb=_rcb)
1515
torch.jit.script = script_
1616

17-
if core.get_jit_opt() and isinstance(jit_m, torch._C.ScriptModule):
17+
if core.get_jit_opt() and hasattr(jit_m, '_c'):
1818
# Disable mix precision in model fusion, since mixed precision cannot
1919
# bring any benefits for inference, but will lead to loss of accuracy
2020
orig_mixed_type = ipex.get_auto_mix_precision()
@@ -31,7 +31,7 @@ def trace_(func, example_inputs, *args, **kwargs):
3131
ipex.enable_auto_mix_precision(None)
3232
jit_m = orig_trace(func, example_inputs, *args, **kwargs)
3333

34-
if core.get_jit_opt() and isinstance(jit_m, torch._C.ScriptModule):
34+
if core.get_jit_opt() and hasattr(jit_m, '_c'):
3535
jit_m = wrap_cpp_module(torch._C._jit_pass_fold_convbn(jit_m._c))
3636
ipex.enable_auto_mix_precision(orig_mixed_type)
3737
return jit_m

tests/cpu/test_jit.py

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def _test_output(self, model, x, kind_in_graph=None, kind_not_in_graph=None):
192192
if kind_in_graph is not None:
193193
self.assertTrue(any(n.kind() == kind_in_graph for n in script_graph.nodes()))
194194
self.assertTrue(any(n.kind() == kind_in_graph for n in trace_graph.nodes()))
195-
195+
196196
# check if certain node does not exist in the graph
197197
if kind_not_in_graph is not None:
198198
self.assertTrue(all(n.kind() != kind_not_in_graph for n in script_graph.nodes()))
@@ -205,12 +205,12 @@ def _test_output_bf16(self, model, x, kind_in_graph=None, kind_not_in_graph=None
205205
core.enable_auto_dnnl()
206206
core.enable_jit_opt()
207207
core.enable_mix_bf16_fp32()
208-
208+
209209
model = model.to(ipex.DEVICE).eval()
210210
x = x.to(ipex.DEVICE)
211211
x2 = x.clone()
212212
x3 = x.clone()
213-
213+
214214
script_fused_model = torch.jit.script(copy.deepcopy(model))
215215
trace_fused_model = torch.jit.trace(copy.deepcopy(model), x3)
216216

@@ -224,24 +224,24 @@ def _test_output_bf16(self, model, x, kind_in_graph=None, kind_not_in_graph=None
224224
# bf 16, jit trace path
225225
trace_graph = trace_fused_model.graph_for(x3)
226226
fused_tresult = trace_fused_model(x3)
227-
228-
# disable mix_bf16_fp32 when the calculation is done
227+
228+
# disable mix_bf16_fp32 when the calculation is done
229229
# to avoid affecting other scripts
230230
core.disable_mix_bf16_fp32()
231-
231+
232232
self.assertEqual(fused_sresult, result, prec=prec)
233233
self.assertEqual(fused_tresult, result, prec=prec)
234234

235235
# check if the fused node exists in the graph
236236
if kind_in_graph is not None:
237237
self.assertTrue(any(n.kind() == kind_in_graph for n in script_graph.nodes()))
238238
self.assertTrue(any(n.kind() == kind_in_graph for n in trace_graph.nodes()))
239-
239+
240240
# check if certain node does not exist in the graph
241241
if kind_not_in_graph is not None:
242242
self.assertTrue(all(n.kind() != kind_not_in_graph for n in script_graph.nodes()))
243243
self.assertTrue(all(n.kind() != kind_not_in_graph for n in trace_graph.nodes()))
244-
244+
245245

246246
def test_output_conv_bn_2d(self):
247247
self._test_output(
@@ -364,6 +364,23 @@ def test_output_linear_relu(self):
364364
kind_in_graph="ipex::linear_relu")
365365

366366

367+
def test_jit_function(self):
368+
# test hool trace and script can works for function
369+
def fn(input, weight, bias):
370+
return F.linear(input, weight, bias)
371+
372+
input = torch.randn(2, 4)
373+
weight = torch.randn(5, 4)
374+
bias = torch.randn(5)
375+
result = fn(input, weight, bias)
376+
377+
scripted_fn = torch.jit.script(fn)
378+
traced_fn = torch.jit.trace(fn, (input, weight, bias))
379+
380+
self.assertEqual(scripted_fn(input, weight, bias), result)
381+
self.assertEqual(traced_fn(input, weight, bias), result)
382+
383+
367384
if __name__ == '__main__':
368385
torch.manual_seed(2020)
369386
core.enable_auto_dnnl()

0 commit comments

Comments
 (0)