Skip to content

Commit 8a1abe5

Browse files
committed
clean fusion infershape code
1 parent 916f42b commit 8a1abe5

File tree

4 files changed

+66
-99
lines changed

4 files changed

+66
-99
lines changed

paddle/fluid/operators/attention_lstm_op.cc

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/attention_lstm_op.h"
1616
#include <string>
17-
#include "paddle/fluid/framework/shape_runtime_infer.h"
17+
#include "paddle/fluid/operators/fusion_infershape_define.h"
1818
#include "paddle/fluid/operators/math/blas.h"
1919
#include "paddle/fluid/operators/math/cpu_vec.h"
2020
#include "paddle/fluid/operators/math/fc_compute.h"
@@ -24,38 +24,7 @@ namespace paddle {
2424
namespace operators {
2525

2626
void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
27-
auto* runtime_ctx = dynamic_cast<framework::RuntimeInferShapeContext*>(ctx);
28-
if (runtime_ctx == nullptr) {
29-
LOG(FATAL) << "Should have runtime infer context";
30-
}
31-
const auto& ins = runtime_ctx->OpBase().Inputs();
32-
const auto& outs = runtime_ctx->OpBase().Outputs();
33-
const auto& scope = runtime_ctx->InferScope();
34-
const auto ins_end = ins.end();
35-
const auto outs_end = outs.end();
36-
auto fair_input = [&](const std::string& name) -> bool {
37-
auto it = ins.find(name);
38-
if (it == ins_end) {
39-
return false;
40-
}
41-
const auto& in = it->second;
42-
if (in.size() != 1 || in[0] == framework::kEmptyVarName) {
43-
return false;
44-
}
45-
return scope.FindVar(in[0]) != nullptr;
46-
};
47-
auto fair_output = [&](const std::string& name) -> bool {
48-
auto it = outs.find(name);
49-
if (it == outs_end) {
50-
return false;
51-
}
52-
const auto& out = it->second;
53-
if (out.size() != 1 || out[0] == framework::kEmptyVarName) {
54-
return false;
55-
}
56-
return scope.FindVar(out[0]) != nullptr;
57-
};
58-
27+
FUSION_INFERSHAPE_INIT;
5928
PADDLE_ENFORCE(fair_input("X"), "Assert only one Input(X) of AttentionLSTM.");
6029
PADDLE_ENFORCE(fair_input("C0"),
6130
"Assert only one Input(C0) of AttentionLSTM.");

paddle/fluid/operators/fusion_gru_op.cc

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ limitations under the License. */
1515
#include "paddle/fluid/operators/fusion_gru_op.h"
1616
#include <cstring> // for memcpy
1717
#include <string>
18-
#include "paddle/fluid/framework/shape_runtime_infer.h"
18+
#include "paddle/fluid/operators/fusion_infershape_define.h"
1919
#include "paddle/fluid/operators/math/blas.h"
2020
#include "paddle/fluid/operators/math/cpu_vec.h"
2121
#include "paddle/fluid/operators/math/fc_compute.h"
@@ -26,38 +26,7 @@ namespace paddle {
2626
namespace operators {
2727

2828
void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
29-
auto* runtime_ctx = dynamic_cast<framework::RuntimeInferShapeContext*>(ctx);
30-
if (runtime_ctx == nullptr) {
31-
LOG(FATAL) << "Should have runtime infer context";
32-
}
33-
const auto& ins = runtime_ctx->OpBase().Inputs();
34-
const auto& outs = runtime_ctx->OpBase().Outputs();
35-
const auto& scope = runtime_ctx->InferScope();
36-
const auto ins_end = ins.end();
37-
const auto outs_end = outs.end();
38-
auto fair_input = [&](const std::string& name) -> bool {
39-
auto it = ins.find(name);
40-
if (it == ins_end) {
41-
return false;
42-
}
43-
const auto& in = it->second;
44-
if (in.size() != 1 || in[0] == framework::kEmptyVarName) {
45-
return false;
46-
}
47-
return scope.FindVar(in[0]) != nullptr;
48-
};
49-
auto fair_output = [&](const std::string& name) -> bool {
50-
auto it = outs.find(name);
51-
if (it == outs_end) {
52-
return false;
53-
}
54-
const auto& out = it->second;
55-
if (out.size() != 1 || out[0] == framework::kEmptyVarName) {
56-
return false;
57-
}
58-
return scope.FindVar(out[0]) != nullptr;
59-
};
60-
29+
FUSION_INFERSHAPE_INIT;
6130
PADDLE_ENFORCE(fair_input("X"), "Assert only one Input(X) of GRU.");
6231
PADDLE_ENFORCE(fair_input("WeightX"),
6332
"Assert only one Input(WeightX) of GRU.");
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#ifndef PADDLE_FLUID_OPERATORS_FUSION_INFERSHAPE_DEFINE_H_
16+
#define PADDLE_FLUID_OPERATORS_FUSION_INFERSHAPE_DEFINE_H_
17+
18+
#include <string>
19+
#include "paddle/fluid/framework/shape_runtime_infer.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
#define FUSION_INFERSHAPE_INIT \
25+
auto* runtime_ctx = dynamic_cast<framework::RuntimeInferShapeContext*>(ctx); \
26+
if (runtime_ctx == nullptr) { \
27+
LOG(FATAL) << "Should have runtime infer context"; \
28+
} \
29+
const auto& ins = runtime_ctx->OpBase().Inputs(); \
30+
const auto& outs = runtime_ctx->OpBase().Outputs(); \
31+
const auto& scope = runtime_ctx->InferScope(); \
32+
const auto ins_end = ins.end(); \
33+
const auto outs_end = outs.end(); \
34+
auto fair_input = [&](const std::string& name) -> bool { \
35+
auto it = ins.find(name); \
36+
if (it == ins_end) { \
37+
return false; \
38+
} \
39+
const auto& in = it->second; \
40+
if (in.size() != 1 || in[0] == framework::kEmptyVarName) { \
41+
return false; \
42+
} \
43+
return scope.FindVar(in[0]) != nullptr; \
44+
}; \
45+
auto fair_output = [&](const std::string& name) -> bool { \
46+
auto it = outs.find(name); \
47+
if (it == outs_end) { \
48+
return false; \
49+
} \
50+
const auto& out = it->second; \
51+
if (out.size() != 1 || out[0] == framework::kEmptyVarName) { \
52+
return false; \
53+
} \
54+
return scope.FindVar(out[0]) != nullptr; \
55+
}
56+
57+
} // namespace operators
58+
} // namespace paddle
59+
60+
#endif // PADDLE_FLUID_OPERATORS_FUSION_INFERSHAPE_DEFINE_H_

paddle/fluid/operators/fusion_lstm_op.cc

Lines changed: 2 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/fusion_lstm_op.h"
1616
#include <string>
17-
#include "paddle/fluid/framework/shape_runtime_infer.h"
17+
#include "paddle/fluid/operators/fusion_infershape_define.h"
1818
#include "paddle/fluid/operators/math/blas.h"
1919
#include "paddle/fluid/operators/math/cpu_vec.h"
2020
#include "paddle/fluid/operators/math/fc_compute.h"
@@ -25,38 +25,7 @@ namespace paddle {
2525
namespace operators {
2626

2727
void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
28-
auto* runtime_ctx = dynamic_cast<framework::RuntimeInferShapeContext*>(ctx);
29-
if (runtime_ctx == nullptr) {
30-
LOG(FATAL) << "Should have runtime infer context";
31-
}
32-
const auto& ins = runtime_ctx->OpBase().Inputs();
33-
const auto& outs = runtime_ctx->OpBase().Outputs();
34-
const auto& scope = runtime_ctx->InferScope();
35-
const auto ins_end = ins.end();
36-
const auto outs_end = outs.end();
37-
auto fair_input = [&](const std::string& name) -> bool {
38-
auto it = ins.find(name);
39-
if (it == ins_end) {
40-
return false;
41-
}
42-
const auto& in = it->second;
43-
if (in.size() != 1 || in[0] == framework::kEmptyVarName) {
44-
return false;
45-
}
46-
return scope.FindVar(in[0]) != nullptr;
47-
};
48-
auto fair_output = [&](const std::string& name) -> bool {
49-
auto it = outs.find(name);
50-
if (it == outs_end) {
51-
return false;
52-
}
53-
const auto& out = it->second;
54-
if (out.size() != 1 || out[0] == framework::kEmptyVarName) {
55-
return false;
56-
}
57-
return scope.FindVar(out[0]) != nullptr;
58-
};
59-
28+
FUSION_INFERSHAPE_INIT;
6029
PADDLE_ENFORCE(fair_input("X"), "Assert only one Input(X) of LSTM.");
6130
PADDLE_ENFORCE(fair_input("WeightX"),
6231
"Assert only one Input(WeightX) of LSTM.");

0 commit comments

Comments
 (0)