Skip to content

Commit c9c4bc7

Browse files
authored
[cherry pick] full api fp16 support and quant_dequant_pass fix (#9654)
* [Pass] fix quant_dequant_op_fuser for input Scale (#9577) * [Pass] fix memory_optimize_pass for xshape (#9572) * fix full api for fp16 on ARM (#9553) * [ARM] fix sharing zero point between quant_linear and dequant_linear op (#9570)
1 parent d3fdf17 commit c9c4bc7

File tree

6 files changed

+91
-13
lines changed

6 files changed

+91
-13
lines changed

lite/api/cxx_api.cc

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323

2424
#include "lite/api/paddle_use_passes.h"
2525
#include "lite/utils/io.h"
26+
#ifdef ENABLE_ARM_FP16
27+
#include "lite/backends/arm/math/fp16/type_trans_fp16.h"
28+
#endif
2629

2730
namespace paddle {
2831
namespace lite {
@@ -297,6 +300,54 @@ const cpp::ProgramDesc &Predictor::program_desc() const {
297300
}
298301
const RuntimeProgram &Predictor::runtime_program() const { return *program_; }
299302

303+
#ifdef ENABLE_ARM_FP16
304+
typedef __fp16 float16_t;
305+
void Predictor::WeightFP32ToFP16() {
306+
std::shared_ptr<const cpp::ProgramDesc> program_desc = program_desc_;
307+
std::vector<std::string> fp16_ops{"conv2d",
308+
"depthwise_conv2d",
309+
"conv2d_transpose",
310+
"fc",
311+
"mul",
312+
"matmul",
313+
"matmul_v2",
314+
"gru",
315+
"sequence_conv",
316+
"elementwise_add",
317+
"elementwise_sub",
318+
"elementwise_div",
319+
"elementwise_mul",
320+
"prelu"};
321+
for (size_t i = 0; i < program_desc->BlocksSize(); i++) {
322+
auto *block = program_desc->GetBlock<cpp::BlockDesc>(i);
323+
for (size_t k = 0; k < block->OpsSize(); ++k) {
324+
auto *op_desc = block->GetOp<cpp::OpDesc>(k);
325+
std::string op_type = op_desc->Type();
326+
auto iter = std::find(fp16_ops.begin(), fp16_ops.end(), op_type);
327+
if (iter != fp16_ops.end()) {
328+
auto input_names = op_desc->input_vars();
329+
for (auto &input_name : input_names) {
330+
std::string input_weight_name = input_name + "_fp16";
331+
if (op_desc->HasAttr(input_weight_name)) { // the input is fp16
332+
Tensor tmp_tensor;
333+
auto input_tensor =
334+
scope_->FindVar(input_name)->GetMutable<lite::Tensor>();
335+
if (input_tensor->precision() != PRECISION(kFloat)) continue;
336+
tmp_tensor.CopyDataFrom(*input_tensor);
337+
input_tensor->clear();
338+
input_tensor->set_precision(PRECISION(kFP16));
339+
float16_t *fp_data = input_tensor->mutable_data<float16_t>();
340+
const float *in_data = tmp_tensor.data<float>();
341+
lite::arm::math::fp16::fp32_to_fp16(
342+
in_data, fp_data, input_tensor->numel());
343+
}
344+
}
345+
}
346+
}
347+
}
348+
}
349+
#endif // ENABLE_ARM_FP16
350+
300351
void Predictor::Build(const lite_api::CxxConfig &config,
301352
const std::vector<Place> &valid_places,
302353
const std::vector<std::string> &passes,
@@ -413,6 +464,11 @@ void Predictor::Build(const std::shared_ptr<cpp::ProgramDesc> &program_desc,
413464

414465
// Update the runtime program to program_desc only once
415466
program_->SaveRuntimProgramIntoProgramDesc(program_desc_);
467+
468+
#ifdef ENABLE_ARM_FP16
469+
// fp16 Weight convert
470+
WeightFP32ToFP16();
471+
#endif
416472
}
417473

418474
void Predictor::GenRuntimeProgram() {

lite/api/cxx_api.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,9 @@ class LITE_API Predictor {
252252

253253
void ClearTensorArray(
254254
const std::shared_ptr<const cpp::ProgramDesc>& program_desc);
255+
#ifdef ENABLE_ARM_FP16
256+
void WeightFP32ToFP16();
257+
#endif
255258

256259
private:
257260
std::shared_ptr<cpp::ProgramDesc> program_desc_;

lite/core/optimizer/mir/fusion/quant_dequant_fuse_pass.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,11 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
7171

7272
// process new quant op pass: quantize_linear and dequantize_linear
7373
// pass1: input+quantize_linear+dequantize_linear --> input
74-
fusion::QuantDequantLinearOpFuser quant_dequant_linear_fuser;
75-
quant_dequant_linear_fuser(graph.get());
74+
for (auto share_zero_point : {true, false}) {
75+
fusion::QuantDequantLinearOpFuser quant_dequant_linear_fuser(
76+
share_zero_point);
77+
quant_dequant_linear_fuser(graph.get());
78+
}
7679
// pass2: weight+dequantize_linear --> weight
7780
fusion::DequantLinearOpFuser dequantize_linear_fuser;
7881
dequantize_linear_fuser(graph.get());

lite/core/optimizer/mir/fusion/quant_dequant_op_fuser.cc

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -650,9 +650,6 @@ void QuantDequantLinearOpFuser::BuildPattern() {
650650
->assert_is_op_input("quantize_linear", "ZeroPoint");
651651
auto* quant_op_output =
652652
VarNode("quant_op_output")->assert_is_op_output("quantize_linear", "Y");
653-
auto* dequant_op_zero_point =
654-
VarNode("dequant_op_zero_point")
655-
->assert_is_op_input("dequantize_linear", "ZeroPoint");
656653
auto* dequant_op_out =
657654
VarNode("dequant_op_out")->assert_is_op_output("dequantize_linear", "Y");
658655

@@ -663,9 +660,19 @@ void QuantDequantLinearOpFuser::BuildPattern() {
663660

664661
quant_op->LinksFrom({quant_op_input, quant_op_scale, quant_op_zero_point})
665662
.LinksTo({quant_op_output});
666-
dequant_op
667-
->LinksFrom({quant_op_output, quant_op_scale, dequant_op_zero_point})
668-
.LinksTo({dequant_op_out});
663+
664+
if (shared_zero_point_) {
665+
dequant_op
666+
->LinksFrom({quant_op_output, quant_op_scale, quant_op_zero_point})
667+
.LinksTo({dequant_op_out});
668+
} else {
669+
auto* dequant_op_zero_point =
670+
VarNode("dequant_op_zero_point")
671+
->assert_is_op_input("dequantize_linear", "ZeroPoint");
672+
dequant_op
673+
->LinksFrom({quant_op_output, quant_op_scale, dequant_op_zero_point})
674+
.LinksTo({dequant_op_out});
675+
}
669676
VLOG(4) << "QuantDequantLinearOpFuser";
670677
}
671678

@@ -715,6 +722,7 @@ void QuantDequantLinearOpFuser::InsertNewNode(SSAGraph* graph,
715722
if (!out_scale_node->IsStmt()) continue;
716723
auto* out_scale_scope = out_scale_node->stmt()->op()->scope();
717724
auto* out_scale_op_info = out_scale_node->stmt()->op_info();
725+
if (out_scale_op_info->Type() != "quantize_linear") continue;
718726
if (!out_scale_op_info->HasInput("Scale")) continue;
719727
std::string out_scale_name = out_scale_op_info->Input("Scale").front();
720728
auto* out_scale_tensor =

lite/core/optimizer/mir/fusion/quant_dequant_op_fuser.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,9 @@ class DynamicQuantOpFuser : public FuseBase {
140140
*/
141141
class QuantDequantLinearOpFuser : public FuseBase {
142142
public:
143-
QuantDequantLinearOpFuser() {}
143+
explicit QuantDequantLinearOpFuser(const bool shared_zero_point) {
144+
shared_zero_point_ = shared_zero_point;
145+
}
144146
void BuildPattern() override;
145147
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
146148

@@ -152,6 +154,7 @@ class QuantDequantLinearOpFuser : public FuseBase {
152154
"mul",
153155
"matmul",
154156
"matmul_v2"};
157+
bool shared_zero_point_{};
155158
};
156159

157160
/* The pattern like "dequantize_linear_op + quantized_op "

lite/core/optimizer/mir/memory_optimize_pass.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,12 +147,17 @@ void MemoryOptimizePass::CollectLifeCycleByDevice(
147147
}
148148
if (inplace) {
149149
for (auto& in_param_name : inplace_op_node->second.first) {
150-
const auto& in_arg_names = op_info->Input(in_param_name);
151-
invalid_var_names.insert(in_arg_names.begin(), in_arg_names.end());
150+
if (op_info->HasInput(in_param_name)) {
151+
const auto& in_arg_names = op_info->Input(in_param_name);
152+
invalid_var_names.insert(in_arg_names.begin(), in_arg_names.end());
153+
}
152154
}
153155
for (auto& out_param_name : inplace_op_node->second.second) {
154-
const auto& out_arg_names = op_info->Output(out_param_name);
155-
invalid_var_names.insert(out_arg_names.begin(), out_arg_names.end());
156+
if (op_info->HasOutput(out_param_name)) {
157+
const auto& out_arg_names = op_info->Output(out_param_name);
158+
invalid_var_names.insert(out_arg_names.begin(),
159+
out_arg_names.end());
160+
}
156161
}
157162
}
158163
}

0 commit comments

Comments
 (0)