Skip to content

Commit 94bd248

Browse files
committed
make observer can access multi inputs and outputs
1 parent 6a8c6ae commit 94bd248

File tree

8 files changed

+178
-143
lines changed

8 files changed

+178
-143
lines changed

torch_ipex/csrc/auto_opt_config.h

Lines changed: 65 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -61,27 +61,35 @@ class AutoOptConfig {
6161
}
6262

6363
inline void insert_or_updata_observer(std::string op_name,
64-
std::vector<float> input_min_max_values, std::vector<float> output_min_max_values) {
64+
std::vector<std::vector<float>> i_min_max_values, std::vector<std::vector<float>> o_min_max_values) {
6565
num_ops_id++;
6666
if (observers_.size() < num_ops_id) {
6767
// this path is that user not set int8 op's configure, using default configures
68-
Observer new_observer = {num_ops_id - 1, op_name, input_min_max_values, output_min_max_values};
68+
Observer new_observer = {num_ops_id - 1, op_name, i_min_max_values, o_min_max_values};
6969
observers_.push_back(new_observer);
7070
} else {
7171
// user has set configure or have run one interation
72-
auto input_pre = observers_[num_ops_id - 1].Input_min_max_values;
73-
auto output_pre = observers_[num_ops_id - 1].Output_min_max_values;
74-
if (observers_[num_ops_id - 1].Algorithm == "min_max") {
75-
observers_[num_ops_id - 1].Input_min_max_values[0] = std::min(input_pre[0], input_min_max_values[0]);
76-
observers_[num_ops_id - 1].Input_min_max_values[1] = std::max(input_pre[1], input_min_max_values[1]);
77-
observers_[num_ops_id - 1].Output_min_max_values[0] = std::min(output_pre[0], output_min_max_values[0]);
78-
observers_[num_ops_id - 1].Output_min_max_values[1] = std::max(output_pre[1], output_min_max_values[1]);
79-
} else if(observers_[num_ops_id -1].Algorithm == "moving_averager_min_max"){
80-
auto c = observers_[num_ops_id - 1].Averaging_constant;
81-
observers_[num_ops_id - 1].Input_min_max_values[0] = (1 - c) * input_pre[0] + c * input_min_max_values[0];
82-
observers_[num_ops_id - 1].Input_min_max_values[1] = (1 - c) * input_pre[1] + c * input_min_max_values[1];
83-
observers_[num_ops_id - 1].Output_min_max_values[0] = (1 - c) * output_pre[0] + c * output_min_max_values[0];
84-
observers_[num_ops_id - 1].Output_min_max_values[1] = (1 - c) * output_pre[1] + c * output_min_max_values[1];
72+
auto inputs_pre = observers_[num_ops_id - 1].inputs_min_max_values;
73+
auto outputs_pre = observers_[num_ops_id - 1].outputs_min_max_values;
74+
if (observers_[num_ops_id - 1].algorithm == "min_max") {
75+
for (auto i = 0; i < i_min_max_values.size(); i++) {
76+
observers_[num_ops_id - 1].inputs_min_max_values[i][0] = std::min(inputs_pre[i][0], i_min_max_values[i][0]);
77+
observers_[num_ops_id - 1].inputs_min_max_values[i][1] = std::max(inputs_pre[i][1], i_min_max_values[i][1]);
78+
}
79+
for (auto j = 0; j < o_min_max_values.size(); j++) {
80+
observers_[num_ops_id - 1].outputs_min_max_values[j][0]= std::min(outputs_pre[j][0], o_min_max_values[j][0]);
81+
observers_[num_ops_id - 1].outputs_min_max_values[j][1] = std::max(outputs_pre[j][1], o_min_max_values[j][1]);
82+
}
83+
} else if(observers_[num_ops_id -1].algorithm == "moving_averager_min_max"){
84+
auto c = observers_[num_ops_id - 1].averaging_constant;
85+
for (auto i = 0; i < i_min_max_values.size(); i++) {
86+
observers_[num_ops_id - 1].inputs_min_max_values[i][0] = (1 - c) * inputs_pre[i][0] + c * i_min_max_values[i][0];
87+
observers_[num_ops_id - 1].inputs_min_max_values[i][1] = (1 - c) * inputs_pre[i][1] + c * i_min_max_values[i][1];
88+
}
89+
for (auto j = 0; j < o_min_max_values.size(); j++) {
90+
observers_[num_ops_id - 1].outputs_min_max_values[j][0] = (1 - c) * outputs_pre[j][0] + c * o_min_max_values[j][0];
91+
observers_[num_ops_id - 1].outputs_min_max_values[j][1] = (1 - c) * outputs_pre[j][1] + c * o_min_max_values[j][1];
92+
}
8593
}
8694
}
8795
}
@@ -93,63 +101,82 @@ class AutoOptConfig {
93101
std::cout<<observers_[i].max_values[j]<<std::endl;
94102
}
95103
}
96-
*/
97104
inline void print_indicator() {
98105
for (auto i = 0; i< indicators_.size(); i++) {
99106
auto scales = indicators_[i].get_indicator_scales();
100107
for (auto j = 0; j< scales.size(); j++)
101108
std::cout<<scales[j]<<std::endl;
102109
}
103110
}
111+
*/
104112

