Skip to content

Commit 2449a8d

Browse files
committed
int8: conv_relu, conv_sum, conv_sum_relu
1 parent 94bd248 commit 2449a8d

File tree

2 files changed

+101
-31
lines changed

2 files changed

+101
-31
lines changed

torch_ipex/csrc/cpu/FusionOPs.cpp

Lines changed: 89 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -31,24 +31,46 @@ at::Tensor dil_convolution_outplace_fusion(
3131
at::IntArrayRef padding,
3232
at::IntArrayRef dilation,
3333
int64_t groups,
34-
const dil::attr_t& op_attr) {
34+
const dil::attr_t& op_attr,
35+
const std::string& op_name = "Convolution_Relu") {
3536
dil::tensor dil_input;
3637
dil::tensor dil_weight;
3738
c10::optional<dil::tensor> dil_bias{c10::nullopt};
39+
// for int8 path, input always acbd format which is non-contiguous, .contiguous() will reorder to fp32
40+
auto src_dil_type = dbl::comm::try_gen_dil_tensor(input).get_data_type();
41+
auto input_contiguous = (src_dil_type == dil::data_type::u8 || src_dil_type == dil::data_type::s8
42+
|| input.is_contiguous()) ? input : input.contiguous();
43+
auto weight_dil_type = dbl::comm::try_gen_dil_tensor(weight).get_data_type();
44+
auto weight_contiguous = (weight_dil_type == dil::data_type::s8 || weight.is_contiguous()) ? weight : weight.contiguous();
45+
46+
std::vector<float> output_scale = {};
47+
if (check_auto_mix_int8_fp32() && !check_int8_calibration()) {
48+
std::vector<std::vector<float>> scales;
49+
bool quantized;
50+
std::tie(scales, quantized) = dbl::comm::get_int8_scales({input}, /* uint8_used for output*/false);
51+
//quantized = false;
52+
if (quantized) {
53+
output_scale.push_back(scales[1][0]);
54+
dbl::comm::reorder_to_int8_for_mix_prec(input_contiguous, scales[0]);
55+
dbl::comm::reorder_to_int8_for_mix_prec(weight_contiguous, {});
56+
} else {
57+
dbl::comm::reorder_to_dtype(input, at::kFloat);
58+
dbl::comm::reorder_to_dtype(weight, at::kFloat);
59+
}
60+
} else {
61+
dbl::comm::reorder_to_bf16_for_mix_prec(input_contiguous);
62+
dbl::comm::reorder_to_bf16_for_mix_prec(weight_contiguous);
63+
}
3864

39-
auto input_contiguous = input.is_contiguous() ? input : input.contiguous();
40-
auto weight_contiguous = weight.is_contiguous() ? weight : weight.contiguous();
41-
42-
reorder_to_bf16_for_mix_prec(input_contiguous);
4365
dil_input = try_gen_dil_tensor(input_contiguous);
44-
4566
if (bias.defined()) {
46-
auto bias_contiguous = bias.is_contiguous() ? bias : bias.contiguous();
47-
reorder_to_bf16_for_mix_prec(bias_contiguous);
48-
dil_bias = try_gen_dil_tensor(bias_contiguous);
67+
auto bias_contiguous = bias.is_contiguous() ? bias : bias.contiguous();
68+
if (!check_auto_mix_int8_fp32()) {
69+
dbl::comm::reorder_to_bf16_for_mix_prec(bias_contiguous);
70+
}
71+
dil_bias = dbl::comm::try_gen_dil_tensor(bias_contiguous);
4972
}
5073

51-
reorder_to_bf16_for_mix_prec(weight_contiguous);
5274
dbl::conv::prepack_conv_weights(
5375
input_contiguous,
5476
dil_input,
@@ -67,9 +89,14 @@ at::Tensor dil_convolution_outplace_fusion(
6789
stride,
6890
dilation,
6991
groups,
70-
op_attr);
92+
op_attr,
93+
output_scale);
7194

72-
return gen_aten_tensor_by(std::move(dil_output));
95+
auto aten_output = dbl::comm::gen_aten_tensor_by(std::move(dil_output));
96+
if (check_auto_mix_int8_fp32() && check_int8_calibration()) {
97+
insert_or_updata_observer({input_contiguous}, {aten_output}, op_name);
98+
}
99+
return aten_output;
73100
}
74101

