Skip to content

Commit 146e942

Browse files
authored
Merge pull request #15250 from tensor-tang/refine/seqpool/feed
Refine/seqpool/feed with infer zerocopytensor
2 parents 8f17c71 + 96786d3 commit 146e942

File tree

9 files changed

+378
-48
lines changed

9 files changed

+378
-48
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ cc_test(graph_helper_test SRCS graph_helper_test.cc DEPS graph graph_helper op_r
6969
cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph_to_program_pass)
7070
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
7171
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
72+
cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto)
7273
cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass)
7374
if (WITH_MKLDNN)
7475
cc_test(test_depthwise_conv_mkldnn_pass SRCS depthwise_conv_mkldnn_pass_tester.cc DEPS depthwise_conv_mkldnn_pass)

paddle/fluid/framework/ir/seqpool_concat_fuse_pass.cc

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,21 +39,25 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern,
3939

4040
auto is_seqpool_op_with_pootype_of_nth_input_of_concat = [=](
4141
Node* x, const std::string& type, int idx) -> bool {
42-
bool ok = x && x->IsOp() && x->Op()->Type() == "sequence_pool" &&
43-
x->Op()->HasAttr("pooltype") &&
44-
boost::get<std::string>(x->Op()->GetAttr("pooltype")) == type &&
45-
x->outputs.size() == 2; // seqpool should only have 2 outputs
46-
if (ok) {
47-
// only one output of seqpool_op is nth_input_var of concat
48-
// the other one should be unused empty var
42+
bool this_is_seqpool_op =
43+
x && x->IsOp() && x->Op()->Type() == "sequence_pool" &&
44+
x->Op()->HasAttr("pooltype") &&
45+
boost::get<std::string>(x->Op()->GetAttr("pooltype")) == type &&
46+
x->outputs.size() == 2; // seqpool should only have 2 outputs
47+
bool satisfied_all = this_is_seqpool_op;
48+
if (this_is_seqpool_op) {
49+
// Only one output of seqpool_op is nth_input_var of concat,
50+
// the other one should be unused empty var.
4951
if (is_nth_input_var_of_concat(x->outputs[0], idx)) {
50-
ok = ok && x->outputs[1]->IsVar() && x->outputs[1]->outputs.size() == 0;
52+
satisfied_all = satisfied_all && x->outputs[1]->IsVar() &&
53+
x->outputs[1]->outputs.size() == 0;
5154
} else {
52-
ok = ok && is_nth_input_var_of_concat(x->outputs[1], idx) &&
53-
x->outputs[0]->IsVar() && x->outputs[0]->outputs.size() == 0;
55+
satisfied_all =
56+
satisfied_all && is_nth_input_var_of_concat(x->outputs[1], idx) &&
57+
x->outputs[0]->IsVar() && x->outputs[0]->outputs.size() == 0;
5458
}
5559
}
56-
return ok;
60+
return satisfied_all;
5761
};
5862

5963
auto* concat_op = pattern->NewNode(
@@ -72,6 +76,7 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern,
7276

7377
std::vector<PDNode*> seqpool_ops_input_var(num_inputs);
7478
std::vector<PDNode*> seqpool_ops_output_var(num_inputs);
79+
std::vector<PDNode*> seqpool_ops_output_unused_var(num_inputs);
7580
std::vector<PDNode*> seqpool_ops(num_inputs);
7681

7782
for (int i = 0; i < num_inputs; ++i) {
@@ -84,6 +89,15 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern,
8489
},
8590
name_scope + "/sequence_pool_out_" + std::to_string(i));
8691

92+
seqpool_ops_output_unused_var[i] = pattern->NewNode(
93+
[=](Node* x) {
94+
return x && x->IsVar() && x->inputs.size() == 1 &&
95+
x->outputs.size() == 0 &&
96+
is_seqpool_op_with_pootype_of_nth_input_of_concat(x->inputs[0],
97+
"SUM", i);
98+
},
99+
name_scope + "/sequence_pool_unused_out_" + std::to_string(i));
100+
87101
seqpool_ops[i] = pattern->NewNode(
88102
[=](Node* x) {
89103
return x && x->IsOp() &&
@@ -93,23 +107,29 @@ PDNode* BuildSeqPoolConcatPattern(PDPattern* pattern,
93107

94108
seqpool_ops_input_var[i] = pattern->NewNode(
95109
[=](Node* x) {
96-
return x && x->IsVar() && x->outputs.size() >= 1 &&
97-
is_seqpool_op_with_pootype_of_nth_input_of_concat(
98-
x->outputs[0], "SUM", i);
110+
bool basic = x && x->IsVar() && x->outputs.size() >= 1;
111+
bool next_is_fine = false;
112+
for (auto* o : x->outputs) {
113+
if (is_seqpool_op_with_pootype_of_nth_input_of_concat(o, "SUM",
114+
i)) {
115+
next_is_fine = true;
116+
break;
117+
}
118+
}
119+
return basic && next_is_fine;
99120
},
100121
name_scope + "/sequence_pool_in_" + std::to_string(i));
101122

102123
// Links
103124
seqpool_ops[i]
104125
->LinksFrom({seqpool_ops_input_var[i]})
105-
.LinksTo({seqpool_ops_output_var[i]});
126+
.LinksTo({seqpool_ops_output_var[i], seqpool_ops_output_unused_var[i]});
106127
}
107128
concat_op->LinksFrom(seqpool_ops_output_var).LinksTo({concat_out_var});
108129
return concat_out_var;
109130
}
110131

111-
int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
112-
int num_inputs) {
132+
int BuildFusion(Graph* graph, const std::string& name_scope, int num_inputs) {
113133
GraphPatternDetector gpd;
114134
auto* pattern = gpd.mutable_pattern();
115135
BuildSeqPoolConcatPattern(pattern, name_scope, num_inputs);
@@ -178,8 +198,8 @@ std::unique_ptr<ir::Graph> SeqPoolConcatFusePass::ApplyImpl(
178198
FusePassBase::Init(name_scope_, graph.get());
179199
int fusion_count = 0;
180200
for (int i = MAX_CONCAT_INPUTS; i > 0; --i) {
181-
fusion_count += BuildFusion(
182-
graph.get(), name_scope_ + "/" + std::to_string(i), param_scope(), i);
201+
fusion_count +=
202+
BuildFusion(graph.get(), name_scope_ + "/" + std::to_string(i), i);
183203
}
184204
AddStatis(fusion_count);
185205

paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,20 @@ namespace paddle {
2323
namespace framework {
2424
namespace ir {
2525

26+
/**
27+
* Fuse SequencePool(with sum pooltype yet) and Concat;
28+
*
29+
* Before fuse:
30+
* | | |
31+
* seq_pool, seq_pool, ... seq_pool
32+
* \ | ... /
33+
* concat
34+
* |
35+
* After fuse:
36+
* \ | /
37+
* FusionSeqPoolConcat
38+
* |
39+
*/
2640
class SeqPoolConcatFusePass : public FusePassBase {
2741
public:
2842
virtual ~SeqPoolConcatFusePass() {}
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/seqpool_concat_fuse_pass.h"
16+
#include <gtest/gtest.h>
17+
#include "paddle/fluid/framework/op_proto_maker.h"
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace ir {
22+
23+
void SetOp(ProgramDesc* prog, const std::string& type,
24+
const std::vector<std::string>& inputs,
25+
const std::vector<std::string>& outputs) {
26+
auto* op = prog->MutableBlock(0)->AppendOp();
27+
op->SetType(type);
28+
if (type == "sequence_pool") {
29+
op->SetInput("X", {inputs[0]});
30+
std::string pooltype = "SUM";
31+
op->SetAttr("pooltype", pooltype);
32+
op->SetOutput("MaxIndex", {outputs[0]});
33+
op->SetOutput("Out", {outputs[1]});
34+
} else if (type == "concat") {
35+
op->SetInput("X", inputs);
36+
op->SetAttr("axis", 1);
37+
op->SetOutput("Out", {outputs[0]});
38+
} else {
39+
op->SetInput("X", inputs);
40+
op->SetOutput("Out", outputs);
41+
}
42+
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
43+
static_cast<int>(OpRole::kForward));
44+
}
45+
46+
int CountOpType(const ir::Graph* graph,
47+
const std::string& op_type = "fusion_seqpool_concat") {
48+
int count = 0;
49+
for (auto* node : graph->Nodes()) {
50+
if (node->IsOp() && node->Op()->Type() == op_type) {
51+
++count;
52+
}
53+
}
54+
return count;
55+
}
56+
57+
std::unique_ptr<ir::Graph> GetNumNodesOfBeforeAfter(
58+
std::unique_ptr<ir::Graph> graph, int* before, int* after,
59+
const std::string& pass_type = "seqpool_concat_fuse_pass") {
60+
auto pass = PassRegistry::Instance().Get(pass_type);
61+
*before = graph->Nodes().size();
62+
graph = pass->Apply(std::move(graph));
63+
*after = graph->Nodes().size();
64+
return graph;
65+
}
66+
67+
/*
68+
* Before fuse:
69+
* a b c
70+
* | | |
71+
* op1 op2 op3
72+
* / \ / \ / \
73+
* d e f g h i
74+
* \ | /
75+
* concat
76+
* |
77+
* j
78+
* Type of op1, op2 and op3 are sequence_pool, with "SUM" pooltype attr
79+
*
80+
* After fuse:
81+
* a b c
82+
* \ | /
83+
* fusion_seqpool_concat
84+
* |
85+
* j
86+
*/
87+
TEST(SeqPoolConcatFusePass, basic) {
88+
ProgramDesc prog;
89+
for (auto& v : std::vector<std::string>(
90+
{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j"})) {
91+
auto* var = prog.MutableBlock(0)->Var(v);
92+
var->SetType(proto::VarType::LOD_TENSOR);
93+
}
94+
95+
SetOp(&prog, "sequence_pool", std::vector<std::string>({"a"}),
96+
std::vector<std::string>({"d", "e"}));
97+
SetOp(&prog, "sequence_pool", std::vector<std::string>({"b"}),
98+
std::vector<std::string>({"f", "g"}));
99+
SetOp(&prog, "sequence_pool", std::vector<std::string>({"c"}),
100+
std::vector<std::string>({"h", "i"}));
101+
SetOp(&prog, "concat", std::vector<std::string>({"e", "g", "i"}),
102+
std::vector<std::string>({"j"}));
103+
104+
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
105+
int before, after;
106+
graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after);
107+
// Remove 10 Nodes: op1, op2, op3, d, e, f, g, h, i, concat_op
108+
// Add 1 Node: fusion_seqpool_concat
109+
EXPECT_EQ(after, before - 9);
110+
EXPECT_EQ(CountOpType(graph.get()), 1);
111+
}
112+
113+
/*
114+
* Before fuse:
115+
* a b
116+
* | / \
117+
* op1 op2 op3
118+
* / \ / \ \
119+
* c d e f g
120+
* \ /
121+
* concat
122+
* |
123+
* h
124+
* Type of op1 and op2 are sequence_pool, with "SUM" pooltype attr
125+
*
126+
* After fuse:
127+
* a b
128+
* \ / \
129+
* fusion_seqpool_concat op3
130+
* | |
131+
* h g
132+
*/
133+
TEST(SeqPoolConcatFusePass, advanced) {
134+
ProgramDesc prog;
135+
for (auto& v :
136+
std::vector<std::string>({"a", "b", "c", "d", "e", "f", "g", "h"})) {
137+
auto* var = prog.MutableBlock(0)->Var(v);
138+
var->SetType(proto::VarType::LOD_TENSOR);
139+
}
140+
141+
SetOp(&prog, "sequence_pool", std::vector<std::string>({"a"}),
142+
std::vector<std::string>({"c", "d"}));
143+
SetOp(&prog, "sequence_pool", std::vector<std::string>({"b"}),
144+
std::vector<std::string>({"e", "f"}));
145+
SetOp(&prog, "op3", std::vector<std::string>({"b"}),
146+
std::vector<std::string>({"g"}));
147+
SetOp(&prog, "concat", std::vector<std::string>({"d", "f"}),
148+
std::vector<std::string>({"h"}));
149+
150+
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
151+
int before, after;
152+
graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after);
153+
// Remove 7 Nodes: op1, op2, c, d, e, f concat_op
154+
// Add 1 Node: fusion_seqpool_concat
155+
EXPECT_EQ(after, before - 6);
156+
EXPECT_EQ(CountOpType(graph.get()), 1);
157+
}
158+
159+
ProgramDesc BuildProgramDesc(int num_inputs_of_concat) {
160+
ProgramDesc prog;
161+
auto new_var = [&](const std::string& name) {
162+
auto* var = prog.MutableBlock(0)->Var(name);
163+
var->SetType(proto::VarType::LOD_TENSOR);
164+
};
165+
std::vector<std::string> concat_inputs;
166+
for (int i = 0; i < num_inputs_of_concat; ++i) {
167+
std::string prefix = "seqpool_op_" + i;
168+
new_var(prefix + "in");
169+
new_var(prefix + "out");
170+
new_var(prefix + "out_unused");
171+
SetOp(&prog, "sequence_pool", std::vector<std::string>({prefix + "in"}),
172+
std::vector<std::string>({prefix + "out", prefix + "out_unused"}));
173+
concat_inputs.push_back(prefix + "out");
174+
}
175+
SetOp(&prog, "concat", concat_inputs,
176+
std::vector<std::string>({"concat_out"}));
177+
return prog;
178+
}
179+
180+
// test more inputs of concat
181+
TEST(SeqPoolConcatFusePass, more_inputs) {
182+
for (int num : {1, 2, 10}) {
183+
ProgramDesc prog = BuildProgramDesc(num);
184+
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
185+
int before, after;
186+
graph = GetNumNodesOfBeforeAfter(std::move(graph), &before, &after);
187+
// Remove Nodes: n * (seqpool_op, out, out_unused), and concat_op
188+
// Add Node: fusion_seqpool_concat op
189+
EXPECT_EQ(after, before - num * 3);
190+
EXPECT_EQ(CountOpType(graph.get()), 1);
191+
}
192+
}
193+
194+
} // namespace ir
195+
} // namespace framework
196+
} // namespace paddle
197+
198+
USE_PASS(seqpool_concat_fuse_pass);

paddle/fluid/inference/api/helper.h

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,14 @@ static std::string DescribeTensor(const PaddleTensor &tensor) {
204204
os << to_string(l) << "; ";
205205
}
206206
os << "\n";
207-
os << " - data: ";
207+
os << " - memory length: " << tensor.data.length();
208+
os << "\n";
208209

210+
os << " - data: ";
209211
int dim = VecReduceToInt(tensor.shape);
212+
float *pdata = static_cast<float *>(tensor.data.data());
210213
for (int i = 0; i < dim; i++) {
211-
os << static_cast<float *>(tensor.data.data())[i] << " ";
214+
os << pdata[i] << " ";
212215
}
213216
os << '\n';
214217
return os.str();
@@ -224,10 +227,12 @@ static std::string DescribeZeroCopyTensor(const ZeroCopyTensor &tensor) {
224227
os << to_string(l) << "; ";
225228
}
226229
os << "\n";
227-
os << " - data: ";
228230
PaddlePlace place;
229231
int size;
230232
const auto *data = tensor.data<float>(&place, &size);
233+
os << " - numel: " << size;
234+
os << "\n";
235+
os << " - data: ";
231236
for (int i = 0; i < size; i++) {
232237
os << data[i] << " ";
233238
}

paddle/fluid/inference/api/paddle_api.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ class ZeroCopyTensor {
123123
*/
124124
template <typename T>
125125
T* mutable_data(PaddlePlace place);
126-
/** Get the memory directly, will return the place and memory size by pointer.
126+
/** Get the memory directly, will return the place and element size by
127+
* pointer.
127128
* This is for reading the output tensor.
128129
*/
129130
template <typename T>

paddle/fluid/inference/tests/api/analyzer_rnn1_tester.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -351,10 +351,10 @@ TEST(Analyzer_rnn1, ZeroCopy) {
351351
ASSERT_TRUE(native_predictor->Run(native_inputs.front(), &native_outputs));
352352
LOG(INFO) << "native output " << DescribeTensor(native_outputs.front());
353353

354-
int output_size{0};
354+
int output_size{0}; // this is the number of elements not memory size
355355
auto *zero_copy_data = output_tensor->data<float>(&place, &output_size);
356356
auto *native_data = static_cast<float *>(native_outputs.front().data.data());
357-
for (size_t i = 0; i < output_size / sizeof(float); i++) {
357+
for (int i = 0; i < output_size; i++) {
358358
EXPECT_NEAR(zero_copy_data[i], native_data[i], 1e-3);
359359
}
360360
}

0 commit comments

Comments
 (0)