105113
inline void add_indicators() {
106114
num_ops_id = 0;
107115
// default used is s8
108116
for (auto i = 0; i < observers_.size(); i++) {
109-
std::vector<float> scales;
110-
std::vector<float> input_values = observers_[i].Input_min_max_values;
111-
std::vector<float> output_values = observers_[i].Output_min_max_values;
117+
std::vector<float> inputs_scale, outputs_scale;
118+
std::vector<std::vector<float>> inputs_values = observers_[i].inputs_min_max_values;
119+
std::vector<std::vector<float>> outputs_values = observers_[i].outputs_min_max_values;
112120

113-
scales.push_back(127.5 / std::max(std::abs(input_values[0]), input_values[1]));
114-
scales.push_back(127.5 / std::max(std::abs(output_values[0]), output_values[1]));
121+
for (auto i = 0; i < inputs_values.size(); i++) {
122+
inputs_scale.push_back(127.5 / std::max(std::abs(inputs_values[i][0]), inputs_values[i][1]));
123+
}
124+
for (auto j = 0; j < outputs_values.size(); j++ ) {
125+
outputs_scale.push_back(127.5 / std::max(std::abs(outputs_values[j][0]), outputs_values[j][1]));
126+
}
115127
// zero_points not used now, zero_points = 0 for u8 and 128 for s8.
116128
//zero_point = 128;
117-
Indicator new_indicator(observers_[i].Id, observers_[i].Name, observers_[i].Algorithm,
118-
observers_[i].Weight_granularity, scales, {observers_[i].Input_dtype_uint8, observers_[i].Output_dtype_uint8},
119-
observers_[i].Quantized);
129+
Indicator new_indicator(observers_[i].id, observers_[i].name, observers_[i].algorithm,
130+
observers_[i].weight_granularity, inputs_scale, outputs_scale, observers_[i].inputs_dtype_uint8,
131+
observers_[i].outputs_dtype_uint8, observers_[i].quantized);
120132
indicators_.push_back(new_indicator);
121133
}
122134
observers_.clear();
123135
}
124136

