Skip to content

Commit 046405e

Browse files
authored
Merge pull request #10486 from reyoung/feature/clean_op_maker
Clean OpProtoAndCheckerMaker
2 parents ba57348 + 613d3ef commit 046405e

File tree

198 files changed

+354
-729
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

198 files changed

+354
-729
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ cc_library(data_transform SRCS data_transform.cc DEPS math_function tensor
5757
cc_library(attribute SRCS attribute.cc DEPS framework_proto boost)
5858
cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc
5959
device_context)
60-
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
60+
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute glog)
6161
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
6262
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
6363
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute device_context)

paddle/fluid/framework/data_device_transform_test.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@ struct AddFunctor {
3232

3333
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
3434
public:
35-
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
36-
: OpProtoAndCheckerMaker(proto, op_checker) {
35+
void Make() {
3736
AddInput("input", "input1 of test op");
3837
AddOutput("output", "output of test op");
3938
AddAttr<bool>("use_gpu", "force to use gpu kernel").SetDefault(false);

paddle/fluid/framework/details/op_registry.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,10 @@ struct OpInfoFiller<T, kOpProtoAndCheckerMaker> {
9595
void operator()(const char* op_type, OpInfo* info) const {
9696
info->proto_ = new proto::OpProto;
9797
info->checker_ = new OpAttrChecker();
98-
auto maker = T(info->proto_, info->checker_);
98+
T maker;
99+
maker.SetProto(info->proto_);
100+
maker.SetChecker(info->checker_);
101+
maker.Make();
99102
maker.Validate();
100103
info->proto_->set_type(op_type);
101104
PADDLE_ENFORCE(

paddle/fluid/framework/op_proto_maker.h

Lines changed: 20 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,78 +14,72 @@ limitations under the License. */
1414
#pragma once
1515

1616
#include <string>
17+
#include "glog/logging.h"
1718
#include "paddle/fluid/framework/attribute.h"
1819
#include "paddle/fluid/framework/framework.pb.h"
19-
2020
namespace paddle {
2121
namespace framework {
2222

2323
// this class not only make proto but also init attribute checkers.
2424
class OpProtoAndCheckerMaker {
2525
public:
26-
using OpProto = proto::OpProto;
27-
using OpAttrChecker = framework::OpAttrChecker;
28-
OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
29-
: proto_(proto), op_checker_(op_checker) {}
26+
virtual void Make() = 0;
3027

3128
virtual ~OpProtoAndCheckerMaker() {
32-
PADDLE_ENFORCE(validated_, "should call Validate after build");
29+
CHECK(validated_) << "should call Validate after build";
3330
}
3431

32+
void SetProto(proto::OpProto *proto) { proto_ = proto; }
33+
34+
void SetChecker(OpAttrChecker *attr_checker) { op_checker_ = attr_checker; }
35+
3536
void Validate();
3637

3738
protected:
3839
struct VariableBuilder {
39-
OpProto::Var* var_;
40+
proto::OpProto::Var *var_;
4041

41-
VariableBuilder& AsDuplicable() {
42+
VariableBuilder &AsDuplicable() {
4243
var_->set_duplicable(true);
4344
return *this;
4445
}
4546

46-
VariableBuilder& AsIntermediate() {
47+
VariableBuilder &AsIntermediate() {
4748
var_->set_intermediate(true);
4849
return *this;
4950
}
5051

51-
VariableBuilder& AsDispensable() {
52+
VariableBuilder &AsDispensable() {
5253
var_->set_dispensable(true);
5354
return *this;
5455
}
5556
};
5657

57-
VariableBuilder AddInput(const std::string& name, const std::string& comment);
58+
VariableBuilder AddInput(const std::string &name, const std::string &comment);
5859

59-
VariableBuilder AddOutput(const std::string& name,
60-
const std::string& comment);
60+
VariableBuilder AddOutput(const std::string &name,
61+
const std::string &comment);
6162

6263
template <typename T>
63-
TypedAttrChecker<T>& AddAttr(const std::string& name,
64-
const std::string& comment,
64+
TypedAttrChecker<T> &AddAttr(const std::string &name,
65+
const std::string &comment,
6566
bool generated = false) {
66-
auto* attr = proto_->add_attrs();
67+
auto *attr = proto_->add_attrs();
6768
attr->set_name(name);
6869
attr->set_comment(comment);
6970
attr->set_generated(generated);
7071
attr->set_type(AttrTypeID<T>());
7172
return op_checker_->AddAttrChecker<T>(name);
7273
}
7374

74-
void AddComment(const std::string& comment) { proto_->set_comment(comment); }
75+
void AddComment(const std::string &comment) { proto_->set_comment(comment); }
7576

7677
private:
7778
void CheckNoDuplicatedInOutAttrs();
7879

79-
OpProto* proto_;
80-
OpAttrChecker* op_checker_;
80+
proto::OpProto *proto_;
81+
OpAttrChecker *op_checker_;
8182
bool validated_{false};
8283
};
83-
84-
class NOPMaker : public OpProtoAndCheckerMaker {
85-
public:
86-
NOPMaker(OpProto* proto, framework::OpAttrChecker* op_checker)
87-
: OpProtoAndCheckerMaker(proto, op_checker) {}
88-
};
89-
9084
} // namespace framework
9185
} // namespace paddle

paddle/fluid/framework/op_proto_maker_test.cc

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,7 @@ limitations under the License. */
1818

1919
class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
2020
public:
21-
TestAttrProtoMaker(paddle::framework::proto::OpProto* proto,
22-
paddle::framework::OpAttrChecker* op_checker)
23-
: OpProtoAndCheckerMaker(proto, op_checker) {
21+
void Make() {
2422
AddAttr<float>("scale", "scale of test op");
2523
AddAttr<float>("scale", "scale of test op");
2624
}
@@ -29,15 +27,16 @@ class TestAttrProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
2927
TEST(ProtoMaker, DuplicatedAttr) {
3028
paddle::framework::proto::OpProto op_proto;
3129
paddle::framework::OpAttrChecker op_checker;
32-
auto proto_maker = TestAttrProtoMaker(&op_proto, &op_checker);
30+
TestAttrProtoMaker proto_maker;
31+
proto_maker.SetProto(&op_proto);
32+
proto_maker.SetChecker(&op_checker);
33+
proto_maker.Make();
3334
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
3435
}
3536

3637
class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
3738
public:
38-
TestInOutProtoMaker(paddle::framework::proto::OpProto* proto,
39-
paddle::framework::OpAttrChecker* op_checker)
40-
: OpProtoAndCheckerMaker(proto, op_checker) {
39+
void Make() {
4140
AddInput("input", "input of test op");
4241
AddInput("input", "input of test op");
4342
}
@@ -46,6 +45,9 @@ class TestInOutProtoMaker : public paddle::framework::OpProtoAndCheckerMaker {
4645
TEST(ProtoMaker, DuplicatedInOut) {
4746
paddle::framework::proto::OpProto op_proto;
4847
paddle::framework::OpAttrChecker op_checker;
49-
auto proto_maker = TestInOutProtoMaker(&op_proto, &op_checker);
48+
TestAttrProtoMaker proto_maker;
49+
proto_maker.SetProto(&op_proto);
50+
proto_maker.SetChecker(&op_checker);
51+
proto_maker.Make();
5052
ASSERT_THROW(proto_maker.Validate(), paddle::platform::EnforceNotMet);
5153
}

paddle/fluid/framework/op_registry_test.cc

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,7 @@ class CosineOp : public OperatorBase {
3333

3434
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
3535
public:
36-
CosineOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
37-
: OpProtoAndCheckerMaker(proto, op_checker) {
36+
void Make() {
3837
AddInput("input", "input of cosine op");
3938
AddOutput("output", "output of cosine op");
4039
AddAttr<float>("scale", "scale of cosine op")
@@ -55,8 +54,7 @@ class MyTestOp : public OperatorBase {
5554

5655
class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
5756
public:
58-
MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
59-
: OpProtoAndCheckerMaker(proto, op_checker) {
57+
void Make() {
6058
AddInput("input", "input of cosine op").AsDuplicable();
6159
AddOutput("output", "output of cosine op").AsIntermediate();
6260
auto my_checker = [](int i) {
@@ -212,10 +210,7 @@ namespace framework {
212210

213211
class OpKernelTestMaker : public OpProtoAndCheckerMaker {
214212
public:
215-
OpKernelTestMaker(OpProto* proto, OpAttrChecker* op_checker)
216-
: OpProtoAndCheckerMaker(proto, op_checker) {
217-
AddComment("NoGradOp, same input output. no Grad");
218-
}
213+
void Make() { AddComment("NoGradOp, same input output. no Grad"); }
219214
};
220215

221216
class OpWithKernelTest : public OperatorWithKernel {
@@ -275,9 +270,9 @@ TEST(OperatorRegistrar, CUDA) {
275270

276271
static int op_test_value = 0;
277272

278-
using paddle::platform::DeviceContext;
279273
using paddle::platform::CPUDeviceContext;
280274
using paddle::platform::CUDADeviceContext;
275+
using paddle::platform::DeviceContext;
281276

282277
namespace paddle {
283278
namespace framework {

paddle/fluid/framework/operator_test.cc

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,7 @@ class OpWithoutKernelTest : public OperatorBase {
4646

4747
class OpWithoutKernelCheckerMaker : public OpProtoAndCheckerMaker {
4848
public:
49-
OpWithoutKernelCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
50-
: OpProtoAndCheckerMaker(proto, op_checker) {
49+
void Make() {
5150
AddInput("input", "input of test op");
5251
AddOutput("output", "output of test op");
5352
AddAttr<float>("scale", "scale of cosine op");
@@ -98,8 +97,7 @@ namespace framework {
9897

9998
class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
10099
public:
101-
OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker)
102-
: OpProtoAndCheckerMaker(proto, op_checker) {
100+
void Make() {
103101
AddInput("x", "input of test op");
104102
AddOutput("y", "output of test op");
105103
AddAttr<float>("scale", "scale of cosine op")
@@ -137,9 +135,7 @@ class CPUKernelTest : public OpKernel<float> {
137135
class OpKernelTestMultiInputsProtoAndCheckerMaker
138136
: public OpProtoAndCheckerMaker {
139137
public:
140-
OpKernelTestMultiInputsProtoAndCheckerMaker(OpProto* proto,
141-
OpAttrChecker* op_checker)
142-
: OpProtoAndCheckerMaker(proto, op_checker) {
138+
void Make() {
143139
AddInput("xs", "inputs of test op").AsDuplicable();
144140
AddInput("k", "input of test op");
145141
AddOutput("ys", "outputs of test op").AsDuplicable();

paddle/fluid/framework/var_type_inference_test.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ namespace framework {
2424

2525
class SumOpMaker : public OpProtoAndCheckerMaker {
2626
public:
27-
SumOpMaker(OpProto *proto, OpAttrChecker *op_checker)
28-
: OpProtoAndCheckerMaker(proto, op_checker) {
27+
void Make() {
2928
AddInput("X", "").AsDuplicable();
3029
AddOutput("Out", "");
3130
AddComment("");

paddle/fluid/operators/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ function(op_library TARGET)
166166
# NOTE(*): activation use macro to regist the kernels, set use_op manually.
167167
if(${TARGET} STREQUAL "activation")
168168
file(APPEND ${pybind_file} "USE_OP(relu);\n")
169+
elseif(${TARGET} STREQUAL "reduce")
170+
file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n")
169171
else()
170172
file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
171173
endif()

paddle/fluid/operators/accuracy_op.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
6363

6464
class AccuracyOpMaker : public framework::OpProtoAndCheckerMaker {
6565
public:
66-
AccuracyOpMaker(OpProto *proto, OpAttrChecker *op_checker)
67-
: OpProtoAndCheckerMaker(proto, op_checker) {
66+
void Make() override {
6867
// TODO(typhoonzero): support both inference value and indices.
6968
AddInput("Out", "The network output of topk (inferences)");
7069
AddInput("Indices", "The the network output of topk (indices)");

0 commit comments

Comments
 (0)