Skip to content

Commit 967f4c2

Browse files
author
Pei Yang
authored
Cherry pick bert transformer 2.0 support (#31959)
* [Paddle-TRT] TRT inference support for BERT/Transformer in paddle 2.0 api (#31744) * support multihead_matmul_fuse_pass_v3 * fix compile problems * embedding_eltwise_ln pass support lookup_table_v2 * suppoort matmul and matmul_v2 in qkv matmul * map_matmul_to_mul_pass support 3dim
1 parent b655bee commit 967f4c2

7 files changed

+585
-8
lines changed

paddle/fluid/framework/ir/embedding_eltwise_layernorm_fuse_pass.cc

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,19 @@ namespace patterns {
2929
static PDNode* create_emb_vars(PDPattern* pattern, const std::string& name,
3030
const std::string& arg,
3131
bool is_persist = false) {
32+
std::unordered_set<std::string> embedding_ops{"lookup_table",
33+
"lookup_table_v2"};
3234
PDNode* node =
33-
pattern->NewNode(name)->assert_is_op_input("lookup_table", arg);
35+
pattern->NewNode(name)->assert_is_ops_input(embedding_ops, arg);
3436
if (is_persist) return node->assert_is_persistable_var();
3537
return node;
3638
}
3739
static PDNode* create_emb_out_vars(PDPattern* pattern, const std::string& name,
3840
const std::string& arg) {
41+
std::unordered_set<std::string> embedding_ops{"lookup_table",
42+
"lookup_table_v2"};
3943
PDNode* node = pattern->NewNode(name)
40-
->assert_is_only_output_of_op("lookup_table")
44+
->assert_is_only_output_of_ops(embedding_ops)
4145
->assert_is_op_input("elementwise_add", arg)
4246
->AsIntermediate();
4347
return node;
@@ -51,10 +55,12 @@ void Embedding2Eltwise1Pattern::operator()() {
5155
create_emb_vars(pattern, lookup_table1_w_repr(), "W", true);
5256
auto* lookup_table2_w =
5357
create_emb_vars(pattern, lookup_table2_w_repr(), "W", true);
58+
std::unordered_set<std::string> embedding_ops{"lookup_table",
59+
"lookup_table_v2"};
5460
auto* lookup_table1 =
55-
pattern->NewNode(lookup_table1_repr())->assert_is_op("lookup_table");
61+
pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops);
5662
auto* lookup_table2 =
57-
pattern->NewNode(lookup_table2_repr())->assert_is_op("lookup_table");
63+
pattern->NewNode(lookup_table2_repr())->assert_is_ops(embedding_ops);
5864
auto* lookup_table1_out =
5965
create_emb_out_vars(pattern, lookup_table1_out_repr(), "X");
6066
auto* lookup_table2_out =
@@ -75,8 +81,10 @@ void Embedding1Eltwise1Pattern::operator()() {
7581
create_emb_vars(pattern, lookup_table1_x_repr(), "Ids");
7682
auto* lookup_table1_w =
7783
create_emb_vars(pattern, lookup_table1_w_repr(), "W", true);
84+
std::unordered_set<std::string> embedding_ops{"lookup_table",
85+
"lookup_table_v2"};
7886
auto* lookup_table1 =
79-
pattern->NewNode(lookup_table1_repr())->assert_is_op("lookup_table");
87+
pattern->NewNode(lookup_table1_repr())->assert_is_ops(embedding_ops);
8088
auto* lookup_table1_out =
8189
create_emb_out_vars(pattern, lookup_table1_out_repr(), "Y");
8290
auto* eltwise_add =
@@ -342,4 +350,5 @@ REGISTER_PASS_CAPABILITY(embedding_eltwise_layernorm_fuse_pass)
342350
.AddCombination(
343351
paddle::framework::compatible::OpVersionComparatorCombination()
344352
.EQ("lookup_table", 0)
353+
.LE("lookup_table_v2", 1)
345354
.EQ("elementweise_add", 0));

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,36 @@ PDNode *PDNode::assert_is_ops_input(
662662
return this;
663663
}
664664

665+
PDNode *PDNode::assert_is_only_input_of_ops(
666+
const std::unordered_set<std::string> &op_types) {
667+
assert_is_var();
668+
asserts_.emplace_back([=](Node *x) {
669+
for (auto *op : x->outputs) {
670+
if (op && op->IsOp() && op->Op() && op_types.count(op->Op()->Type()) &&
671+
op->inputs.size() == 1) {
672+
return true;
673+
}
674+
}
675+
return false;
676+
});
677+
return this;
678+
}
679+
680+
PDNode *PDNode::assert_is_only_output_of_ops(
681+
const std::unordered_set<std::string> &op_types) {
682+
assert_is_var();
683+
asserts_.emplace_back([=](Node *x) {
684+
for (auto *op : x->inputs) {
685+
if (op && op->IsOp() && op->Op() && op_types.count(op->Op()->Type()) &&
686+
op->outputs.size() == 1) {
687+
return true;
688+
}
689+
}
690+
return false;
691+
});
692+
return this;
693+
}
694+
665695
bool VarLinksToOp(Node *node, const std::string &op_type) {
666696
for (auto *out : node->outputs) {
667697
if (out->IsOp() && out->Op()->Type() == op_type) {

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@ struct PDNode {
146146
const std::unordered_set<std::string>& op_types,
147147
const std::string& argument, int nth);
148148

149+
PDNode* assert_is_only_input_of_ops(
150+
const std::unordered_set<std::string>& op_types);
151+
PDNode* assert_is_only_output_of_ops(
152+
const std::unordered_set<std::string>& op_types);
153+
149154
PDNode* assert_has_n_inputs(size_t n);
150155
PDNode* assert_has_n_outputs(size_t n);
151156

paddle/fluid/framework/ir/map_matmul_to_mul_pass.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
5757
std::vector<int64_t> y_shape = matmul_in_y->Var()->GetShape();
5858
size_t x_rank = x_shape.size();
5959
size_t y_rank = y_shape.size();
60-
flag = flag && x_rank == 2 && y_rank == 2;
60+
flag = flag && (x_rank == 2 || x_rank == 3) && y_rank == 2;
6161

6262
std::vector<Node*>& next_ops = matmul_out->outputs;
6363
flag = flag && next_ops.size() == 1 &&
@@ -69,7 +69,7 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
6969
desc.SetInput("X", {matmul_in_x->Name()});
7070
desc.SetInput("Y", {matmul_in_y->Name()});
7171
desc.SetOutput("Out", {matmul_out->Name()});
72-
desc.SetAttr("x_num_col_dims", 1);
72+
desc.SetAttr("x_num_col_dims", static_cast<int>(x_rank - 1));
7373
desc.SetAttr("y_num_col_dims", 1);
7474
if (matmul_op->Op()->HasAttr("enable_int8")) {
7575
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));

0 commit comments

Comments
 (0)