@@ -14,7 +14,7 @@ def script_(obj, optimize=None, _frames_up=0, _rcb=None):
1414 jit_m = orig_script (obj , optimize = optimize , _frames_up = _frames_up + 1 , _rcb = _rcb )
1515 torch .jit .script = script_
1616
17- if core .get_jit_opt ():
17+ if core .get_jit_opt () and isinstance ( jit_m , torch . _C . ScriptModule ) :
1818 # Disable mix precision in model fusion, since mixed precision cannot
1919 # bring any benefits for inference, but will lead to loss of accuracy
2020 orig_mixed_type = ipex .get_auto_mix_precision ()
@@ -24,14 +24,14 @@ def script_(obj, optimize=None, _frames_up=0, _rcb=None):
2424 return jit_m
2525
2626def trace_ (func , example_inputs , * args , ** kwargs ):
27- # Disable mix precision. torch.jit.trace will check the traced output
28- # against what is expected. Since mix precision will lead to
27+ # Disable mix precision. torch.jit.trace will check the traced output
28+ # against what is expected. Since mix precision will lead to
2929 # loss of accuracy, this will raise warning during torch.jit.trace
3030 orig_mixed_type = ipex .get_auto_mix_precision ()
3131 ipex .enable_auto_mix_precision (None )
3232 jit_m = orig_trace (func , example_inputs , * args , ** kwargs )
3333
34- if core .get_jit_opt ():
34+ if core .get_jit_opt () and isinstance ( jit_m , torch . _C . ScriptModule ):
3535 jit_m = wrap_cpp_module (torch ._C ._jit_pass_fold_convbn (jit_m ._c ))
3636 ipex .enable_auto_mix_precision (orig_mixed_type )
3737 return jit_m
0 commit comments