Skip to content

Commit b61029b

Browse files
committed
Merge remote-tracking branch 'origin/master' into fix_pack_padded_sequence
2 parents 8408ee0 + 726d23e commit b61029b

File tree

17 files changed

+699
-217
lines changed

17 files changed

+699
-217
lines changed

intel_pytorch_extension_py/__init__.py

Lines changed: 92 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -9,71 +9,27 @@
99
import _torch_ipex as core
1010

1111
DEVICE = 'dpcpp'
12-
def enable_auto_optimization(mixed_dtype = None, train = False, configure_file = None):
13-
r""" Enable auto-mixed-precision to improve performance.
1412

15-
The auto-mixed-precision auto reorders the tensor to the specified low precision data type.
16-
You don't need to convert the input tensors and the model to the specified data type manually,
17-
the extension will do it automatically and then dispatch the extension backend to accelerate
18-
computation
19-
20-
Args:
21-
mixed_dtype(torch.dtype): Auto reorder the input tensors to the specified low precision data type
22-
and dispatch to oneDNN backend for computation
23-
24-
"""
25-
if mixed_dtype != None:
26-
core.enable_auto_dnnl()
27-
enable_auto_mix_precision(mixed_dtype, train, configure_file)
28-
29-
def get_auto_optimization():
30-
return get_auto_mix_precision
31-
32-
def get_train():
33-
return core.get_train()
34-
35-
def enable_auto_mix_precision(mixed_dtype = torch.bfloat16, train = False, configure_file = None):
36-
if mixed_dtype == torch.bfloat16:
37-
core.enable_mix_bf16_fp32()
38-
core.disable_mix_int8_fp32()
39-
elif mixed_dtype == torch.int8 or mixed_dtype == torch.uint8:
40-
core.enable_mix_int8_fp32()
41-
core.disable_mix_bf16_fp32()
42-
if configure_file != None:
43-
core.disable_int8_calibration()
44-
f = open(configure_file)
45-
configures = json.load(f)
46-
core.load_indicators_file(configures)
47-
else:
48-
warnings.warn("please not forget do calibration before doing validation step")
49-
else:
50-
core.disable_mix_int8_fp32()
51-
core.disable_mix_bf16_fp32()
52-
core.set_execution_mode(train=train)
53-
54-
def get_auto_mix_precision():
55-
if core.get_mix_bf16_fp32():
56-
return torch.bfloat16
57-
elif core.get_mix_int8_fp32():
58-
return torch.int8
59-
else:
60-
return None
61-
62-
'''
63-
def quarry_int8_configure(model, inputs_shape):
64-
dummy_input = torch.randn(input_shapes).to(DEVICE)
65-
core.enable_mix_int8_fp32()
66-
with torch.no_grad():
67-
y = model(dummy_input)
68-
observer_configures = core.get_int8_observer_configures()
69-
return observer_configures
70-
'''
71-
72-
def calibration_reset():
73-
if core.get_int8_calibration():
74-
core.calibration_reset()
75-
else:
76-
raise ValueError("please first run enable_calibration before calibration reset")
13+
class AmpConf(object):
14+
def __init__(self, mixed_dtype = torch.bfloat16, configure_file = None):
15+
self.dtype = mixed_dtype
16+
self.configure_file = configure_file
17+
18+
# for int8 path, if user give a exited configure file, load it.
19+
if self.configure_file != None and self.dtype != torch.bfloat16:
20+
if os.path.exists(self.configure_file) and os.stat(self.configure_file).st_size != 0:
21+
with open(self.configure_file, 'r') as f:
22+
configures = json.load(f)
23+
core.load_indicators_file(configures)
24+
else:
25+
assert False, 'Can not load a empty file or none existed file, plese first do calibartion step'
26+
27+
# for int8 quantization, will save the date after doing calibration step.
28+
def save(self, configure_file):
29+
core.add_indicators()
30+
configures = core.get_int8_configures()
31+
with open(configure_file, 'w') as fp:
32+
json.dump(configures, fp, indent = 4)
7733

7834
class _DecoratorContextManager:
7935
"""Allow a context manager to be used as a decorator, copy form pytorch FW"""
@@ -102,22 +58,80 @@ def generator_context(*args, **kwargs):
10258
break
10359
return generator_context
10460

