Skip to content

Commit 559b975

Browse files
authored
Fix ComputePropagateScalesMkldnnPass of MKLDNN (#47574) (#47639)
* add constant_folding_pass pass for mkldnn int8 * update UpdateScaleOpInOutScales
1 parent 75088bb commit 559b975

File tree

3 files changed

+43
-22
lines changed

3 files changed

+43
-22
lines changed

paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.cc

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -336,27 +336,45 @@ void ComputePropagateScalesMkldnnPass::ComputeWeightScales(
336336
ComputeLstmWeightScales(graph, scope, "WeightX", "WeightH", var_quant_scales);
337337
}
338338

339-
void ComputePropagateScalesMkldnnPass::UpdateScaleOpInScale(
339+
void ComputePropagateScalesMkldnnPass::UpdateScaleOpInOutScales(
340340
Node* op_node,
341341
const std::string& input_name,
342342
const std::string& output_name,
343343
StringPairMap* var_quant_scales) const {
344-
auto iter = var_quant_scales->find(output_name);
345-
if (iter != var_quant_scales->end()) {
346-
auto pair = iter->second;
347-
const auto tensor = pair.second;
348-
349-
const auto scale = PADDLE_GET_CONST(float, op_node->Op()->GetAttr("scale"));
350-
Tensor tmp_tensor;
351-
tmp_tensor.Resize(tensor.dims());
352-
auto* data = tmp_tensor.mutable_data<float>(platform::CPUPlace());
353-
for (int i = 0; i < tensor.numel(); i++) {
354-
data[i] = data[i] * scale;
355-
}
344+
auto out_iter = var_quant_scales->find(output_name);
345+
auto input_iter = var_quant_scales->find(input_name);
346+
// All the input and output have scales
347+
if (out_iter != var_quant_scales->end() &&
348+
input_iter != var_quant_scales->end()) {
349+
return;
350+
}
351+
const auto scale = PADDLE_GET_CONST(float, op_node->Op()->GetAttr("scale"));
352+
if (std::abs(scale) < 1e-6 && out_iter != var_quant_scales->end()) {
353+
return;
354+
}
355+
356+
std::string name = input_name;
357+
auto iter = out_iter;
358+
if (input_iter != var_quant_scales->end()) {
359+
iter = input_iter;
360+
name = output_name;
361+
}
356362

357-
auto new_pair = std::make_pair(pair.first, tmp_tensor);
358-
var_quant_scales->insert(std::make_pair(input_name, new_pair));
363+
phi::DenseTensor tmp_tensor;
364+
auto pair = iter->second;
365+
const auto tensor = pair.second;
366+
tmp_tensor.Resize(tensor.dims());
367+
auto* data = tmp_tensor.mutable_data<float>(platform::CPUPlace());
368+
auto* src_data = tensor.data<float>();
369+
for (int i = 0; i < tensor.numel(); i++) {
370+
if (out_iter != var_quant_scales->end()) {
371+
data[i] = src_data[i] / scale;
372+
} else {
373+
data[i] = src_data[i] * scale;
374+
}
359375
}
376+
auto new_pair = std::make_pair(pair.first, tmp_tensor);
377+
var_quant_scales->insert(std::make_pair(name, new_pair));
360378
}
361379

362380
std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales(
@@ -403,10 +421,12 @@ std::unordered_set<std::string> ComputePropagateScalesMkldnnPass::UpdateScales(
403421
}
404422
} else if (op_name == "scale") {
405423
const std::string output_name = op_node->Op()->Output("Out")[0];
424+
const std::string input_name = op_node->Op()->Input("X")[0];
406425
auto out_iter = var_quant_scales->find(output_name);
407-
if (out_iter != var_quant_scales->end()) {
408-
const std::string input_name = op_node->Op()->Input("X")[0];
409-
UpdateScaleOpInScale(
426+
auto input_iter = var_quant_scales->find(input_name);
427+
if (out_iter != var_quant_scales->end() ||
428+
input_iter != var_quant_scales->end()) {
429+
UpdateScaleOpInOutScales(
410430
op_node, input_name, output_name, var_quant_scales);
411431
}
412432
}

paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,10 @@ class ComputePropagateScalesMkldnnPass : public FusePassBase {
7979
void UpdateReluOutputScales(ir::Graph* graph,
8080
StringPairMap* var_quant_scales) const;
8181

82-
void UpdateScaleOpInScale(Node* op_node,
83-
const std::string& input_name,
84-
const std::string& output_name,
85-
StringPairMap* var_quant_scales) const;
82+
void UpdateScaleOpInOutScales(Node* op_node,
83+
const std::string& input_name,
84+
const std::string& output_name,
85+
StringPairMap* var_quant_scales) const;
8686

8787
std::unordered_set<std::string> UpdateScales(
8888
ir::Graph* graph,

paddle/fluid/inference/api/paddle_pass_builder.cc

100644100755
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ void CpuPassStrategy::EnableMkldnnInt8() {
376376
passes_.push_back("quant_dequant_mkldnn_pass");
377377
passes_.push_back("mkldnn_placement_pass");
378378
passes_.push_back("simplify_with_basic_ops_pass");
379+
passes_.push_back("constant_folding_pass");
379380
passes_.push_back("layer_norm_fuse_pass");
380381
passes_.push_back("attention_lstm_fuse_pass");
381382
passes_.push_back("seqconv_eltadd_relu_fuse_pass");

0 commit comments

Comments
 (0)