|
1 |
| -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. |
2 |
| -
|
3 |
| -Licensed under the Apache License, Version 2.0 (the "License"); |
4 |
| -you may not use this file except in compliance with the License. |
5 |
| -You may obtain a copy of the License at |
6 |
| -
|
7 |
| - http://www.apache.org/licenses/LICENSE-2.0 |
8 |
| -
|
9 |
| -Unless required by applicable law or agreed to in writing, software |
10 |
| -distributed under the License is distributed on an "AS IS" BASIS, |
11 |
| -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 |
| -See the License for the specific language governing permissions and |
13 |
| -limitations under the License. */ |
14 |
| - |
15 |
| -#include "paddle/fluid/framework/op_registry.h" |
| 1 | +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include "paddle/fluid/operators/increment_op.h" |
16 | 16 |
|
17 | 17 | namespace paddle {
|
18 | 18 | namespace operators {
|
19 | 19 |
|
20 |
| -class IncrementInferShape : public framework::InferShapeBase { |
| 20 | +class IncrementOp : public framework::OperatorWithKernel { |
21 | 21 | public:
|
22 |
| - void operator()(framework::InferShapeContext *ctx) const override { |
| 22 | + IncrementOp(const std::string &type, const framework::VariableNameMap &inputs, |
| 23 | + const framework::VariableNameMap &outputs, |
| 24 | + const framework::AttributeMap &attrs) |
| 25 | + : OperatorWithKernel(type, inputs, outputs, attrs) {} |
| 26 | + |
| 27 | + void InferShape(framework::InferShapeContext *ctx) const override { |
23 | 28 | PADDLE_ENFORCE(ctx->HasInput("X"),
|
24 | 29 | "Input(X) of IncrementOp should not be null.");
|
25 | 30 | PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
26 | 31 | "Output(Out) of IncrementOp should not be null.");
|
27 | 32 | PADDLE_ENFORCE_EQ(1, framework::product(ctx->GetInputDim("X")));
|
28 | 33 | ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
|
| 34 | + ctx->ShareLoD("X", "Out"); |
29 | 35 | }
|
30 |
| -}; |
31 |
| - |
32 |
| -struct IncrementFunctor { |
33 |
| - IncrementFunctor(const framework::LoDTensor &x, framework::LoDTensor *out, |
34 |
| - float value) |
35 |
| - : x_(x), out_(out), value_(value) {} |
36 |
| - |
37 |
| - template <typename T> |
38 |
| - void operator()() const { |
39 |
| - *out_->data<T>() = *x_.data<T>() + static_cast<T>(value_); |
40 |
| - } |
41 |
| - |
42 |
| - const framework::LoDTensor &x_; |
43 |
| - framework::LoDTensor *out_; |
44 |
| - float value_; |
45 |
| -}; |
46 |
| - |
47 |
| -class IncrementOp : public framework::OperatorBase { |
48 |
| - public: |
49 |
| - IncrementOp(const std::string &type, const framework::VariableNameMap &inputs, |
50 |
| - const framework::VariableNameMap &outputs, |
51 |
| - const framework::AttributeMap &attrs) |
52 |
| - : OperatorBase(type, inputs, outputs, attrs) {} |
53 |
| - |
54 |
| - private: |
55 |
| - void RunImpl(const framework::Scope &scope, |
56 |
| - const platform::Place &place) const override { |
57 |
| - auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>(); |
58 |
| - auto &out = |
59 |
| - *scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>(); |
60 | 36 |
|
61 |
| - PADDLE_ENFORCE(platform::is_cpu_place(x.place())); |
62 |
| - out.Resize(x.dims()); |
63 |
| - out.mutable_data(x.place(), x.type()); |
64 |
| - float value = Attr<float>("step"); |
65 |
| - VLOG(10) << Output("Out") << " increase " << Input("X") << " with " |
66 |
| - << value; |
67 |
| - framework::VisitDataType(framework::ToDataType(out.type()), |
68 |
| - IncrementFunctor(x, &out, value)); |
| 37 | + protected: |
| 38 | + framework::OpKernelType GetExpectedKernelType( |
| 39 | + const framework::ExecutionContext &ctx) const override { |
| 40 | + framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); |
| 41 | + // IncrementOp kernel's device type is decided by input tensor place |
| 42 | + kt.place_ = ctx.Input<framework::LoDTensor>("X")->place(); |
| 43 | + return kt; |
69 | 44 | }
|
70 | 45 | };
|
71 | 46 |
|
@@ -108,5 +83,10 @@ class IncrementGradOpMaker : public framework::SingleGradOpDescMaker {
|
108 | 83 | } // namespace paddle
|
109 | 84 |
|
110 | 85 | namespace ops = paddle::operators;
|
111 |
| -REGISTER_OPERATOR(increment, ops::IncrementOp, ops::IncrementInferShape, |
112 |
| - ops::IncrementOpMaker, ops::IncrementGradOpMaker); |
| 86 | +REGISTER_OPERATOR(increment, ops::IncrementOp, ops::IncrementOpMaker, |
| 87 | + ops::IncrementGradOpMaker); |
| 88 | +REGISTER_OP_CPU_KERNEL( |
| 89 | + increment, ops::IncrementKernel<paddle::platform::CPUDeviceContext, float>, |
| 90 | + ops::IncrementKernel<paddle::platform::CPUDeviceContext, double>, |
| 91 | + ops::IncrementKernel<paddle::platform::CPUDeviceContext, int>, |
| 92 | + ops::IncrementKernel<paddle::platform::CPUDeviceContext, int64_t>) |
0 commit comments