Skip to content

Commit 94a3cf7

Browse files
committed
disable int8 path when doing tracing
1 parent 9d65c96 commit 94a3cf7

File tree

1 file changed

+10
-4
lines changed
  • intel_pytorch_extension_py/ops

1 file changed

+10
-4
lines changed

intel_pytorch_extension_py/ops/jit.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,27 +13,33 @@ def script_(obj, optimize=None, _frames_up=0, _rcb=None):
1313
jit_m = orig_script(obj, optimize=optimize, _frames_up=_frames_up+1, _rcb=_rcb)
1414
torch.jit.script = script_
1515

16-
mix_state = core.get_mix_bf16_fp32()
16+
mix_state = torch.bfloat16 if core.get_mix_bf16_fp32() else torch.int8 if core.get_mix_int8_fp32() else None
1717
# Disable mix precision in model fusion, since mixed precision cannot
1818
# bring any benefits for inference, but will lead to loss of accuracy
1919
core.disable_mix_bf16_fp32()
20+
core.disable_mix_int8_fp32()
2021
if core.get_jit_opt() and hasattr(jit_m, '_c'):
2122
jit_m = wrap_cpp_module(torch._C._jit_pass_fold_convbn(jit_m._c))
22-
if mix_state:
23+
if mix_state == torch.bfloat16:
2324
core.enable_mix_bf16_fp32()
25+
elif mix_state == torch.int8:
26+
core.enable_mix_int8_fp32()
2427
return jit_m
2528

2629
def trace_(func, example_inputs, *args, **kwargs):
2730
# Disable mix precision. torch.jit.trace will check the traced output
2831
# against what is expected. Since mix precision will lead to
2932
# loss of accuracy, this will raise warning during torch.jit.trace
30-
mix_state = core.get_mix_bf16_fp32()
33+
mix_state = torch.bfloat16 if core.get_mix_bf16_fp32() else torch.int8 if core.get_mix_int8_fp32() else None
3134
core.disable_mix_bf16_fp32()
35+
core.disable_mix_int8_fp32()
3236
jit_m = orig_trace(func, example_inputs, *args, **kwargs)
3337
if core.get_jit_opt() and hasattr(jit_m, '_c'):
3438
jit_m = wrap_cpp_module(torch._C._jit_pass_fold_convbn(jit_m._c))
35-
if mix_state:
39+
if mix_state == torch.bfloat16:
3640
core.enable_mix_bf16_fp32()
41+
elif mix_state == torch.int8:
42+
core.enable_mix_int8_fp32()
3743
return jit_m
3844

3945

0 commit comments

Comments
 (0)