Skip to content

Commit 0e78cb6

Browse files
committed
Clean OpProtoAndCheckerMaker
Do not use ctor * Reduce line of codes. * We can use virtual function for Maker now. * The implementation does not care what maker holds, it is easier to refactor later.
1 parent 2a22da6 commit 0e78cb6

File tree

196 files changed

+332
-711
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

196 files changed

+332
-711
lines changed

paddle/fluid/framework/data_device_transform_test.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@ struct AddFunctor {
3232

3333
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
3434
public:
35-
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
36-
: OpProtoAndCheckerMaker(proto, op_checker) {
35+
void Make() {
3736
AddInput("input", "input1 of test op");
3837
AddOutput("output", "output of test op");
3938
AddAttr<bool>("use_gpu", "force to use gpu kernel").SetDefault(false);

paddle/fluid/framework/details/op_registry.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,10 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
9595
void operator()(const char* op_type, OpInfo* info) const {
9696
info->proto_ = new proto::OpProto;
9797
info->checker_ = new OpAttrChecker();
98-
auto maker = T(info->proto_, info->checker_);
98+
T maker;
99+
maker.SetProto(info->proto_);
100+
maker.SetChecker(info->checker_);
101+
maker.Make();
99102
maker.Validate();
100103
info->proto_->set_type(op_type);
101104
PADDLE_ENFORCE(

paddle/fluid/framework/op_proto_maker.h

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,21 @@ namespace framework {
2323
// this class not only make proto but also init attribute checkers.
2424
class OpProtoAndCheckerMaker {
2525
public:
26-
using OpProto = proto::OpProto;
27-
using OpAttrChecker = framework::OpAttrChecker;
28-
OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
29-
: proto_(proto), op_checker_(op_checker) {}
26+
virtual void Make() = 0;
3027

3128
virtual ~OpProtoAndCheckerMaker() {
3229
PADDLE_ENFORCE(validated_, "should call Validate after build");
3330
}
3431

32+
void SetProto(proto::OpProto* proto) { proto_ = proto; }
33+
34+
void SetChecker(OpAttrChecker* attr_checker) { op_checker_ = attr_checker; }
35+
3536
void Validate();
3637

3738
protected:
3839
struct VariableBuilder {
39-
OpProto::Var* var_;
40+
proto::OpProto::Var* var_;
4041

4142
VariableBuilder& AsDuplicable() {
4243
var_->set_duplicable(true);
@@ -76,16 +77,9 @@ class OpProtoAndCheckerMaker {
7677
private:
7778
void CheckNoDuplicatedInOutAttrs();
7879

79-
OpProto* proto_;
80+
proto::OpProto* proto_;
8081
OpAttrChecker* op_checker_;
8182
bool validated_{false};
8283
};
83-
84-
class NOPMaker : public OpProtoAndCheckerMaker {
85-
public:
86-
NOPMaker(OpProto* proto, framework::OpAttrChecker* op_checker)
87-
: OpProtoAndCheckerMaker(proto, op_checker) {}
88-
};
89-
9084
} // namespace framework
9185
} // namespace paddle

paddle/fluid/framework/op_proto_maker_test.cc

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@ limitations under the License. */
1818

1919
class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
2020
public:
21-
TestAttrProtoMaker(paddle::framework::proto::OpProto* proto,
22-
paddle::framework::OpAttrChecker* op_checker)
23-
: OpProtoAndCheckerMaker(proto, op_checker) {
21+
void Make() {
2422
AddAttr<float>("scale", "scale of test op");
2523
AddAttr<float>("scale", "scale of test op");
2624
}
@@ -29,15 +27,16 @@ class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
2927
TEST(ProtoMaker, DuplicatedAttr) {
3028
paddle::framework::proto::OpProto op_proto;
3129
paddle::framework::OpAttrChecker op_checker;
32-
auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
30+
TestAttrProtoMaker proto_maker;
31+
proto_maker.SetProto(&op_proto);
32+
proto_maker.SetChecker(&op_checker);
33+
proto_maker.Make();
3334
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
3435
}
3536

3637
class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
3738
public:
38-
TestInOutProtoMaker(paddle::framework::proto::OpProto* proto,
39-
paddle::framework::OpAttrChecker* op_checker)
40-
: OpProtoAndCheckerMaker(proto, op_checker) {
39+
void Make() {
4140
AddInput("input", "input of test op");
4241
AddInput("input", "input of test op");
4342
}
@@ -46,6 +45,9 @@ class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
4645
TEST(ProtoMaker, DuplicatedInOut) {
4746
paddle::framework::proto::OpProto op_proto;
4847
paddle::framework::OpAttrChecker op_checker;
49-
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
48+
TestAttrProtoMaker proto_maker;
49+
proto_maker.SetProto(&op_proto);
50+
proto_maker.SetChecker(&op_checker);
51+
proto_maker.Make();
5052
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
5153
}

paddle/fluid/framework/op_registry_test.cc

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ class CosineOp : public OperatorBase {
3333

3434
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
3535
public:
36-
CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
37-
: OpProtoAndCheckerMaker(proto, op_checker) {
36+
void Make() {
3837
AddInput("input", "input of cosine op");
3938
AddOutput("output", "output of cosine op");
4039
AddAttr<float>("scale", "scale of cosine op")
@@ -55,8 +54,7 @@ class MyTestOp : public OperatorBase {
5554

5655
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
5756
public:
58-
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
59-
: OpProtoAndCheckerMaker(proto, op_checker) {
57+
void Make() {
6058
AddInput("input", "input of cosine op").AsDuplicable();
6159
AddOutput("output", "output of cosine op").AsIntermediate();
6260
auto my_checker = [](int i) {
@@ -212,10 +210,7 @@ namespace framework {
212210

213211
class OpKernelTestMaker : public OpProtoAndCheckerMaker {
214212
public:
215-
OpKernelTestMaker(OpProto* proto, OpAttrChecker* op_checker)
216-
: OpProtoAndCheckerMaker(proto, op_checker) {
217-
AddComment("NoGradOp, same input output. no Grad");
218-
}
213+
void Make() { AddComment("NoGradOp, same input output. no Grad"); }
219214
};
220215

221216
class OpWithKernelTest : public OperatorWithKernel {
@@ -275,9 +270,9 @@ TEST(OperatorRegistrar, CUDA) {
275270

276271
static int op_test_value = 0;
277272

278-
using paddle::platform::DeviceContext;
279273
using paddle::platform::CPUDeviceContext;
280274
using paddle::platform::CUDADeviceContext;
275+
using paddle::platform::DeviceContext;
281276

282277
namespace paddle {
283278
namespace framework {

paddle/fluid/framework/operator_test.cc

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ class OpWithoutKernelTest : public OperatorBase {
4646

4747
class OpWithoutKernelCheckerMaker : public OpProtoAndCheckerMaker {
4848
public:
49-
OpWithoutKernelCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
50-
: OpProtoAndCheckerMaker(proto, op_checker) {
49+
void Make() {
5150
AddInput("input", "input of test op");
5251
AddOutput("output", "output of test op");
5352
AddAttr<float>("scale", "scale of cosine op");
@@ -98,8 +97,7 @@ namespace framework {
9897

9998
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
10099
public:
101-
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
102-
: OpProtoAndCheckerMaker(proto, op_checker) {
100+
void Make() {
103101
AddInput("x", "input of test op");
104102
AddOutput("y", "output of test op");
105103
AddAttr<float>("scale", "scale of cosine op")
@@ -137,9 +135,7 @@ class CPUKernelTest : public OpKernel<float> {
137135
class OpKernelTestMultiInputsProtoAndCheckerMaker
138136
: public OpProtoAndCheckerMaker {
139137
public:
140-
OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto,
141-
OpAttrChecker* op_checker)
142-
: OpProtoAndCheckerMaker(proto, op_checker) {
138+
void Make() {
143139
AddInput("xs", "inputs of test op").AsDuplicable();
144140
AddInput("k", "input of test op");
145141
AddOutput("ys", "outputs of test op").AsDuplicable();

paddle/fluid/framework/var_type_inference_test.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ namespace framework {
2424

2525
class SumOpMaker : public OpProtoAndCheckerMaker {
2626
public:
27-
SumOpMaker(OpProto *proto, OpAttrChecker *op_checker)
28-
: OpProtoAndCheckerMaker(proto, op_checker) {
27+
void Make() {
2928
AddInput("X", "").AsDuplicable();
3029
AddOutput("Out", "");
3130
AddComment("");

paddle/fluid/operators/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ function(op_library TARGET)
166166
# NOTE(*): activation use macro to regist the kernels, set use_op manually.
167167
if(${TARGET} STREQUAL "activation")
168168
file(APPEND ${pybind_file} "USE_OP(relu);\n")
169+
elseif(${TARGET} STREQUAL "reduce")
170+
file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n")
169171
else()
170172
file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
171173
endif()

paddle/fluid/operators/accuracy_op.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
6363

6464
class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker {
6565
public:
66-
AccuracyOpMaker(OpProto *proto, OpAttrChecker *op_checker)
67-
: OpProtoAndCheckerMaker(proto, op_checker) {
66+
void Make() override {
6867
// TODO(typhoonzero): support both inference value and indices.
6968
AddInput("Out", "The network output of topk (inferences)");
7069
AddInput("Indices", "The the network output of topk (indices)");

paddle/fluid/operators/activation_op.cc

Lines changed: 24 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,18 @@ limitations under the License. */
1919
namespace paddle {
2020
namespace operators {
2121

22-
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
23-
class OP_NAME##OpMaker \
24-
: public ::paddle::framework::OpProtoAndCheckerMaker { \
25-
public: \
26-
OP_NAME##OpMaker(OpProto *proto, OpAttrChecker *op_checker) \
27-
: ::paddle::framework::OpProtoAndCheckerMaker(proto, op_checker) { \
28-
AddInput("X", "Input of " #OP_NAME "operator"); \
29-
AddOutput("Out", "Output of" #OP_NAME "operator"); \
30-
AddAttr<bool>("use_mkldnn", \
31-
"(bool, default false) Only used in mkldnn kernel") \
32-
.SetDefault(false); \
33-
AddComment(#OP_COMMENT); \
34-
} \
22+
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
23+
class OP_NAME##OpMaker \
24+
: public ::paddle::framework::OpProtoAndCheckerMaker { \
25+
public: \
26+
void Make() override { \
27+
AddInput("X", "Input of " #OP_NAME "operator"); \
28+
AddOutput("Out", "Output of" #OP_NAME "operator"); \
29+
AddAttr<bool>("use_mkldnn", \
30+
"(bool, default false) Only used in mkldnn kernel") \
31+
.SetDefault(false); \
32+
AddComment(#OP_COMMENT); \
33+
} \
3534
}
3635

3736
#define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE) \
@@ -204,8 +203,7 @@ Softsign Activation Operator.
204203

205204
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
206205
public:
207-
LeakyReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
208-
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
206+
void Make() override {
209207
AddInput("X", "Input of LeakyRelu operator");
210208
AddOutput("Out", "Output of LeakyRelu operator");
211209
AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
@@ -220,8 +218,7 @@ LeakyRelu Activation Operator.
220218

221219
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
222220
public:
223-
SoftShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
224-
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
221+
void Make() override {
225222
AddInput("X", "Input of Softshrink operator");
226223
AddOutput("Out", "Output of Softshrink operator");
227224
AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
@@ -242,8 +239,7 @@ out = \begin{cases}
242239

243240
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
244241
public:
245-
HardShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
246-
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
242+
void Make() override {
247243
AddInput("X", "Input of HardShrink operator");
248244
AddOutput("Out", "Output of HardShrink operator");
249245
AddAttr<float>("threshold", "The value of threshold for HardShrink")
@@ -265,8 +261,7 @@ out = \begin{cases}
265261

266262
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
267263
public:
268-
BReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
269-
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
264+
void Make() override {
270265
AddInput("X", "Input of BRelu operator");
271266
AddOutput("Out", "Output of BRelu operator");
272267
AddAttr<float>("t_min", "The min marginal value of BRelu")
@@ -284,8 +279,7 @@ BRelu Activation Operator.
284279

285280
class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
286281
public:
287-
SoftReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
288-
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
282+
void Make() override {
289283
AddInput("X", "Input of SoftRelu operator");
290284
AddOutput("Out", "Output of SoftRelu operator");
291285
AddAttr<float>("threshold", "The threshold value of SoftRelu")
@@ -301,8 +295,7 @@ SoftRelu Activation Operator.
301295

302296
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
303297
public:
304-
ELUOpMaker(OpProto *proto, OpAttrChecker *op_checker)
305-
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
298+
void Make() override {
306299
AddInput("X", "Input of ELU operator");
307300
AddOutput("Out", "Output of ELU operator");
308301
AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
@@ -320,8 +313,7 @@ Applies the following element-wise computation on the input according to
320313

321314
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
322315
public:
323-
Relu6OpMaker(OpProto *proto, OpAttrChecker *op_checker)
324-
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
316+
void Make() override {
325317
AddInput("X", "Input of Relu6 operator");
326318
AddOutput("Out", "Output of Relu6 operator");
327319
AddAttr<float>("threshold", "The threshold value of Relu6")
@@ -337,8 +329,7 @@ Relu6 Activation Operator.
337329

338330
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
339331
public:
340-
PowOpMaker(OpProto *proto, OpAttrChecker *op_checker)
341-
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
332+
void Make() override {
342333
AddInput("X", "Input of Pow operator");
343334
AddOutput("Out", "Output of Pow operator");
344335
AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
@@ -353,8 +344,7 @@ Pow Activation Operator.
353344

354345
class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
355346
public:
356-
STanhOpMaker(OpProto *proto, OpAttrChecker *op_checker)
357-
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
347+
void Make() override {
358348
AddInput("X", "Input of STanh operator");
359349
AddOutput("Out", "Output of STanh operator");
360350
AddAttr<float>("scale_a", "The scale parameter of a for the input")
@@ -372,8 +362,7 @@ STanh Activation Operator.
372362

373363
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
374364
public:
375-
ThresholdedReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
376-
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
365+
void Make() override {
377366
AddInput("X", "Input of ThresholdedRelu operator");
378367
AddOutput("Out", "Output of ThresholdedRelu operator");
379368
AddAttr<float>("threshold", "The threshold location of activation")
@@ -394,8 +383,7 @@ out = \begin{cases}
394383

395384
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
396385
public:
397-
HardSigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
398-
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
386+
void Make() override {
399387
AddInput("X", "Input of HardSigmoid operator");
400388
AddOutput("Out", "Output of HardSigmoid operator");
401389
AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
@@ -420,8 +408,7 @@ It is recommended to use the defaults for this activation.
420408

421409
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
422410
public:
423-
SwishOpMaker(OpProto *proto, OpAttrChecker *op_checker)
424-
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
411+
void Make() override {
425412
AddInput("X", "Input of Swish operator");
426413
AddOutput("Out", "Output of Swish operator");
427414
AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);

0 commit comments

Comments
 (0)