75102
static at::Tensor& dil_convolution_inplace_fusion(
@@ -81,28 +108,55 @@ static at::Tensor& dil_convolution_inplace_fusion(
81108
at::IntArrayRef padding,
82109
at::IntArrayRef dilation,
83110
int64_t groups,
84-
const dil::attr_t& attr) {
111+
const dil::attr_t& attr,
112+
const std::string& op_name) {
85113
dil::tensor dil_input;
86114
dil::tensor dil_weight;
87115
dil::tensor dil_output;
88116
c10::optional<dil::tensor> dil_bias{c10::nullopt};
89117

90-
auto input_contiguous = input.is_contiguous() ? input : input.contiguous();
91-
auto weight_contiguous = weight.is_contiguous() ? weight : weight.contiguous();
92-
auto output_contiguous = accumu.is_contiguous() ? accumu : accumu.contiguous();
118+
// for int8 path, input always acbd format which is non-contiguous, .contiguous() will reorder to fp32
119+
auto src_dil_type = dbl::comm::try_gen_dil_tensor(input).get_data_type();
120+
auto input_contiguous = (src_dil_type == dil::data_type::u8 || src_dil_type == dil::data_type::s8
121+
|| input.is_contiguous()) ? input : input.contiguous();
122+
auto weight_dil_type = dbl::comm::try_gen_dil_tensor(weight).get_data_type();
123+
auto weight_contiguous = (weight_dil_type == dil::data_type::s8 || weight.is_contiguous()) ? weight : weight.contiguous();
124+
auto ouput_dil_type = dbl::comm::try_gen_dil_tensor(accumu).get_data_type();
125+
auto output_contiguous = (ouput_dil_type == dil::data_type::s8 || accumu.is_contiguous()) ? accumu : accumu.contiguous();
126+
127+
std::vector<float> output_scale = {};
128+
if (check_auto_mix_int8_fp32() && !check_int8_calibration()) {
129+
std::vector<std::vector<float>> scales;
130+
bool quantized;
131+
std::tie(scales, quantized) = dbl::comm::get_int8_scales({input}, /* uint8_used for output*/false);
132+
//quantized = false;
133+
if (quantized) {
134+
output_scale.push_back(scales[1][0]);
135+
dbl::comm::reorder_to_int8_for_mix_prec(input_contiguous, scales[0]);
136+
dbl::comm::reorder_to_int8_for_mix_prec(weight_contiguous, {});
137+
} else {
138+
dbl::comm::reorder_to_dtype(input_contiguous, at::kFloat);
139+
dbl::comm::reorder_to_dtype(weight_contiguous, at::kFloat);
140+
// ouput may a int8 tensor, should reorder to fp32
141+
dbl::comm::reorder_to_dtype(output_contiguous, at::kFloat);
142+
}
143+
} else {
144+
dbl::comm::reorder_to_bf16_for_mix_prec(input_contiguous);
145+
dbl::comm::reorder_to_bf16_for_mix_prec(weight_contiguous);
146+
dbl::comm::reorder_to_bf16_for_mix_prec(output_contiguous);
147+
}
93148

94-
reorder_to_bf16_for_mix_prec(input_contiguous);
95-
reorder_to_bf16_for_mix_prec(output_contiguous);
96149
dil_input = try_gen_dil_tensor(input_contiguous);
97150
dil_output = try_gen_dil_tensor(output_contiguous);
98151

99152
if (bias.defined()) {
100-
auto bias_contiguous = bias.is_contiguous() ? bias : bias.contiguous();
101-
reorder_to_bf16_for_mix_prec(bias_contiguous);
102-
dil_bias = try_gen_dil_tensor(bias_contiguous);
153+
auto bias_contiguous = bias.is_contiguous() ? bias : bias.contiguous();
154+
if (!check_auto_mix_int8_fp32()) {
155+
dbl::comm::reorder_to_bf16_for_mix_prec(bias_contiguous);
156+
}
157+
dil_bias = dbl::comm::try_gen_dil_tensor(bias_contiguous);
103158
}
104159

105-
reorder_to_bf16_for_mix_prec(weight_contiguous);
106160
dbl::conv::prepack_conv_weights(
107161
input_contiguous,
108162
dil_input,
@@ -122,9 +176,14 @@ static at::Tensor& dil_convolution_inplace_fusion(
122176
stride,
123177
dilation,
124178
groups,
125-
attr);
179+
attr,
180+
output_scale);
126181

127-
sync_shape_from_dil_to_aten(accumu, dil_output);
182+
183+
dbl::comm::equip_dil_buffer(accumu, dil_output);
184+
if (check_auto_mix_int8_fp32() && check_int8_calibration()) {
185+
insert_or_updata_observer({input_contiguous}, {accumu}, op_name);
186+
}
128187
return accumu;
129188
}
130189

@@ -203,7 +262,8 @@ at::Tensor AtenIpexJITDev::dil_convolution_relu(
203262
padding,
204263
dilation,
205264
groups,
206-
dil::attr_t::fuse_relu());
265+
dil::attr_t::fuse_relu(),
266+
"Convolution_Relu");
207267
}
208268

