Skip to content

Commit 98fa4e9

Browse files
committed
change int8 quantization api
1 parent 29e53b0 commit 98fa4e9

File tree

6 files changed

+98
-76
lines changed

6 files changed

+98
-76
lines changed

intel_pytorch_extension_py/__init__.py

Lines changed: 92 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -9,61 +9,27 @@
99
import _torch_ipex as core
1010

1111
DEVICE = 'dpcpp'
12-
def enable_auto_optimization(mixed_dtype = None, train = False, configure_file = None):
13-
r""" Enable auto-mixed-precision to improve performance.
1412

15-
The auto-mixed-precision auto reorders the tensor to the specified low precision data type.
16-
You don't need to convert the input tensors and the model to the specified data type manually,
17-
the extension will do it automatically and then dispatch the extension backend to accelerate
18-
computation
19-
20-
Args:
21-
mixed_dtype(torch.dtype): Auto reorder the input tensors to the specified low precision data type
22-
and dispatch to oneDNN backend for computation
23-
24-
"""
25-
if mixed_dtype != None:
26-
core.enable_auto_dnnl()
27-
enable_auto_mix_precision(mixed_dtype, train, configure_file)
28-
29-
def get_auto_optimization():
30-
return get_auto_mix_precision
31-
32-
def get_train():
33-
return core.get_train()
34-
35-
def enable_auto_mix_precision(mixed_dtype = torch.bfloat16, train = False, configure_file = None):
36-
if mixed_dtype == torch.bfloat16:
37-
core.enable_mix_bf16_fp32()
38-
core.disable_mix_int8_fp32()
39-
elif mixed_dtype == torch.int8 or mixed_dtype == torch.uint8:
40-
core.enable_mix_int8_fp32()
41-
core.disable_mix_bf16_fp32()
42-
if configure_file != None:
43-
core.disable_int8_calibration()
44-
f = open(configure_file)
45-
configures = json.load(f)
46-
core.load_indicators_file(configures)
47-
else:
48-
warnings.warn("please not forget do calibration before doing validation step")
49-
else:
50-
core.disable_mix_int8_fp32()
51-
core.disable_mix_bf16_fp32()
52-
core.set_execution_mode(train=train)
53-
54-
def get_auto_mix_precision():
55-
if core.get_mix_bf16_fp32():
56-
return torch.bfloat16
57-
elif core.get_mix_int8_fp32():
58-
return torch.int8
59-
else:
60-
return None
61-
62-
def calibration_reset():
63-
if core.get_int8_calibration():
64-
core.calibration_reset()
65-
else:
66-
raise ValueError("please first run enable_calibration before calibration reset")
13+
class AmpConf(object):
14+
def __init__(self, mixed_dtype = torch.bfloat16, configure_file = None):
15+
self.dtype = mixed_dtype
16+
self.configure_file = configure_file
17+
18+
# for int8 path, if user give a exited configure file, load it.
19+
if self.configure_file != None and self.dtype != torch.bfloat16:
20+
if os.path.exists(self.configure_file) and os.stat(self.configure_file).st_size != 0:
21+
f = open(self.configure_file)
22+
configures = json.load(f)
23+
core.load_indicators_file(configures)
24+
else:
25+
assert False, 'Can not load a empty file or none existed file, plese first do calibartion step'
26+
27+
# for int8 quantization, will save the date after doing calibration step.
28+
def save(self, configure_file):
29+
core.add_indicators()
30+
configures = core.get_int8_configures()
31+
with open(configure_file, 'w') as fp:
32+
json.dump(configures, fp, indent = 4)
6733

6834
class _DecoratorContextManager:
6935
"""Allow a context manager to be used as a decorator, copy form pytorch FW"""
@@ -92,24 +58,80 @@ def generator_context(*args, **kwargs):
9258
break
9359
return generator_context
9460

95-
class int8_calibration(_DecoratorContextManager):
96-
def __init__(self, file_name, observer_configure=None):
97-
self.configure_file = file_name
61+
def get_auto_mix_precision():
62+
if core.get_mix_bf16_fp32():
63+
return torch.bfloat16
64+
elif core.get_mix_int8_fp32():
65+
return torch.int8
66+
else:
67+
return None
68+
69+
def enable_auto_optimization(mixed_dtype = None, train = False):
70+
r""" Enable auto-mixed-precision to improve performance for global scope.
71+
72+
The auto-mixed-precision auto reorders the tensor to the specified low precision data type.
73+
You don't need to convert the input tensors and the model to the specified data type manually,
74+
the extension will do it automatically and then dispatch the extension backend to accelerate
75+
computation
76+
77+
Args:
78+
mixed_dtype(torch.dtype): Auto reorder the input tensors to the specified low precision data type
79+
and dispatch to oneDNN backend for computation, can be torch.bfloat16 or None.
80+
"""
81+
if mixed_dtype != None:
82+
core.enable_auto_dnnl()
83+
running_mode = 'training' if train else 'inference'
84+
enable_auto_mix_precision(AmpConf(mixed_dtype), running_mode).__enter__()
85+
86+
def get_auto_optimization():
87+
return get_auto_mix_precision
88+
89+
def get_train():
90+
return core.get_train()
91+
92+
class enable_auto_mix_precision(_DecoratorContextManager):
93+
def __init__(self, conf, running_mode = 'inference'):
94+
self.pre_mixed_dtype = get_auto_mix_precision()
95+
self.pre_running_mode = get_train()
96+
self.pre_calibration_state = core.get_int8_calibration()
97+
self.mixed_dtype = conf.dtype
98+
self.running_mode = running_mode
9899

