Skip to content

Commit 954b732

Browse files
authored
fix quant_dequant pass, refine code (#9652) (#9660)
1 parent 9a344f6 commit 954b732

7 files changed

+101
-128
lines changed

lite/core/optimizer/mir/__xpu__static_kernel_pick_pass.cc

Lines changed: 6 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
14+
1415
#include "lite/core/optimizer/mir/__xpu__static_kernel_pick_pass.h"
1516
#include <algorithm>
1617
#include <list>
@@ -19,11 +20,9 @@
1920
#include <string>
2021
#include <utility>
2122
#include <vector>
22-
#ifdef LITE_WITH_XPU
23-
#include "lite/backends/xpu/target_wrapper.h"
24-
#endif
2523
#include "lite/core/optimizer/mir/graph_visualize_pass.h"
2624
#include "lite/core/optimizer/mir/pass_registry.h"
25+
2726
namespace paddle {
2827
namespace lite {
2928
namespace mir {
@@ -41,11 +40,9 @@ void XPUStaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
4140
<< "kernel_pick_factors should be specified first";
4241
CHECK(graph) << "graph not valid";
4342

44-
// Collect input data precision for each node in the graph
45-
// Collect XPU op type,which used in fp16/in8;
46-
#ifdef LITE_WITH_XPU
43+
// Collect input data precision for each node in the graph
44+
// Collect XPU op type,which used in fp16/in8;
4745
DataPrecisionDicide(graph);
48-
GetXPUDeviceType();
4946
if (xpu_use_fp16_optimizer_ || xpu_use_int8_optimizer_) {
5047
CollectXPUSpecialOPType(graph);
5148
for (auto& node : graph->StmtTopologicalOrder()) {
@@ -75,9 +72,7 @@ void XPUStaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
7572
InplaceNodeInputPrecision(node);
7673
}
7774
}
78-
#endif
7975

80-
#ifdef LITE_WITH_XPU
8176
// sort kernels by the factors.
8277
VLOG(2) << "graph block_idx: " << graph->blockIdx();
8378
VLOG(2) << "graph->mutable_nodes().size(): " << graph->mutable_nodes().size();
@@ -155,10 +150,8 @@ void XPUStaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
155150
instruct.mutable_op_info()->SetAttr<std::string>(
156151
"kernel_summary", instruct.kernels().front()->summary());
157152
}
158-
#endif
159153
}
160154

161-
#ifdef LITE_WITH_XPU
162155
void XPUStaticKernelPickPass::DataPrecisionDicide(
163156
const std::unique_ptr<SSAGraph>& graph) {
164157
if (GetStringFromEnv("XPUForceUseFP16", "false") == "true") {
@@ -198,8 +191,6 @@ bool XPUStaticKernelPickPass::ForceUsePrecision(
198191
op_info->GetAttr<bool>("enable_int16");
199192
CHECK(!(int8_quant && int16_quant))
200193
<< "You can only specify one quant type for an OP!";
201-
bool xpu_local_quant =
202-
GetBoolFromEnv("XPU_LOCAL_QUANT") || lite::TargetWrapperXPU::local_quant;
203194

204195
if (instruct.op_type() == "__xpu__fc") {
205196
if (int8_quant && kernel.alias() == "XPU_Int8_FP32_FP32") {
@@ -210,12 +201,11 @@ bool XPUStaticKernelPickPass::ForceUsePrecision(
210201
*score *= 4;
211202
VLOG(6) << "__xpu__fc: force use PRECISON INT16: *4";
212203
return true;
213-
} else if (xpu_local_quant && kernel.alias() == "XPU_FP32_LOCAL_QUANT") {
204+
} else if (local_quant_ && kernel.alias() == "XPU_FP32_LOCAL_QUANT") {
214205
*score *= 4;
215206
VLOG(6) << "__xpu__fc: force use LOCAL QUANT: *4";
216207
return true;
217-
} else if ((GetStringFromEnv("XPU_ENCODER_PRECISION", "int16") == "int31" ||
218-
lite::TargetWrapperXPU::multi_encoder_precision == "int31") &&
208+
} else if (encode_precision_ == "int31" &&
219209
kernel.alias() == "XPU_Real_kFloat") {
220210
*score *= 4;
221211
VLOG(6) << "__xpu__fc: force use PRECISON INT31: *4";
@@ -723,27 +713,6 @@ void XPUStaticKernelPickPass::SpecialOpScore(lite::mir::Node* node,
723713
*score += score_tmp_all;
724714
}
725715

726-
void XPUStaticKernelPickPass::GetXPUDeviceType() {
727-
int cur_dev_idx = 0;
728-
uint64_t cur_dev_attr = 0;
729-
730-
XPU_CALL(xpu_current_device(&cur_dev_idx));
731-
XPU_CALL(xpu_device_get_attr(&cur_dev_attr, XPUATTR_MODEL, cur_dev_idx));
732-
if (cur_dev_attr <= 1) {
733-
VLOG(4) << "Currents XPU device : XPU1";
734-
xpu_disable_flag_ = "DISABLE_XPU1";
735-
} else if (cur_dev_attr >= 2 && cur_dev_attr <= 299) {
736-
VLOG(4) << "Currents XPU device : XPU2";
737-
xpu_disable_flag_ = "DISABLE_XPU2";
738-
} else if (cur_dev_attr >= 300 && cur_dev_attr <= 599) {
739-
VLOG(4) << "Currents XPU device : XPU3";
740-
xpu_disable_flag_ = "DISABLE_XPU3";
741-
} else {
742-
VLOG(4) << "invaid XPU device";
743-
xpu_disable_flag_ = "NONE";
744-
}
745-
}
746-
747716
void XPUStaticKernelPickPass::GradeXPUKernelScore(
748717
lite::mir::Node* node,
749718
const lite::KernelBase& kernel,
@@ -846,7 +815,6 @@ void XPUStaticKernelPickPass::CollectXPUSpecialOPType(
846815
return;
847816
}
848817

849-
#endif
850818
} // namespace mir
851819
} // namespace lite
852820
} // namespace paddle

lite/core/optimizer/mir/__xpu__static_kernel_pick_pass.h

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,17 @@
1111
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
14+
1415
#pragma once
1516
#include <limits>
1617
#include <map>
1718
#include <memory>
1819
#include <set>
1920
#include <string>
2021
#include <vector>
22+
#ifdef LITE_WITH_XPU
23+
#include "lite/backends/xpu/target_wrapper.h"
24+
#endif
2125
#include "lite/core/optimizer/mir/pass.h"
2226
#include "lite/core/types.h"
2327

@@ -38,6 +42,36 @@ namespace mir {
3842
*/
3943
class XPUStaticKernelPickPass : public mir::StmtPass {
4044
public:
45+
XPUStaticKernelPickPass() {
46+
#ifdef LITE_WITH_XPU
47+
// get xpu device type
48+
int cur_dev_idx = 0;
49+
uint64_t cur_dev_attr = 0;
50+
XPU_CALL(xpu_current_device(&cur_dev_idx));
51+
XPU_CALL(xpu_device_get_attr(&cur_dev_attr, XPUATTR_MODEL, cur_dev_idx));
52+
if (cur_dev_attr <= 1) {
53+
VLOG(4) << "Currents XPU device : XPU1";
54+
xpu_disable_flag_ = "DISABLE_XPU1";
55+
} else if (cur_dev_attr >= 2 && cur_dev_attr <= 299) {
56+
VLOG(4) << "Currents XPU device : XPU2";
57+
xpu_disable_flag_ = "DISABLE_XPU2";
58+
} else if (cur_dev_attr >= 300 && cur_dev_attr <= 599) {
59+
VLOG(4) << "Currents XPU device : XPU3";
60+
xpu_disable_flag_ = "DISABLE_XPU3";
61+
} else {
62+
VLOG(4) << "invaid XPU device";
63+
xpu_disable_flag_ = "NONE";
64+
}
65+
// init quant type, encode precision
66+
local_quant_ = GetBoolFromEnv("XPU_LOCAL_QUANT") ||
67+
lite::TargetWrapperXPU::local_quant;
68+
encode_precision_ = lite::TargetWrapperXPU::multi_encoder_precision;
69+
if (encode_precision_.empty()) {
70+
encode_precision_ = GetStringFromEnv("XPU_ENCODER_PRECISION", "int16");
71+
}
72+
#endif
73+
}
74+
4175
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
4276

4377
const core::KernelPickFactor& kernel_pick_factors() const {
@@ -120,7 +154,6 @@ class XPUStaticKernelPickPass : public mir::StmtPass {
120154
}
121155
VLOG(4) << "[score s3]:" << score;
122156

123-
#ifdef LITE_WITH_XPU
124157
bool type_match = false;
125158
GradeXPUKernelScore(node,
126159
kernel,
@@ -136,10 +169,8 @@ class XPUStaticKernelPickPass : public mir::StmtPass {
136169
VLOG(4) << "[Input/Output precision compatible]: *2";
137170
}
138171
VLOG(4) << "[score s4]:" << score;
139-
#endif
140172

141-
// add new rules for datatype: When the input types are consistent
142-
// with
173+
// add new rules for datatype: When the input types are consistent with
143174
// kernel's input types, select the kernel of the datatype.
144175
if (instruct.op_info()->Type() != "conditional_block" &&
145176
instruct.op_info()->Type() != "while" &&
@@ -151,8 +182,7 @@ class XPUStaticKernelPickPass : public mir::StmtPass {
151182
std::string argname;
152183
instruct.op_info()->GetInputArgname(in->AsArg().name, &argname);
153184
VLOG(5) << "intput var name : " << in->AsArg().name;
154-
// only when datatype is LOD_TENSOR, LOD_TENSOR_ARRAY,
155-
// STEP_SCOPES,
185+
// only when datatype is LOD_TENSOR, LOD_TENSOR_ARRAY, STEP_SCOPES,
156186
// the type pointer is not null;
157187
if (in->AsArg().type) {
158188
VLOG(5) << "input datatype : "
@@ -194,16 +224,11 @@ class XPUStaticKernelPickPass : public mir::StmtPass {
194224
VLOG(4) << "[score(final)]:" << final_score;
195225
VLOG(4) << "------------------------------";
196226

197-
// The data layout is not considered, for the input and output arguments
198-
// might have different data layout.
199-
// TODO(Superjomn) reconsider the idea of taking the data layout as a
200-
// kernel
201-
// specification.
202227
return final_score;
203228
}
204229

205230
// Compatible for PrecisionType.
206-
// For cuda, in the process of choosing kernel, fp16 and fp32 are
231+
// In the process of choosing kernel, fp16 and fp32 are
207232
// compatiable.
208233
// If kernel's declared type is kAny, it is matched.
209234
bool PrecTypeCompatible(const PrecisionType& p1, const PrecisionType& p2) {
@@ -216,7 +241,6 @@ class XPUStaticKernelPickPass : public mir::StmtPass {
216241
return false;
217242
}
218243
}
219-
#ifdef LITE_WITH_XPU
220244
void DataPrecisionDicide(const std::unique_ptr<SSAGraph>& graph);
221245
bool ForceUsePrecision(size_t* score,
222246
const lite::KernelBase& kernel,
@@ -240,7 +264,6 @@ class XPUStaticKernelPickPass : public mir::StmtPass {
240264
const lite::KernelBase& kernel,
241265
bool* type_match,
242266
size_t* score);
243-
void GetXPUDeviceType();
244267
void InplaceOpScore(lite::mir::Node* node,
245268
const lite::KernelBase& kernel,
246269
bool* type_match,
@@ -256,13 +279,11 @@ class XPUStaticKernelPickPass : public mir::StmtPass {
256279
size_t* score,
257280
bool* type_match);
258281
void CollectXPUSpecialOPType(const std::unique_ptr<SSAGraph>& graph);
259-
#endif
260282

261283
private:
262284
core::KernelPickFactor kernel_pick_factors_;
263285

264286
bool xpu_use_fp16_optimizer_{false};
265-
#ifdef LITE_WITH_XPU
266287
std::multimap<std::string, std::vector<std::map<std::string, PrecisionType>>>
267288
xpu_input_type_{};
268289
std::map<std::string, PrecisionType> xpu_output_type_{};
@@ -277,10 +298,11 @@ class XPUStaticKernelPickPass : public mir::StmtPass {
277298
"squeeze2",
278299
"unsqueeze",
279300
"unsqueeze2"};
280-
// int8
281301
bool xpu_use_int8_optimizer_{false};
282302
std::set<std::string> xpu_int8_special_op_{"__xpu__fc", "__xpu__conv2d"};
283-
#endif
303+
304+
bool local_quant_{false};
305+
std::string encode_precision_;
284306
};
285307

286308
} // namespace mir

lite/core/optimizer/mir/fusion/quant_dequant_op_fuser.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,7 @@ void QuantDequantLinearOpFuser::InsertNewNode(SSAGraph* graph,
735735
break;
736736
}
737737
}
738+
quantized_node->stmt()->op()->Attach(*op_info, scope);
738739
IR_NODE_LINK_TO(input_var_node, quantized_node);
739740
}
740741
// 3. Delete nodes and edges

lite/core/optimizer/mir/opencl_kernel_place_correct_pass.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,4 @@ void OpenCLKernelPlaceCorrectPass::Apply(
3030
} // namespace paddle
3131

3232
REGISTER_MIR_PASS(opencl_kernel_place_correct_pass,
33-
paddle::lite::mir::OpenCLKernelPlaceCorrectPass)
34-
.BindTargets({TARGET(kOpenCL)});
33+
paddle::lite::mir::OpenCLKernelPlaceCorrectPass);

lite/core/optimizer/mir/static_kernel_pick_pass.cc

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,7 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
9999
} else {
100100
bool out_type_int8 = true;
101101
// Quantized lstm has fp32 output
102-
if (instruct.op_type() == "lstm" || instruct.op_type() == "gru" ||
103-
instruct.op_type() == "__xpu__multi_encoder" ||
104-
instruct.op_type() == "__xpu__fc") {
102+
if (instruct.op_type() == "lstm" || instruct.op_type() == "gru") {
105103
out_type_int8 = false;
106104
}
107105
// Only if all ops linked to this op output has enable_int8 attr,
@@ -114,9 +112,7 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
114112
CHECK(tmp_op->IsStmt());
115113
auto* tmp_op_info = tmp_op->AsStmt().op_info();
116114
if (!tmp_op_info->HasAttr("enable_int8") ||
117-
tmp_op_info->Type() == "lstm" || tmp_op_info->Type() == "gru" ||
118-
instruct.op_type() == "__xpu__multi_encoder" ||
119-
instruct.op_type() == "__xpu__fc") {
115+
tmp_op_info->Type() == "lstm" || tmp_op_info->Type() == "gru") {
120116
out_type_int8 = false;
121117
break;
122118
}

lite/core/optimizer/mir/variable_place_inference_pass.h

Lines changed: 26 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -256,47 +256,37 @@ class VariablePlaceInferencePass : public DebugPass {
256256
// update op's input variables precision from graph nodes info
257257
// ps. op's input variables are stored in exec_scope, while
258258
// graph node info is a temporary structure.
259-
auto UpdateOpInputsFromNodeInfo = [&]() {
260-
for (auto* in : node->inlinks) {
261-
if (!(in->AsArg().is_weight) && in->AsArg().type->IsTensor()) {
262-
auto in_arg_name = in->AsArg().name;
263-
auto* tmp_tensor = node->AsStmt()
264-
.op()
265-
->scope()
266-
->Var(in_arg_name)
267-
->GetMutable<lite::Tensor>();
268-
tmp_tensor->set_precision(in->AsArg().type->precision());
269-
}
259+
for (auto* in : node->inlinks) {
260+
if (!(in->AsArg().is_weight) && in->AsArg().type->IsTensor()) {
261+
auto in_arg_name = in->AsArg().name;
262+
auto* in_tensor = node->AsStmt()
263+
.op()
264+
->scope()
265+
->Var(in_arg_name)
266+
->GetMutable<lite::Tensor>();
267+
in_tensor->set_precision(in->AsArg().type->precision());
270268
}
271-
};
272-
273-
// update graph nodes precision info from op's output variables
274-
// ps. op's output variables are stored in exec_scope, while
275-
// graph node info is a temporary structure.
276-
auto UpdateNodeInfoFromOpOutputs = [&] {
277-
for (auto* out : node->outlinks) {
278-
if (!(out->AsArg().is_weight) && out->AsArg().type->IsTensor()) {
279-
auto out_arg_name = out->AsArg().name;
280-
auto* tmp_tensor = node->AsStmt()
281-
.op()
282-
->scope()
283-
->Var(out_arg_name)
284-
->GetMutable<lite::Tensor>();
285-
out->AsArg().type =
286-
LiteType::GetTensorTy(out->AsArg().type->target(),
287-
tmp_tensor->precision(),
288-
out->AsArg().type->layout());
289-
}
290-
}
291-
};
292-
293-
// update op's input variables precision from graph nodes info
294-
UpdateOpInputsFromNodeInfo();
269+
}
295270
// update op's output precision from input precision by applying
296271
// InferType
297272
inst.op()->InferType();
298273
// update graph nodes precision info from op's output variables
299-
UpdateNodeInfoFromOpOutputs();
274+
// ps. op's output variables are stored in exec_scope, while
275+
// graph node info is a temporary structure.
276+
for (auto* out : node->outlinks) {
277+
if (!(out->AsArg().is_weight) && out->AsArg().type->IsTensor()) {
278+
auto out_arg_name = out->AsArg().name;
279+
auto* out_tensor = node->AsStmt()
280+
.op()
281+
->scope()
282+
->Var(out_arg_name)
283+
->GetMutable<lite::Tensor>();
284+
out->AsArg().type =
285+
LiteType::GetTensorTy(out->AsArg().type->target(),
286+
out_tensor->precision(),
287+
out->AsArg().type->layout());
288+
}
289+
}
300290
}
301291
}
302292
}

0 commit comments

Comments
 (0)