Skip to content

Commit 775e658

Browse files
fix torch._C.ScriptFunction' object has no attribute '_c' problem when convert a function to scripted or traced module (#103)
1 parent 9d6dab6 commit 775e658

File tree

1 file changed

+4
-4
lines changed
  • intel_pytorch_extension_py/ops

1 file changed

+4
-4
lines changed

intel_pytorch_extension_py/ops/jit.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def script_(obj, optimize=None, _frames_up=0, _rcb=None):
1414
jit_m = orig_script(obj, optimize=optimize, _frames_up=_frames_up+1, _rcb=_rcb)
1515
torch.jit.script = script_
1616

17-
if core.get_jit_opt():
17+
if core.get_jit_opt() and isinstance(jit_m, torch._C.ScriptModule):
1818
# Disable mix precision in model fusion, since mixed precision cannot
1919
# bring any benefits for inference, but will lead to loss of accuracy
2020
orig_mixed_type = ipex.get_auto_mix_precision()
@@ -24,14 +24,14 @@ def script_(obj, optimize=None, _frames_up=0, _rcb=None):
2424
return jit_m
2525

2626
def trace_(func, example_inputs, *args, **kwargs):
27-
# Disable mix precision. torch.jit.trace will check the traced output
28-
# against what is expected. Since mix precision will lead to
27+
# Disable mix precision. torch.jit.trace will check the traced output
28+
# against what is expected. Since mix precision will lead to
2929
# loss of accuracy, this will raise warning during torch.jit.trace
3030
orig_mixed_type = ipex.get_auto_mix_precision()
3131
ipex.enable_auto_mix_precision(None)
3232
jit_m = orig_trace(func, example_inputs, *args, **kwargs)
3333

34-
if core.get_jit_opt():
34+
if core.get_jit_opt() and isinstance(jit_m, torch._C.ScriptModule):
3535
jit_m = wrap_cpp_module(torch._C._jit_pass_fold_convbn(jit_m._c))
3636
ipex.enable_auto_mix_precision(orig_mixed_type)
3737
return jit_m

0 commit comments

Comments
 (0)