99100
def __enter__(self):
100-
if not core.get_mix_int8_fp32():
101-
raise ValueError("please first run enable_auto_mix_precision(torch.int8) before int8 calibration")
102-
core.enable_int8_calibration()
103-
if os.path.exists(self.configure_file) and os.stat(self.configure_file).st_size != 0:
104-
f = open(self.configure_file)
105-
configures = json.load(f)
106-
core.load_indicators_file(configures)
101+
if self.mixed_dtype == torch.bfloat16:
102+
core.enable_mix_bf16_fp32()
103+
core.disable_mix_int8_fp32()
104+
elif self.mixed_dtype == torch.int8:
105+
core.enable_mix_int8_fp32()
106+
core.disable_mix_bf16_fp32()
107+
if self.running_mode == 'inference':
108+
core.disable_int8_calibration()
109+
elif self.running_mode == 'calibration':
110+
core.enable_int8_calibration()
111+
else:
112+
assert False, 'int8 quantization only suport inference and calibration running mode'
113+
else:
114+
core.disable_mix_int8_fp32()
115+
core.disable_mix_bf16_fp32()
116+
core.set_execution_mode(train = True if self.running_mode == 'training' else False)
107117

108118
def __exit__(self, *args):
109-
core.disable_int8_calibration()
110-
core.add_indicators()
111-
configures = core.get_int8_configures()
112-
with open(self.configure_file, 'w') as fp:
113-
json.dump(configures, fp, indent=4)
114-
return False
119+
if self.mixed_dtype == torch.int8:
120+
if self.running_mode == 'calibration':
121+
core.calibration_reset()
122+
# restore previous state
123+
if self.pre_calibration_state:
124+
core.enable_int8_calibration()
125+
else:
126+
core.disable_int8_calibration()
127+
if self.pre_mixed_dtype == torch.bfloat16:
128+
core.enable_mix_bf16_fp32()
129+
core.disable_mix_int8_fp32()
130+
elif self.pre_mixed_dtype == torch.int8:
131+
core.enable_mix_int8_fp32()
132+
core.disable_mix_bf16_fp32()
133+
else:
134+
core.disable_mix_int8_fp32()
135+
core.disable_mix_bf16_fp32()
136+
core.set_execution_mode(train = self.pre_running_mode)
115137

tests/cpu/common_ipex_conf.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,21 @@
44
class AutoMixPrecision(object):
55
def __init__(self, enable_or_not = False, train = False):
66
self.old_value = ipex.get_auto_mix_precision()
7-
self.train_old_value = ipex.get_train()
7+
self.pre_running_mode = 'training' if ipex.get_train() else 'inference'
88
self.enable_or_not = enable_or_not
9-
self.train = train
9+
self.running_mode = 'training' if train else 'inference'
1010

1111
def __enter__(self):
1212
if self.enable_or_not:
13-
ipex.enable_auto_mix_precision(mixed_dtype=torch.bfloat16, train=self.train)
13+
ipex.enable_auto_mix_precision(ipex.AmpConf(torch.bfloat16), self.running_mode).__enter__()
1414
else:
15-
ipex.enable_auto_mix_precision(mixed_dtype=None)
15+
ipex.enable_auto_mix_precision(ipex.AmpConf(None)).__enter__()
1616

1717
def __exit__(self, *args, **kwargs):
1818
if self.old_value:
19-
ipex.enable_auto_mix_precision(mixed_dtype=torch.bfloat16, train=self.train_old_value)
19+
ipex.enable_auto_mix_precision(ipex.AmpConf(torch.bfloat16), self.pre_running_mode).__enter__()
2020
else:
21-
ipex.enable_auto_mix_precision(mixed_dtype=None)
21+
ipex.enable_auto_mix_precision(ipex.AmpConf(None)).__enter__()
2222

2323
class AutoDNNL(object):
2424
def __init__(self, enable_or_not = False):

tests/cpu/model.pth

31.6 KB
Binary file not shown.

tests/cpu/model_dpcpp.pth

16.1 KB
Binary file not shown.

tests/cpu/tensor.pt

6.55 MB
Binary file not shown.

tests/cpu/tensor_dpcpp.pt

6.55 MB
Binary file not shown.

0 commit comments

Comments
 (0)