105-
class int8_calibration(_DecoratorContextManager):
106-
def __init__(self, file_name, observer_configure=None):
107-
#self.observer_configure = observer_configure
108-
self.configure_file = file_name
61+
def get_auto_mix_precision():
62+
if core.get_mix_bf16_fp32():
63+
return torch.bfloat16
64+
elif core.get_mix_int8_fp32():
65+
return torch.int8
66+
else:
67+
return None
68+
69+
def enable_auto_optimization(mixed_dtype = None, train = False):
70+
r""" Enable auto-mixed-precision to improve performance for global scope.
71+
72+
The auto-mixed-precision auto reorders the tensor to the specified low precision data type.
73+
You don't need to convert the input tensors and the model to the specified data type manually,
74+
the extension will do it automatically and then dispatch the extension backend to accelerate
75+
computation
76+
77+
Args:
78+
mixed_dtype(torch.dtype): Auto reorder the input tensors to the specified low precision data type
79+
and dispatch to oneDNN backend for computation, can be torch.bfloat16 or None.
80+
"""
81+
if mixed_dtype != None:
82+
core.enable_auto_dnnl()
83+
running_mode = 'training' if train else 'inference'
84+
enable_auto_mix_precision(AmpConf(mixed_dtype), running_mode).__enter__()
85+
86+
def get_auto_optimization():
87+
return get_auto_mix_precision
88+
89+
def get_train():
90+
return core.get_train()
91+
92+
class enable_auto_mix_precision(_DecoratorContextManager):
93+
def __init__(self, conf, running_mode = 'inference'):
94+
self.pre_mixed_dtype = get_auto_mix_precision()
95+
self.pre_running_mode = get_train()
96+
self.pre_calibration_state = core.get_int8_calibration()
97+
self.mixed_dtype = conf.dtype
98+
self.running_mode = running_mode
10999

110100
def __enter__(self):
111-
if not core.get_mix_int8_fp32():
112-
raise ValueError("please first run enable_auto_mix_precision(torch.int8) before int8 calibration")
113-
core.enable_int8_calibration()
114-
#core.set_int8_observer_configure(self.observer_configure)
101+
if self.mixed_dtype == torch.bfloat16:
102+
core.enable_mix_bf16_fp32()
103+
core.disable_mix_int8_fp32()
104+
elif self.mixed_dtype == torch.int8:
105+
core.enable_mix_int8_fp32()
106+
core.disable_mix_bf16_fp32()
107+
if self.running_mode == 'inference':
108+
core.disable_int8_calibration()
109+
elif self.running_mode == 'calibration':
110+
core.enable_int8_calibration()
111+
else:
112+
assert False, 'int8 quantization only suport inference and calibration running mode'
113+
else:
114+
core.disable_mix_int8_fp32()
115+
core.disable_mix_bf16_fp32()
116+
core.set_execution_mode(train = True if self.running_mode == 'training' else False)
115117

116118
def __exit__(self, *args):
117-
core.disable_int8_calibration()
118-
core.add_indicators()
119-
configures = core.get_int8_configures()
120-
with open(self.configure_file, 'w') as fp:
121-
json.dump(configures, fp, indent=4)
122-
return False
119+
if self.mixed_dtype == torch.int8:
120+
if self.running_mode == 'calibration':
121+
core.calibration_reset()
122+
# restore previous state
123+
if self.pre_calibration_state:
124+
core.enable_int8_calibration()
125+
else:
126+
core.disable_int8_calibration()
127+
if self.pre_mixed_dtype == torch.bfloat16:
128+
core.enable_mix_bf16_fp32()
129+
core.disable_mix_int8_fp32()
130+
elif self.pre_mixed_dtype == torch.int8:
131+
core.enable_mix_int8_fp32()
132+
core.disable_mix_bf16_fp32()
133+
else:
134+
core.disable_mix_int8_fp32()
135+
core.disable_mix_bf16_fp32()
136+
core.set_execution_mode(train = self.pre_running_mode)
123137

intel_pytorch_extension_py/ops/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from .embeddingbag import embeddingbag
33
from .linear import *
44
from .pooling import *
5-
from .reshape import *
65
from .mlp import *
76
from .jit import *
87
from .save import *

intel_pytorch_extension_py/ops/reshape.py

Lines changed: 0 additions & 18 deletions
This file was deleted.

tests/cpu/common_ipex_conf.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,21 @@
44
class AutoMixPrecision(object):
55
def __init__(self, enable_or_not = False, train = False):
66
self.old_value = ipex.get_auto_mix_precision()
7-
self.train_old_value = ipex.get_train()
7+
self.pre_running_mode = 'training' if ipex.get_train() else 'inference'
88
self.enable_or_not = enable_or_not
9-
self.train = train
9+
self.running_mode = 'training' if train else 'inference'
1010

1111
def __enter__(self):
1212
if self.enable_or_not:
13-
ipex.enable_auto_mix_precision(mixed_dtype=torch.bfloat16, train=self.train)
13+
ipex.enable_auto_mix_precision(ipex.AmpConf(torch.bfloat16), self.running_mode).__enter__()
1414
else:
15-
ipex.enable_auto_mix_precision(mixed_dtype=None)
15+
ipex.enable_auto_mix_precision(ipex.AmpConf(None)).__enter__()
1616

1717
def __exit__(self, *args, **kwargs):
1818
if self.old_value:
19-
ipex.enable_auto_mix_precision(mixed_dtype=torch.bfloat16, train=self.train_old_value)
19+
ipex.enable_auto_mix_precision(ipex.AmpConf(torch.bfloat16), self.pre_running_mode).__enter__()
2020
else:
21-
ipex.enable_auto_mix_precision(mixed_dtype=None)
21+
ipex.enable_auto_mix_precision(ipex.AmpConf(None)).__enter__()
2222

2323
class AutoDNNL(object):
2424
def __init__(self, enable_or_not = False):

0 commit comments

Comments
 (0)