125-
inline std::tuple<std::vector<float>, bool> get_indicator_scales(std::vector<bool> uint8_used) {
137+
inline std::tuple<std::vector<std::vector<float>>, bool> get_indicator_scales(std::vector<bool> i_uint8_used, std::vector<bool> o_uint8_used) {
126138
if (num_ops_id > indicators_.size() - 1) num_ops_id = 0;
127139

128-
auto indicator_uint8_used = indicators_[num_ops_id].get_indicator_uint8_status();
129-
std::vector<float> indicator_scales;
140+
std::vector<float> inputs_scale, outputs_scale;
141+
std::vector<bool> inputs_uint8_used, outputs_uint8_used;
130142
bool quantized_status;
131-
indicator_scales = indicators_[num_ops_id].get_indicator_scales();
143+
std::tie(inputs_uint8_used, outputs_uint8_used) = indicators_[num_ops_id].get_indicator_uint8_status();
144+
std::tie(inputs_scale, outputs_scale) = indicators_[num_ops_id].get_indicator_scales();
132145
quantized_status = indicators_[num_ops_id].get_indicator_quantized_status();
133146
bool scale_update = false;
134-
for (auto i = 0; i < uint8_used.size(); i++) {
135-
if (!indicator_uint8_used[i] && uint8_used[i]) {
147+
for (auto i = 0; i < i_uint8_used.size(); i++) {
148+
if (!inputs_uint8_used[i] && i_uint8_used[i]) {
149+
// update zero_point and scales
150+
inputs_scale[i] /= 127.5;
151+
inputs_scale[i] *= 255.5;
152+
scale_update = true;
153+
} else if (inputs_uint8_used[i] && !i_uint8_used[i]) {
154+
// update zero_point and scales
155+
inputs_scale[i] /= 255.5;
156+
inputs_scale[i] *= 127.5;
157+
scale_update = true;
158+
}
159+
}
160+
for (auto j = 0; j < o_uint8_used.size(); j++) {
161+
if (!outputs_uint8_used[j] && o_uint8_used[j]) {
136162
// update zero_point and scales
137-
indicator_scales[i] /= 127.5;
138-
indicator_scales[i] *= 255.5;
163+
outputs_scale[j] /= 127.5;
164+
outputs_scale[j] *= 255.5;
139165
scale_update = true;
140-
} else if (indicator_uint8_used[i] && !uint8_used[i]) {
166+
} else if (outputs_uint8_used[j] && !o_uint8_used[j]) {
141167
// update zero_point and scales
142-
indicator_scales[i] /= 255.5;
143-
indicator_scales[i] *= 127.5;
168+
outputs_scale[j] /= 255.5;
169+
outputs_scale[j] *= 127.5;
144170
scale_update = true;
145171
}
146172
}
147173
if (scale_update) {
148-
indicators_[num_ops_id].set_indicator_scales(indicator_scales);
149-
indicators_[num_ops_id].set_indicator_uint8_status(uint8_used);
174+
indicators_[num_ops_id].set_indicator_scales(inputs_scale, outputs_scale);
175+
indicators_[num_ops_id].set_indicator_uint8_status(inputs_uint8_used, outputs_uint8_used);
150176
}
151177
num_ops_id++;
152-
return std::make_tuple(indicator_scales, quantized_status);
178+
std::vector<std::vector<float>> input_output_scale = {inputs_scale, outputs_scale};
179+
return std::make_tuple(input_output_scale, quantized_status);
153180
}
154181

155182
void set_indicators(std::vector<Indicator> indicators) {

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 35 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,13 @@ at::Tensor AtenIpexCPUDev::dil_convolution(
5454

5555
std::vector<float> output_scale = {};
5656
if (check_auto_mix_int8_fp32() && !check_int8_calibration()) {
57-
std::vector<float> scales;
57+
std::vector<std::vector<float>> scales;
5858
bool quantized;
59-
std::tie(scales, quantized) = dbl::comm::get_int8_scales(input, /* uint8_used for output*/false);
59+
std::tie(scales, quantized) = dbl::comm::get_int8_scales({input}, /* uint8_used for output*/false);
6060
//quantized = false;
6161
if (quantized) {
62-
output_scale.push_back(scales[1]);
63-
dbl::comm::reorder_to_int8_for_mix_prec(input, {scales[0]});
62+
output_scale.push_back(scales[1][0]);
63+
dbl::comm::reorder_to_int8_for_mix_prec(input, scales[0]);
6464
dbl::comm::reorder_to_int8_for_mix_prec(weight, {});
6565
} else {
6666
dbl::comm::reorder_to_dtype(input, at::kFloat);
@@ -103,7 +103,7 @@ at::Tensor AtenIpexCPUDev::dil_convolution(
103103
auto aten_output = dbl::comm::gen_aten_tensor_by(std::move(dil_output));
104104

105105
if (check_auto_mix_int8_fp32() && check_int8_calibration()) {
106-
insert_or_updata_observer(input, aten_output, "Convolution");
106+
insert_or_updata_observer({input}, {aten_output}, "Convolution");
107107
}
108108

109109
return aten_output;
@@ -761,13 +761,13 @@ at::Tensor AtenIpexCPUDev::dil_linear(
761761

762762
std::vector<float> output_scale = {};
763763
if (check_auto_mix_int8_fp32() && !check_int8_calibration()) {
764-
std::vector<float> scales;
764+
std::vector<std::vector<float>> scales;
765765
bool quantized;
766-
std::tie(scales, quantized) = dbl::comm::get_int8_scales(self, /* uint8_used for output*/false);
766+
std::tie(scales, quantized) = dbl::comm::get_int8_scales({self}, /* uint8_used for output*/false);
767767
//quantized = false;
768768
if (quantized) {
769-
output_scale.push_back(scales[1]);
770-
dbl::comm::reorder_to_int8_for_mix_prec(self, {scales[0]});
769+
output_scale.push_back(scales[1][0]);
770+
dbl::comm::reorder_to_int8_for_mix_prec(self, scales[0]);
771771
dbl::comm::reorder_to_int8_for_mix_prec(weight, {});
772772
} else {
773773
dbl::comm::reorder_to_dtype(self, at::kFloat);
@@ -797,7 +797,7 @@ at::Tensor AtenIpexCPUDev::dil_linear(
797797
auto aten_output = dbl::comm::gen_aten_tensor_by(std::move(y));
798798

799799
if (check_auto_mix_int8_fp32() && check_int8_calibration()) {
800-
insert_or_updata_observer(self, aten_output, "Linear");
800+
insert_or_updata_observer({self}, {aten_output}, "Linear");
801801
}
802802

803803
if (self.dim() > 2) {
@@ -955,12 +955,12 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> AtenIpexCPUDev::dil_native_batch_
955955
std::vector<float> output_scales = {};
956956
bool quantized = false;
957957
if (check_auto_mix_int8_fp32() && !check_int8_calibration()) {
958-
std::vector<float> scales;
959-
std::tie(scales, quantized) = dbl::comm::get_int8_scales(input, /* uint8_used for output*/false);
958+
std::vector<std::vector<float>> scales;
959+
std::tie(scales, quantized) = dbl::comm::get_int8_scales({input}, /* uint8_used for output*/false);
960960
//quantized = false;
961961
if (quantized) {
962-
input_scales.push_back(scales[0]);
963-
output_scales.push_back(scales[1]);
962+
input_scales = scales[0];
963+
output_scales = scales[1];
964964
dbl::comm::reorder_to_int8_for_mix_prec(input, input_scales);
965965
} else {
966966
dbl::comm::reorder_to_dtype(input, at::kFloat);
@@ -1005,9 +1005,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> AtenIpexCPUDev::dil_native_batch_
10051005

10061006
auto aten_output = dbl::comm::gen_aten_tensor_by(std::move(y));
10071007

1008-
//dbl::comm::reorder_to_dtype(aten_output, at::kFloat);
10091008
if (check_auto_mix_int8_fp32() && check_int8_calibration()) {
1010-
insert_or_updata_observer(input, aten_output, "BatchNorm");
1009+
insert_or_updata_observer({input}, {aten_output}, "BatchNorm");
10111010
}
10121011

10131012
return std::make_tuple(aten_output, at::Tensor(), at::Tensor());
@@ -1060,12 +1059,12 @@ at::Tensor AtenIpexCPUDev::dil_max_pooling(
10601059
DEBUG("AtenIpexCPUDev::dil_max_pooling\n");
10611060
CHECK_DNNL_OP_PRE_COND(input);
10621061
if (check_auto_mix_int8_fp32() && !check_int8_calibration()) {
1063-
std::vector<float> scales;
1062+
std::vector<std::vector<float>> scales;
10641063
bool quantized;
1065-
std::tie(scales, quantized) = dbl::comm::get_int8_scales(input, /* uint8_used for output*/false);
1064+
std::tie(scales, quantized) = dbl::comm::get_int8_scales({input}, /* uint8_used for output*/false);
10661065
//quantized = false;
10671066
if (quantized) {
1068-
dbl::comm::reorder_to_int8_for_mix_prec(input, {scales[0]});
1067+
dbl::comm::reorder_to_int8_for_mix_prec(input, scales[0]);
10691068
} else {
10701069
dbl::comm::reorder_to_dtype(input, at::kFloat);
10711070
}
@@ -1074,7 +1073,7 @@ at::Tensor AtenIpexCPUDev::dil_max_pooling(
10741073
}
10751074

10761075
if (check_auto_mix_int8_fp32() && check_int8_calibration()) {
1077-
insert_or_updata_observer(input, at::Tensor(), "MaxPooling");
1076+
insert_or_updata_observer({input}, {input}, "MaxPooling");
10781077
}
10791078
return dbl::pool::_dil_pooling(
10801079
input,
@@ -1100,12 +1099,12 @@ at::Tensor AtenIpexCPUDev::dil_avg_pool2d(
11001099
"dil_avg_pooling operator does not support divisor");
11011100

11021101
if (check_auto_mix_int8_fp32() && !check_int8_calibration()) {
1103-
std::vector<float> scales;
1102+
std::vector<std::vector<float>> scales;
11041103
bool quantized;
1105-
std::tie(scales, quantized) = dbl::comm::get_int8_scales(input, /* uint8_used for output*/false);
1104+
std::tie(scales, quantized) = dbl::comm::get_int8_scales({input}, /* uint8_used for output*/false);
11061105
//quantized = false;
11071106
if (quantized) {
1108-
dbl::comm::reorder_to_int8_for_mix_prec(input, {scales[0]});
1107+
dbl::comm::reorder_to_int8_for_mix_prec(input, scales[0]);
11091108
} else {
11101109
dbl::comm::reorder_to_dtype(input, at::kFloat);
11111110
}
@@ -1114,7 +1113,7 @@ at::Tensor AtenIpexCPUDev::dil_avg_pool2d(
11141113
}
11151114

11161115
if (check_auto_mix_int8_fp32() && check_int8_calibration()) {
1117-
insert_or_updata_observer(input, at::Tensor(), "AvgPool2d");
1116+
insert_or_updata_observer({input}, {input}, "AvgPool2d");
11181117
}
11191118

11201119
return dbl::pool::_dil_pooling(
@@ -1161,12 +1160,12 @@ at::Tensor AtenIpexCPUDev::dil_adaptive_avg_pool2d(
11611160
CHECK_DNNL_OP_PRE_COND(input);
11621161

11631162
if (check_auto_mix_int8_fp32() && !check_int8_calibration()) {
1164-
std::vector<float> scales;
1163+
std::vector<std::vector<float>> scales;
11651164
bool quantized;
1166-
std::tie(scales, quantized) = dbl::comm::get_int8_scales(input, /* uint8_used for output*/false);
1165+
std::tie(scales, quantized) = dbl::comm::get_int8_scales({input}, /* uint8_used for output*/false);
11671166
//quantized = false;
11681167
if (quantized) {
1169-
dbl::comm::reorder_to_int8_for_mix_prec(input, {scales[0]});
1168+
dbl::comm::reorder_to_int8_for_mix_prec(input, scales[0]);
11701169
} else {
11711170
dbl::comm::reorder_to_dtype(input, at::kFloat);
11721171
}
@@ -1195,7 +1194,7 @@ at::Tensor AtenIpexCPUDev::dil_adaptive_avg_pool2d(
11951194
}
11961195

11971196
if (check_auto_mix_int8_fp32() && check_int8_calibration()) {
1198-
insert_or_updata_observer(input, at::Tensor(), "AdaptiveAvgPool2d");
1197+
insert_or_updata_observer({input}, {input}, "AdaptiveAvgPool2d");
11991198
}
12001199
return dbl::pool::_dil_pooling(
12011200
input,
@@ -1343,12 +1342,12 @@ at::Tensor AtenIpexCPUDev::dil_relu(const at::Tensor& input) {
13431342
DEBUG("AtenIpexCPUDev::dil_relu\n");
13441343
CHECK_DNNL_OP_PRE_COND(input);
13451344
if (check_auto_mix_int8_fp32() && !check_int8_calibration()) {
1346-
std::vector<float> scales;
1345+
std::vector<std::vector<float>> scales;
13471346
bool quantized;
1348-
std::tie(scales, quantized)= dbl::comm::get_int8_scales(input, /* uint8_used for output*/true);
1347+
std::tie(scales, quantized)= dbl::comm::get_int8_scales({input}, /* uint8_used for output*/true);
13491348
//quantized = false;
13501349
if (quantized) {
1351-
dbl::comm::reorder_to_int8_for_mix_prec(input, {scales[0]});
1350+
dbl::comm::reorder_to_int8_for_mix_prec(input, scales[0]);
13521351
} else {
13531352
dbl::comm::reorder_to_dtype(input, at::kFloat);
13541353
}
@@ -1362,7 +1361,7 @@ at::Tensor AtenIpexCPUDev::dil_relu(const at::Tensor& input) {
13621361
x, y, dil::algorithm::eltwise_relu, dil::prop_kind::forward_training, /*alpha*/ 0.0);
13631362

13641363
if (check_auto_mix_int8_fp32() && check_int8_calibration()) {
1365-
insert_or_updata_observer(input, at::Tensor(), "Relu");
1364+
insert_or_updata_observer({input}, {input}, "Relu");
13661365
}
13671366

13681367
return dbl::comm::gen_aten_tensor_by(std::move(y));
@@ -1373,12 +1372,12 @@ at::Tensor& AtenIpexCPUDev::dil_relu_(at::Tensor& input) {
13731372
CHECK_DNNL_OP_PRE_COND(input);
13741373

13751374
if (check_auto_mix_int8_fp32() && !check_int8_calibration()) {
1376-
std::vector<float> scales;
1375+
std::vector<std::vector<float>> scales;
13771376
bool quantized;
1378-
std::tie(scales, quantized) = dbl::comm::get_int8_scales(input, /* uint8_used for output*/true);
1377+
std::tie(scales, quantized) = dbl::comm::get_int8_scales({input}, /* uint8_used for output*/true);
13791378
//quantized = false;
13801379
if (quantized) {
1381-
dbl::comm::reorder_to_int8_for_mix_prec(input, {scales[0]});
1380+
dbl::comm::reorder_to_int8_for_mix_prec(input, scales[0]);
13821381
} else {
13831382
dbl::comm::reorder_to_dtype(input, at::kFloat);
13841383
}
@@ -1387,7 +1386,7 @@ at::Tensor& AtenIpexCPUDev::dil_relu_(at::Tensor& input) {
13871386
}
13881387

13891388
if (check_auto_mix_int8_fp32() && check_int8_calibration()) {
1390-
insert_or_updata_observer(input, at::Tensor(), "Relu_");
1389+
insert_or_updata_observer({input}, {input}, "Relu_");
13911390
}
13921391

13931392
auto dil_self = dbl::comm::try_gen_dil_tensor(input);

torch_ipex/csrc/cpu/dbl/Common.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,16 @@ dil::tensor reorder_dil_tensor_to_dtype(const dil::tensor &dil_tensor, dil::data
110110
return dst;
111111
}
112112

113-
std::tuple<std::vector<float>, bool> get_int8_scales(const at::Tensor& input, bool uint8_used) {
113+
std::tuple<std::vector<std::vector<float>>, bool> get_int8_scales(const at::TensorList& inputs, bool uint8_used) {
114114
if (check_auto_mix_int8_fp32() && !check_int8_calibration()) {
115-
auto src_dil_type = try_gen_dil_tensor(input).get_data_type();
116-
bool input_uint8_used = (src_dil_type == dil::data_type::u8);
117-
return get_indicator_scales({input_uint8_used, uint8_used});
115+
std::vector<bool> inputs_uint8_used;
116+
for (auto i = 0; i < inputs.size(); i++) {
117+
auto src_dil_type = try_gen_dil_tensor(inputs[i]).get_data_type();
118+
inputs_uint8_used.push_back(src_dil_type == dil::data_type::u8);
119+
}
120+
return get_indicator_scales(inputs_uint8_used, {uint8_used});
118121
} else {
119-
return std::make_tuple(std::vector<float>(), false);
122+
return std::make_tuple(std::vector<std::vector<float>>(), false);
120123
}
121124
}
122125

torch_ipex/csrc/cpu/dbl/Common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace comm {
1818
*/
1919
void reorder_to_bf16_for_mix_prec(const at::Tensor& tensor, bool not_reorder_for_training = false);
2020

21-
std::tuple<std::vector<float>, bool> get_int8_scales(const at::Tensor& tensor, bool uint8_used);
21+
std::tuple<std::vector<std::vector<float>>, bool> get_int8_scales(const at::TensorList& tensor, bool uint8_used);
2222

2323
void reorder_to_int8_for_mix_prec(const at::Tensor& tensor, std::vector<float> scales, bool uint8_used = false);
2424

0 commit comments

Comments
 (0)