Skip to content

Commit 29e53b0

Browse files
XiaobingSuperEikanWang
authored andcommitted
int8:enable change observer algorithm
1 parent 828959c commit 29e53b0

File tree

3 files changed

+11
-32
lines changed

3 files changed

+11
-32
lines changed

intel_pytorch_extension_py/__init__.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,6 @@ def get_auto_mix_precision():
5959
else:
6060
return None
6161

62-
'''
63-
def quarry_int8_configure(model, inputs_shape):
64-
dummy_input = torch.randn(input_shapes).to(DEVICE)
65-
core.enable_mix_int8_fp32()
66-
with torch.no_grad():
67-
y = model(dummy_input)
68-
observer_configures = core.get_int8_observer_configures()
69-
return observer_configures
70-
'''
71-
7262
def calibration_reset():
7363
if core.get_int8_calibration():
7464
core.calibration_reset()
@@ -104,14 +94,16 @@ def generator_context(*args, **kwargs):
10494

10595
class int8_calibration(_DecoratorContextManager):
10696
def __init__(self, file_name, observer_configure=None):
107-
#self.observer_configure = observer_configure
10897
self.configure_file = file_name
10998

11099
def __enter__(self):
111100
if not core.get_mix_int8_fp32():
112101
raise ValueError("please first run enable_auto_mix_precision(torch.int8) before int8 calibration")
113102
core.enable_int8_calibration()
114-
#core.set_int8_observer_configure(self.observer_configure)
103+
if os.path.exists(self.configure_file) and os.stat(self.configure_file).st_size != 0:
104+
f = open(self.configure_file)
105+
configures = json.load(f)
106+
core.load_indicators_file(configures)
115107

116108
def __exit__(self, *args):
117109
core.disable_int8_calibration()

torch_ipex/csrc/auto_opt_config.h

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,12 @@ class AutoOptConfig {
6464
std::vector<std::vector<float>> i_min_max_values, std::vector<std::vector<float>> o_min_max_values) {
6565
num_ops_id++;
6666
if (observers_.size() < num_ops_id) {
67-
// this path is that user not set int8 op's configure, using default configures
68-
Observer new_observer = {num_ops_id - 1, op_name, i_min_max_values, o_min_max_values};
67+
// this path is that user not set int8 op's configure, using default configures if user not set it.
68+
std::string observer_algorithm = "min_max";
69+
if (!indicators_.empty()) {
70+
observer_algorithm = indicators_[num_ops_id - 1].get_indicator_algorithm();
71+
}
72+
Observer new_observer = {num_ops_id - 1, op_name, i_min_max_values, o_min_max_values, observer_algorithm};
6973
observers_.push_back(new_observer);
7074
} else {
7175
// user has set configure or have run one interation
@@ -94,24 +98,9 @@ class AutoOptConfig {
9498
}
9599
}
96100

97-
/*
98-
inline void print_observer() {
99-
for (auto i = 0; i< observers_.size(); i++) {
100-
for (auto j = 0; j < observers_[i].max_values.size(); j++)
101-
std::cout<<observers_[i].max_values[j]<<std::endl;
102-
}
103-
}
104-
inline void print_indicator() {
105-
for (auto i = 0; i< indicators_.size(); i++) {
106-
auto scales = indicators_[i].get_indicator_scales();
107-
for (auto j = 0; j< scales.size(); j++)
108-
std::cout<<scales[j]<<std::endl;
109-
}
110-
}
111-
*/
112-
113101
inline void add_indicators() {
114102
num_ops_id = 0;
103+
indicators_.clear();
115104
// default used is s8
116105
for (auto i = 0; i < observers_.size(); i++) {
117106
std::vector<float> inputs_scale, outputs_scale;

torch_ipex/csrc/init_python_bindings.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,6 @@ void InitIpexModuleBindings(py::module m) {
147147
m.def("get_int8_calibration", []() { return AutoOptConfig::singleton().get_int8_calibration(); });
148148
m.def("calibration_reset", []() { AutoOptConfig::singleton().calibration_reset(); });
149149
m.def("add_indicators", []() { AutoOptConfig::singleton().add_indicators(); });
150-
//m.def("print_observer", []() { AutoOptConfig::singleton().print_observer(); });
151-
// m.def("print_indicator", []() { AutoOptConfig::singleton().print_indicator(); });
152150
m.def("get_int8_configures", []() {
153151
py::list output_list;
154152
auto indicators = AutoOptConfig::singleton().get_indicators();

0 commit comments

Comments
 (0)