Skip to content

Commit 4e9f229

Browse files
author
Pei Yang
authored
solve conflicts. test=release/1.8 (#25263)
1 parent 05163e1 commit 4e9f229

14 files changed

+417
-276
lines changed

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 46 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1983,99 +1983,58 @@ PDNode *patterns::TransposeFlattenConcat::operator()(
19831983
return concat_out;
19841984
}
19851985

1986-
void patterns::QuantDequantOpFuse::operator()(PDNode *quant_op_input,
1987-
const std::string &op_type,
1988-
const std::string &weight_name,
1989-
int times,
1990-
const std::string &quant_type,
1991-
const std::string &dequant_type) {
1992-
int kNumFields = 5;
1993-
const int kQuantizedWeightOffset = 0;
1994-
const int kQuantizedOpOffset = 1;
1995-
const int kQuantizedOpOutOffset = 2;
1996-
const int kDequantOpOffset = 3;
1997-
const int kDequantOpOutOffset = 4;
1998-
const int kDequantOpWeightScaleOffset = 5;
1999-
2000-
// the quant op always be one.
2001-
auto quant_op_in_scale = pattern->NewNode(GetNodeName("quant_op_in_scale"))
1986+
void patterns::DeleteQuantOpFuse::operator()(PDNode *input_act_node,
1987+
const std::string &quant_type) {
1988+
auto *input_scale_node = pattern->NewNode(GetNodeName("input_scale_node"))
20021989
->assert_is_op_input(quant_type, "InScale")
20031990
->AsInput();
2004-
auto quant_op =
2005-
pattern->NewNode(GetNodeName("quant_op"))->assert_is_op(quant_type);
2006-
2007-
PDNode *quant_op_out_scale = nullptr;
1991+
auto *quant_node =
1992+
pattern->NewNode(GetNodeName("quant_node"))->assert_is_op(quant_type);
1993+
auto *output_scale_node = pattern->NewNode(GetNodeName("output_scale_node"))
1994+
->assert_is_op_output(quant_type, "OutScale")
1995+
->AsOutput();
1996+
auto *output_act_node = pattern->NewNode(GetNodeName("output_act_node"))
1997+
->assert_is_op_output(quant_type, "Out")
1998+
->AsOutput();
1999+
quant_node->LinksFrom({input_scale_node, input_act_node});
2000+
output_scale_node->LinksFrom({quant_node});
2001+
output_act_node->LinksFrom({quant_node});
2002+
}
2003+
2004+
void patterns::DequantOpFuse::operator()(PDNode *quantized_op_input,
2005+
const std::string &quantized_op_type,
2006+
const std::string &dequant_type,
2007+
const std::string &weight_name) {
2008+
auto *quantized_op_weight =
2009+
pattern->NewNode(GetNodeName("quantized_op_weight"))
2010+
->assert_is_op_input(quantized_op_type, weight_name)
2011+
->AsInput();
2012+
auto *quantized_op = pattern->NewNode(GetNodeName("quantized_op"))
2013+
->assert_is_op(quantized_op_type);
2014+
auto *quantized_op_out = pattern->NewNode(GetNodeName("quantized_op_out"))
2015+
->assert_is_op_output(quantized_op_type)
2016+
->assert_is_op_input(dequant_type, "X");
2017+
auto *dequant_op =
2018+
pattern->NewNode(GetNodeName("dequant_op"))->assert_is_op(dequant_type);
2019+
auto *dequant_op_out = pattern->NewNode(GetNodeName("dequant_op_out"))
2020+
->assert_is_op_output(dequant_type, "Out")
2021+
->AsOutput();
2022+
PDNode *dequant_channel_scale = nullptr;
20082023
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
2009-
kNumFields += 1;
2010-
quant_op_out_scale = pattern->NewNode(GetNodeName("quant_op_out_scale"))
2011-
->assert_is_op_output(quant_type, "OutScale")
2012-
->assert_is_op_nth_input(dequant_type, "Scales", 1)
2013-
->AsIntermediate();
2014-
} else {
2015-
quant_op_out_scale = pattern->NewNode(GetNodeName("quant_op_out_scale"))
2016-
->assert_is_op_output(quant_type, "OutScale")
2017-
->assert_is_op_input(dequant_type, "Scale")
2018-
->AsIntermediate();
2024+
dequant_channel_scale =
2025+
pattern->NewNode(GetNodeName("dequant_channel_scale"))
2026+
->assert_is_op_nth_input(dequant_type, "Scales", 0)
2027+
->AsInput();
20192028
}
2029+
quantized_op->LinksFrom({quantized_op_input, quantized_op_weight});
2030+
quantized_op_out->LinksFrom({quantized_op});
20202031

2021-
auto quant_op_out = pattern->NewNode(GetNodeName("quant_op_out"))
2022-
->assert_is_op_output(quant_type, "Out")
2023-
->assert_is_op_input(op_type)
2024-
->AsIntermediate();
2025-
2026-
// there are 'times' quantized and dequant op
2027-
std::vector<PDNode *> nodes;
2028-
for (int i = 0; i < times; i++) {
2029-
nodes.push_back(
2030-
pattern->NewNode(GetNodeName("quantized_op_weight") + std::to_string(i))
2031-
->assert_is_op_input(op_type, weight_name)
2032-
->AsInput());
2033-
nodes.push_back(
2034-
pattern->NewNode(GetNodeName("quantized_op") + std::to_string(i))
2035-
->assert_is_op(op_type));
2036-
2037-
nodes.push_back(
2038-
pattern->NewNode(GetNodeName("quantized_op_out") + std::to_string(i))
2039-
->assert_is_op_output(op_type)
2040-
->assert_is_op_input(dequant_type, "X")
2041-
->AsIntermediate());
2042-
2043-
nodes.push_back(
2044-
pattern->NewNode(GetNodeName("dequant_op") + std::to_string(i))
2045-
->assert_is_op(dequant_type));
2046-
2047-
nodes.push_back(
2048-
pattern->NewNode(GetNodeName("dequant_op_out") + std::to_string(i))
2049-
->assert_is_op_output(dequant_type, "Out")
2050-
->AsOutput());
2051-
2052-
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
2053-
nodes.push_back(pattern
2054-
->NewNode(GetNodeName("dequant_channel_scale") +
2055-
std::to_string(i))
2056-
->assert_is_op_nth_input(dequant_type, "Scales", 0)
2057-
->AsInput());
2058-
}
2059-
}
2060-
2061-
quant_op->LinksFrom({quant_op_input, quant_op_in_scale});
2062-
quant_op_out->LinksFrom({quant_op});
2063-
for (int i = 0; i < times; i++) {
2064-
nodes[i * kNumFields + kQuantizedOpOffset]->LinksFrom(
2065-
{quant_op_out, nodes[i * kNumFields + kQuantizedWeightOffset]});
2066-
nodes[i * kNumFields + kQuantizedOpOutOffset]->LinksFrom(
2067-
{nodes[i * kNumFields + kQuantizedOpOffset]});
2068-
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
2069-
nodes[i * kNumFields + kDequantOpOffset]->LinksFrom(
2070-
{nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale,
2071-
nodes[i * kNumFields + kDequantOpWeightScaleOffset]});
2072-
} else {
2073-
nodes[i * kNumFields + kDequantOpOffset]->LinksFrom(
2074-
{nodes[i * kNumFields + kQuantizedOpOutOffset], quant_op_out_scale});
2075-
}
2076-
nodes[i * kNumFields + kDequantOpOutOffset]->LinksFrom(
2077-
{nodes[i * kNumFields + kDequantOpOffset]});
2032+
if (dequant_type == "fake_channel_wise_dequantize_max_abs") {
2033+
dequant_op->LinksFrom({quantized_op_out, dequant_channel_scale});
2034+
} else {
2035+
dequant_op->LinksFrom({quantized_op_out});
20782036
}
2037+
dequant_op_out->LinksFrom({dequant_op});
20792038
}
20802039

