@@ -29,15 +29,19 @@ namespace patterns {
29
29
static PDNode* create_emb_vars (PDPattern* pattern, const std::string& name,
30
30
const std::string& arg,
31
31
bool is_persist = false ) {
32
+ std::unordered_set<std::string> embedding_ops{" lookup_table" ,
33
+ " lookup_table_v2" };
32
34
PDNode* node =
33
- pattern->NewNode (name)->assert_is_op_input ( " lookup_table " , arg);
35
+ pattern->NewNode (name)->assert_is_ops_input (embedding_ops , arg);
34
36
if (is_persist) return node->assert_is_persistable_var ();
35
37
return node;
36
38
}
37
39
static PDNode* create_emb_out_vars (PDPattern* pattern, const std::string& name,
38
40
const std::string& arg) {
41
+ std::unordered_set<std::string> embedding_ops{" lookup_table" ,
42
+ " lookup_table_v2" };
39
43
PDNode* node = pattern->NewNode (name)
40
- ->assert_is_only_output_of_op ( " lookup_table " )
44
+ ->assert_is_only_output_of_ops (embedding_ops )
41
45
->assert_is_op_input (" elementwise_add" , arg)
42
46
->AsIntermediate ();
43
47
return node;
@@ -51,10 +55,12 @@ void Embedding2Eltwise1Pattern::operator()() {
51
55
create_emb_vars (pattern, lookup_table1_w_repr (), " W" , true );
52
56
auto * lookup_table2_w =
53
57
create_emb_vars (pattern, lookup_table2_w_repr (), " W" , true );
58
+ std::unordered_set<std::string> embedding_ops{" lookup_table" ,
59
+ " lookup_table_v2" };
54
60
auto * lookup_table1 =
55
- pattern->NewNode (lookup_table1_repr ())->assert_is_op ( " lookup_table " );
61
+ pattern->NewNode (lookup_table1_repr ())->assert_is_ops (embedding_ops );
56
62
auto * lookup_table2 =
57
- pattern->NewNode (lookup_table2_repr ())->assert_is_op ( " lookup_table " );
63
+ pattern->NewNode (lookup_table2_repr ())->assert_is_ops (embedding_ops );
58
64
auto * lookup_table1_out =
59
65
create_emb_out_vars (pattern, lookup_table1_out_repr (), " X" );
60
66
auto * lookup_table2_out =
@@ -75,8 +81,10 @@ void Embedding1Eltwise1Pattern::operator()() {
75
81
create_emb_vars (pattern, lookup_table1_x_repr (), " Ids" );
76
82
auto * lookup_table1_w =
77
83
create_emb_vars (pattern, lookup_table1_w_repr (), " W" , true );
84
+ std::unordered_set<std::string> embedding_ops{" lookup_table" ,
85
+ " lookup_table_v2" };
78
86
auto * lookup_table1 =
79
- pattern->NewNode (lookup_table1_repr ())->assert_is_op ( " lookup_table " );
87
+ pattern->NewNode (lookup_table1_repr ())->assert_is_ops (embedding_ops );
80
88
auto * lookup_table1_out =
81
89
create_emb_out_vars (pattern, lookup_table1_out_repr (), " Y" );
82
90
auto * eltwise_add =
@@ -342,4 +350,5 @@ REGISTER_PASS_CAPABILITY(embedding_eltwise_layernorm_fuse_pass)
342
350
.AddCombination(
343
351
paddle::framework::compatible::OpVersionComparatorCombination ()
344
352
.EQ(" lookup_table" , 0 )
353
+ .LE(" lookup_table_v2" , 1 )
345
354
.EQ(" elementweise_add" , 0 ));
0 commit comments