Skip to content

Commit 37d49b3

Browse files
authored
Merge pull request #14409 from luotao1/dam_fc
Enhance fc_op for 3-D shape tensor
2 parents 7a64d48 + 1d86780 commit 37d49b3

14 files changed

+60
-49
lines changed

paddle/fluid/framework/ir/fc_fuse_pass.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
5757
desc.SetInput("W", std::vector<std::string>({fc_Y_in}));
5858
desc.SetInput("Bias", std::vector<std::string>({fc_bias_in}));
5959
desc.SetOutput("Out", std::vector<std::string>({fc_out_out}));
60+
desc.SetAttr("in_num_col_dims", mul->Op()->GetAttr("x_num_col_dims"));
6061
desc.SetType("fc");
6162
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
6263
GraphSafeRemoveNodes(graph.get(), {mul, elementwise_add, mul_out});

paddle/fluid/framework/ir/fc_fuse_pass_tester.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ void SetOp(ProgramDesc* prog, const std::string& type,
2929
if (type == "mul") {
3030
op->SetInput("X", {inputs[0]});
3131
op->SetInput("Y", {inputs[1]});
32+
op->SetAttr("x_num_col_dims", {1});
3233
} else if (type == "elementwise_add") {
3334
op->SetInput("X", inputs);
3435
}

paddle/fluid/inference/analysis/ir_passes/subgraph_detector.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ void DetachDeletedNodes(framework::ir::Graph *graph) {
412412
void SubGraphFuser::ReplaceNodesWithSubGraphs() {
413413
auto subgraphs = SubgraphDetector(graph_, node_inside_subgraph_teller_)();
414414
for (auto &subgraph : subgraphs) {
415-
if (subgraph.size() <= min_subgraph_size_) continue;
415+
if (subgraph.size() <= (size_t)min_subgraph_size_) continue;
416416
LOG(INFO) << "detect a subgraph size " << subgraph.size();
417417
std::unordered_set<Node *> subgraph_uniq(subgraph.begin(), subgraph.end());
418418
// replace this sub-graph with the first node. Two steps: 1. Create a Block

paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(framework::ir::Node *node,
114114
// it is either an OP's input or an OP's output.
115115

116116
auto &subgraph_nodes = *Agent(node).subgraph();
117-
for (int index = 0; index < block_desc.OpSize(); index++) {
117+
for (size_t index = 0; index < block_desc.OpSize(); index++) {
118118
framework::proto::OpDesc *op = block_desc.Op(index)->Proto();
119119
auto correspond_node = subgraph_nodes[index];
120120
PADDLE_ENFORCE_EQ(correspond_node->Name(), op->type());

paddle/fluid/inference/tests/api/CMakeLists.txt

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,7 @@ inference_analysis_api_test(test_analyzer_rnn2 ${RNN2_INSTALL_DIR} analyzer_rnn2
4545
# DAM
4646
set(DAM_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/dam")
4747
download_model_and_data(${DAM_INSTALL_DIR} "DAM_model.tar.gz" "DAM_data.txt.tar.gz")
48-
inference_analysis_test(test_analyzer_dam SRCS analyzer_dam_tester.cc
49-
EXTRA_DEPS ${INFERENCE_EXTRA_DEPS} ARGS
50-
--infer_model=${DAM_INSTALL_DIR}/model
51-
--infer_data=${DAM_INSTALL_DIR}/data.txt
52-
--use_analysis=0)
48+
inference_analysis_api_test(test_analyzer_dam ${DAM_INSTALL_DIR} analyzer_dam_tester.cc)
5349

5450
# chinese_ner
5551
set(CHINESE_NER_INSTALL_DIR "${INFERENCE_DEMO_INSTALL_DIR}/chinese_ner")

paddle/fluid/inference/tests/api/analyzer_dam_tester.cc

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ struct DataRecord {
6969
num_lines++;
7070
std::vector<std::string> data;
7171
split(line, ',', &data);
72-
CHECK_EQ(data.size(), 2 * MAX_TURN_NUM + 3);
72+
CHECK_EQ(data.size(), (size_t)(2 * MAX_TURN_NUM + 3));
7373
// load turn data
7474
std::vector<int64_t> turns_tmp[MAX_TURN_NUM];
7575
for (int i = 0; i < MAX_TURN_NUM; ++i) {
@@ -197,15 +197,13 @@ TEST(Analyzer_dam, fuse_statis) {
197197
contrib::AnalysisConfig cfg;
198198
SetConfig(&cfg);
199199

200-
if (FLAGS_use_analysis) {
201-
int num_ops;
202-
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
203-
auto fuse_statis = GetFuseStatis(
204-
static_cast<AnalysisPredictor *>(predictor.get()), &num_ops);
205-
ASSERT_TRUE(fuse_statis.count("fc_fuse"));
206-
EXPECT_EQ(fuse_statis.at("fc_fuse"), 317);
207-
EXPECT_EQ(num_ops, 2020);
208-
}
200+
int num_ops;
201+
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
202+
auto fuse_statis = GetFuseStatis(
203+
static_cast<AnalysisPredictor *>(predictor.get()), &num_ops);
204+
ASSERT_TRUE(fuse_statis.count("fc_fuse"));
205+
EXPECT_EQ(fuse_statis.at("fc_fuse"), 317);
206+
EXPECT_EQ(num_ops, 2020);
209207
}
210208

211209
// Compare result of NativeConfig and AnalysisConfig
@@ -216,11 +214,8 @@ TEST(Analyzer_dam, compare) {
216214
std::vector<std::vector<PaddleTensor>> input_slots_all;
217215
SetInput(&input_slots_all);
218216

219-
if (FLAGS_use_analysis) {
220-
CompareNativeAndAnalysis(
221-
reinterpret_cast<const PaddlePredictor::Config *>(&cfg),
222-
input_slots_all);
223-
}
217+
CompareNativeAndAnalysis(
218+
reinterpret_cast<const PaddlePredictor::Config *>(&cfg), input_slots_all);
224219
}
225220

226221
} // namespace inference

paddle/fluid/inference/tests/api/analyzer_vis_tester.cc

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,6 @@ void SetConfig(AnalysisConfig *cfg) {
5959
cfg->specify_input_name = true;
6060
// TODO(TJ): fix fusion gru
6161
cfg->pass_builder()->DeletePass("fc_gru_fuse_pass");
62-
#ifdef PADDLE_WITH_MKLDNN
63-
cfg->EnableMKLDNN();
64-
#endif
6562
}
6663

6764
void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {

paddle/fluid/operators/fc_op.cc

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,9 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
2727
"Out(Output) of Fully Connected should not be null.");
2828
PADDLE_ENFORCE(ctx->HasInput("W"),
2929
"W(Input) of Fully Connected should not be null.");
30-
// NCHW
30+
3131
auto in_dims = ctx->GetInputDim("Input");
32-
// IO, I=C*H*W
3332
auto w_dims = ctx->GetInputDim("W");
34-
std::vector<int64_t> output_shape({in_dims[0], w_dims[1]});
3533

3634
if (ctx->HasInput("Bias")) {
3735
auto bias_dims = ctx->GetInputDim("Bias");
@@ -44,14 +42,32 @@ void FCOp::InferShape(framework::InferShapeContext* ctx) const {
4442
"The shape of Bias must be [1, dim].");
4543
}
4644
}
47-
PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4,
48-
"Fully Connected input should be 2-D or 4-D tensor.");
45+
46+
if (ctx->Attrs().Get<bool>("use_mkldnn")) {
47+
PADDLE_ENFORCE(in_dims.size() == 2 || in_dims.size() == 4,
48+
"Fully Connected input should be 2-D or 4-D tensor.");
49+
}
4950
PADDLE_ENFORCE_EQ(w_dims.size(), 2UL,
5051
"Fully Connected input should be 2-D tensor.");
51-
PADDLE_ENFORCE_EQ(framework::product(in_dims) / in_dims[0], w_dims[0],
52-
"Fully Connected input and weigth size do not match.");
52+
int in_num_col_dims = ctx->Attrs().Get<int>("in_num_col_dims");
53+
PADDLE_ENFORCE_GT(
54+
in_dims.size(), in_num_col_dims,
55+
"The input tensor Input's rank of FCOp should be larger than "
56+
"in_num_col_dims.");
57+
58+
auto in_mat_dims = framework::flatten_to_2d(in_dims, in_num_col_dims);
59+
PADDLE_ENFORCE_EQ(
60+
in_mat_dims[1], w_dims[0],
61+
"Fully Connected input and weigth size do not match. %s, %s");
62+
63+
std::vector<int64_t> output_dims;
64+
output_dims.reserve(static_cast<size_t>(in_num_col_dims + 1));
65+
for (int i = 0; i < in_num_col_dims; ++i) {
66+
output_dims.push_back(in_dims[i]);
67+
}
68+
output_dims.push_back(w_dims[1]);
5369

54-
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
70+
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
5571
ctx->ShareLoD("Input", "Out");
5672
}
5773

@@ -101,12 +117,15 @@ framework::OpKernelType FCOpGrad::GetExpectedKernelType(
101117
}
102118

103119
void FCOpMaker::Make() {
104-
AddInput("Input",
105-
"(Tensor), The input tensor of fully connected operator with format "
106-
"(NCHW). ");
120+
AddInput("Input", "(Tensor), The input tensor of fully connected operator.");
107121
AddInput("W", "(Tensor), The weight fc op with shape (I, O).");
108122
AddInput("Bias", "(Tensor, optional) Bias vector with shape (1 x O")
109123
.AsDispensable();
124+
AddAttr<int>("in_num_col_dims",
125+
"(int, default 1), The fc op can take tensors with more than "
126+
"two dimensions as its inputs.")
127+
.SetDefault(1)
128+
.EqualGreaterThan(1);
110129
AddOutput("Out", "(Tensor) The output tensor of fully connected operator. ");
111130
AddAttr<bool>("use_mkldnn",
112131
"(bool, default false) Only used in mkldnn kernel")
@@ -131,13 +150,15 @@ class FCOpKernel : public framework::OpKernel<T> {
131150
auto output = ctx.Output<Tensor>("Out");
132151
auto in_dims = input->dims();
133152
auto w_dims = w->dims();
153+
auto out_dims = output->dims();
154+
int M = framework::product(out_dims) / out_dims[out_dims.size() - 1];
134155

135156
const T* input_data = input->data<T>();
136157
const T* w_data = w->data<T>();
137158
T* output_data = output->mutable_data<T>(ctx.GetPlace());
138159
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
139160
math::FCCompute<platform::CPUDeviceContext, T>(
140-
blas, in_dims[0], w_dims[1], w_dims[0], input_data, w_data, output_data,
161+
blas, M, w_dims[1], w_dims[0], input_data, w_data, output_data,
141162
bias ? bias->data<T>() : NULL);
142163

143164
// TODO(TJ): fuse act

paddle/fluid/operators/hash_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class HashOp : public framework::OperatorWithKernel {
3838
std::vector<int64_t> out_dims;
3939
out_dims.reserve(dims.size() + 1);
4040
// copy all dims except the last one
41-
for (size_t i = 0u; i != dims.size() - 1; ++i) {
41+
for (int i = 0u; i != dims.size() - 1; ++i) {
4242
out_dims.emplace_back(dims[i]);
4343
}
4444
int num_hash = ctx->Attrs().Get<int>("num_hash");

paddle/fluid/operators/math/selected_rows_functor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ typename std::enable_if<
244244
std::is_same<DeviceContext, platform::CPUDeviceContext>::value>::type
245245
elementwise_add_to(const DeviceContext& ctx, BlasT<DeviceContext, T>* blas,
246246
size_t data_len, const T* in, T* out) {
247-
for (int64_t i = 0; i < data_len; i++) {
247+
for (size_t i = 0; i < data_len; i++) {
248248
out[i] += in[i];
249249
}
250250
}

0 commit comments

Comments
 (0)