20812040
void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) {

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,14 +1163,28 @@ struct TransposeFlattenConcat : public PatternBase {
11631163
}
11641164
};
11651165

1166-
struct QuantDequantOpFuse : public PatternBase {
1167-
QuantDequantOpFuse(PDPattern* pattern, const std::string& name_scope)
1168-
: PatternBase(pattern, name_scope, "quant_dequant_fuse") {}
1169-
1170-
void operator()(PDNode* quant_op_input, const std::string& op_name,
1171-
const std::string& weight_name, int times,
1172-
const std::string& quant_type,
1173-
const std::string& dequant_type);
1166+
struct DeleteQuantOpFuse : public PatternBase {
1167+
DeleteQuantOpFuse(PDPattern* pattern, const std::string& name_scope)
1168+
: PatternBase(pattern, name_scope, "delete_quant_fuse") {}
1169+
1170+
void operator()(PDNode* input_act_node, const std::string& quant_type);
1171+
1172+
std::string GetNodeName(const std::string& op_type) {
1173+
return PDNodeName(name_scope_, repr_, id_, op_type);
1174+
}
1175+
1176+
PDNode* GetPDNode(const std::string& op_type) {
1177+
return pattern->RetrieveNode(GetNodeName(op_type));
1178+
}
1179+
};
1180+
1181+
struct DequantOpFuse : public PatternBase {
1182+
DequantOpFuse(PDPattern* pattern, const std::string& name_scope)
1183+
: PatternBase(pattern, name_scope, "dequant_fuse") {}
1184+
1185+
void operator()(PDNode* quant_op_input, const std::string& quantized_op_type,
1186+
const std::string& dequant_type,
1187+
const std::string& weight_name);
11741188

11751189
std::string GetNodeName(const std::string& op_type) {
11761190
return PDNodeName(name_scope_, repr_, id_, op_type);

0 commit comments

Comments
 (0)