@@ -13,27 +13,33 @@ def script_(obj, optimize=None, _frames_up=0, _rcb=None):
1313 jit_m = orig_script (obj , optimize = optimize , _frames_up = _frames_up + 1 , _rcb = _rcb )
1414 torch .jit .script = script_
1515
16- mix_state = core .get_mix_bf16_fp32 ()
16+ mix_state = torch . bfloat16 if core .get_mix_bf16_fp32 () else torch . int8 if core . get_mix_int8_fp32 () else None
1717 # Disable mix precision in model fusion, since mixed precision cannot
1818 # bring any benefits for inference, but will lead to loss of accuracy
1919 core .disable_mix_bf16_fp32 ()
20+ core .disable_mix_int8_fp32 ()
2021 if core .get_jit_opt () and hasattr (jit_m , '_c' ):
2122 jit_m = wrap_cpp_module (torch ._C ._jit_pass_fold_convbn (jit_m ._c ))
22- if mix_state :
23+ if mix_state == torch . bfloat16 :
2324 core .enable_mix_bf16_fp32 ()
25+ elif mix_state == torch .int8 :
26+ core .enable_mix_int8_fp32 ()
2427 return jit_m
2528
2629def trace_ (func , example_inputs , * args , ** kwargs ):
2730 # Disable mix precision. torch.jit.trace will check the traced output
2831 # against what is expected. Since mix precision will lead to
2932 # loss of accuracy, this will raise warning during torch.jit.trace
30- mix_state = core .get_mix_bf16_fp32 ()
33+ mix_state = torch . bfloat16 if core .get_mix_bf16_fp32 () else torch . int8 if core . get_mix_int8_fp32 () else None
3134 core .disable_mix_bf16_fp32 ()
35+ core .disable_mix_int8_fp32 ()
3236 jit_m = orig_trace (func , example_inputs , * args , ** kwargs )
3337 if core .get_jit_opt () and hasattr (jit_m , '_c' ):
3438 jit_m = wrap_cpp_module (torch ._C ._jit_pass_fold_convbn (jit_m ._c ))
35- if mix_state :
39+ if mix_state == torch . bfloat16 :
3640 core .enable_mix_bf16_fp32 ()
41+ elif mix_state == torch .int8 :
42+ core .enable_mix_int8_fp32 ()
3743 return jit_m
3844
3945
0 commit comments