12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
- #include " paddle/operators/increment_op .h"
15
+ #include " paddle/framework/op_registry .h"
16
16
17
17
namespace paddle {
18
18
namespace operators {
19
19
20
- class IncrementOp : public framework ::OperatorWithKernel {
20
+ class IncrementInferShape : public framework ::InferShapeBase {
21
21
public:
22
- using framework::OperatorWithKernel::OperatorWithKernel;
23
-
24
- void InferShape (framework::InferShapeContext *ctx) const override {
22
+ void operator ()(framework::InferShapeContext *ctx) const override {
25
23
PADDLE_ENFORCE (ctx->HasInput (" X" ),
26
24
" Input(X) of IncrementOp should not be null." );
27
25
PADDLE_ENFORCE (ctx->HasOutput (" Out" ),
28
26
" Output(Out) of IncrementOp should not be null." );
27
+ PADDLE_ENFORCE_EQ (1 , framework::product (ctx->GetInputDim (" X" )));
29
28
ctx->SetOutputDim (" Out" , ctx->GetInputDim (" X" ));
30
- ctx->ShareLoD (" X" , /* ->*/ " Out" );
29
+ }
30
+ };
31
+
32
+ struct IncrementFunctor {
33
+ IncrementFunctor (const framework::LoDTensor &x, framework::LoDTensor *out,
34
+ float value)
35
+ : x_(x), out_(out), value_(value) {}
36
+
37
+ template <typename T>
38
+ void operator ()() const {
39
+ *out_->data <T>() = *x_.data <T>() + static_cast <T>(value_);
40
+ }
41
+
42
+ const framework::LoDTensor &x_;
43
+ framework::LoDTensor *out_;
44
+ float value_;
45
+ };
46
+
47
+ class IncrementOp : public framework ::OperatorBase {
48
+ public:
49
+ IncrementOp (const std::string &type, const framework::VariableNameMap &inputs,
50
+ const framework::VariableNameMap &outputs,
51
+ const framework::AttributeMap &attrs)
52
+ : OperatorBase(type, inputs, outputs, attrs) {}
53
+
54
+ void Run (const framework::Scope &scope,
55
+ const platform::DeviceContext &dev_ctx) const override {
56
+ auto &x = scope.FindVar (Input (" X" ))->Get <framework::LoDTensor>();
57
+ auto &out =
58
+ *scope.FindVar (Output (" Out" ))->GetMutable <framework::LoDTensor>();
59
+
60
+ PADDLE_ENFORCE (platform::is_cpu_place (x.place ()));
61
+ out.Resize (x.dims ());
62
+ out.mutable_data (x.place (), x.type ());
63
+ float value = Attr<float >(" step" );
64
+ framework::VisitDataType (framework::ToDataType (out.type ()),
65
+ IncrementFunctor (x, &out, value));
31
66
}
32
67
};
33
68
@@ -59,10 +94,10 @@ class IncrementGradOpMaker : public framework::SingleGradOpDescMaker {
59
94
60
95
std::unique_ptr<framework::OpDescBind> Apply () const override {
61
96
auto *grad_op = new framework::OpDescBind ();
62
- grad_op->SetType (" scale " );
63
- grad_op->SetInput (" X" , OutputGrad (" Out" ));
64
- grad_op->SetOutput (" Out" , InputGrad (" X" ));
65
- grad_op->SetAttr (" scale " , 1 . 0f );
97
+ grad_op->SetType (" increment " );
98
+ grad_op->SetInput (" X" , Output (" Out" ));
99
+ grad_op->SetOutput (" Out" , Input (" X" ));
100
+ grad_op->SetAttr (" step " , -boost::get< float >( GetAttr ( " step " )) );
66
101
return std::unique_ptr<framework::OpDescBind>(grad_op);
67
102
}
68
103
};
@@ -71,11 +106,5 @@ class IncrementGradOpMaker : public framework::SingleGradOpDescMaker {
71
106
} // namespace paddle
72
107
73
108
namespace ops = paddle::operators;
74
-
75
- REGISTER_OPERATOR (increment, ops::IncrementOp, ops::IncrementOpMaker,
76
- ops::IncrementGradOpMaker);
77
- REGISTER_OP_CPU_KERNEL (
78
- increment, ops::IncrementKernel<paddle::platform::CPUPlace, float >,
79
- ops::IncrementKernel<paddle::platform::CPUPlace, double >,
80
- ops::IncrementKernel<paddle::platform::CPUPlace, int >,
81
- ops::IncrementKernel<paddle::platform::CPUPlace, int64_t >);
109
+ REGISTER_OPERATOR (increment, ops::IncrementOp, ops::IncrementInferShape,
110
+ ops::IncrementOpMaker, ops::IncrementGradOpMaker);
0 commit comments