209269
at::Tensor AtenIpexJITDev::dil_convolution_elu(
@@ -250,7 +310,8 @@ at::Tensor& AtenIpexJITDev::dil_convolution_sum(
250310
padding,
251311
dilation,
252312
groups,
253-
dil::attr_t::fuse_sum(scale));
313+
dil::attr_t::fuse_sum(scale),
314+
"Convolution_Sum");
254315
}
255316

256317
at::Tensor& AtenIpexJITDev::dil_convolution_sum_relu(
@@ -273,7 +334,8 @@ at::Tensor& AtenIpexJITDev::dil_convolution_sum_relu(
273334
padding,
274335
dilation,
275336
groups,
276-
dil::attr_t::residual(scale));
337+
dil::attr_t::residual(scale),
338+
"Convolution_Sum_Relu");
277339
}
278340

279341
at::Tensor AtenIpexJITDev::dil_linear_fuse_relu(

torch_ipex/csrc/cpu/dbl/Conv.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ dil::tensor convolution_impl(
5656
if (dil::data_type::s8 == x.get_data_type()) {
5757
alowp_kind = dil::s8s8;
5858
}
59+
dil::prop_kind aprop_kind = dil::prop_kind::forward;
60+
if (dil::data_type::s8 == x.get_data_type() || dil::data_type::u8 == x.get_data_type()) {
61+
aprop_kind = dil::prop_kind::forward_inference;
62+
}
5963

6064
dil::tensor y;
6165
if (b.has_value()) {
@@ -75,7 +79,7 @@ dil::tensor convolution_impl(
7579
dst_scales,
7680
attr,
7781
dil::algorithm::convolution_direct,
78-
dil::prop_kind::forward,
82+
aprop_kind,
7983
alowp_kind);
8084
} else {
8185
dil::convolution_forward::compute(
@@ -93,7 +97,7 @@ dil::tensor convolution_impl(
9397
dst_scales,
9498
attr,
9599
dil::algorithm::convolution_direct,
96-
dil::prop_kind::forward,
100+
aprop_kind,
97101
alowp_kind);
98102
}
99103
return y;
@@ -131,6 +135,10 @@ void convolution_inplace_impl(
131135
if (dil::data_type::s8 == x.get_data_type()) {
132136
alowp_kind = dil::s8s8;
133137
}
138+
dil::prop_kind aprop_kind = dil::prop_kind::forward;
139+
if (dil::data_type::s8 == x.get_data_type() || dil::data_type::u8 == x.get_data_type()) {
140+
aprop_kind = dil::prop_kind::forward_inference;
141+
}
134142

135143
if (b.has_value()) {
136144
dil::convolution_forward::compute(
@@ -149,7 +157,7 @@ void convolution_inplace_impl(
149157
dst_scales,
150158
attr,
151159
dil::algorithm::convolution_direct,
152-
dil::prop_kind::forward,
160+
aprop_kind,
153161
alowp_kind);
154162
} else {
155163
dil::convolution_forward::compute(
@@ -167,7 +175,7 @@ void convolution_inplace_impl(
167175
dst_scales,
168176
attr,
169177
dil::algorithm::convolution_direct,
170-
dil::prop_kind::forward,
178+
aprop_kind,
171179
alowp_kind);
172180
}
173181
}

0 commit comments

Comments
 (0)