Skip to content

Commit 9565876

Browse files
authored
Merge pull request #9428 from JiayiFeng/kernel_of_increment_op
kernels of IncrementOp
2 parents b7b0342 + 1e4f442 commit 9565876

File tree

6 files changed

+133
-64
lines changed

6 files changed

+133
-64
lines changed

paddle/fluid/operators/compare_op.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker {
2929
AddInput("Y", string::Sprintf(
3030
"(LoDTensor) the right hand operand of %s operator",
3131
comment.type));
32+
AddAttr<bool>("force_cpu",
33+
"(bool, default false) Force fill output variable to cpu "
34+
"memory. Otherwise, fill output variable to the running "
35+
"device")
36+
.SetDefault(false);
3237
AddOutput("Out", string::Sprintf(
3338
"(LoDTensor) n-dim bool tensor. Each element is %s",
3439
comment.equation));
@@ -75,7 +80,9 @@ class CompareOp : public framework::OperatorWithKernel {
7580
const framework::ExecutionContext &ctx) const override {
7681
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
7782
// CompareOp kernel's device type is decided by input tensor place
78-
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
83+
bool force_cpu = ctx.Attr<bool>("force_cpu");
84+
kt.place_ = force_cpu ? platform::CPUPlace()
85+
: ctx.Input<framework::LoDTensor>("X")->place();
7986
return kt;
8087
}
8188
};

paddle/fluid/operators/conditional_block_op.cc

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,18 @@ class ConditionalOp : public framework::OperatorBase {
5454
"numel should be 1, actual numel is %d",
5555
ips[0]->numel());
5656
}
57-
return ips[0]->data<bool>()[0];
57+
bool res = false;
58+
if (platform::is_gpu_place(ips[0]->place())) {
59+
#ifdef PADDLE_WITH_CUDA
60+
framework::LoDTensor cpu_tensor;
61+
framework::TensorCopy(*ips[0], platform::CPUPlace(), &cpu_tensor);
62+
platform::DeviceContextPool::Instance().Get(ips[0]->place())->Wait();
63+
res = cpu_tensor.data<bool>()[0];
64+
#endif
65+
} else {
66+
res = ips[0]->data<bool>()[0];
67+
}
68+
return res;
5869
}
5970
};
6071

Lines changed: 37 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,46 @@
1-
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2-
3-
Licensed under the Apache License, Version 2.0 (the "License");
4-
you may not use this file except in compliance with the License.
5-
You may obtain a copy of the License at
6-
7-
http://www.apache.org/licenses/LICENSE-2.0
8-
9-
Unless required by applicable law or agreed to in writing, software
10-
distributed under the License is distributed on an "AS IS" BASIS,
11-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
See the License for the specific language governing permissions and
13-
limitations under the License. */
14-
15-
#include "paddle/fluid/framework/op_registry.h"
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/increment_op.h"
1616

1717
namespace paddle {
1818
namespace operators {
1919

20-
class IncrementInferShape : public framework::InferShapeBase {
20+
class IncrementOp : public framework::OperatorWithKernel {
2121
public:
22-
void operator()(framework::InferShapeContext *ctx) const override {
22+
IncrementOp(const std::string &type, const framework::VariableNameMap &inputs,
23+
const framework::VariableNameMap &outputs,
24+
const framework::AttributeMap &attrs)
25+
: OperatorWithKernel(type, inputs, outputs, attrs) {}
26+
27+
void InferShape(framework::InferShapeContext *ctx) const override {
2328
PADDLE_ENFORCE(ctx->HasInput("X"),
2429
"Input(X) of IncrementOp should not be null.");
2530
PADDLE_ENFORCE(ctx->HasOutput("Out"),
2631
"Output(Out) of IncrementOp should not be null.");
2732
PADDLE_ENFORCE_EQ(1, framework::product(ctx->GetInputDim("X")));
2833
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
34+
ctx->ShareLoD("X", "Out");
2935
}
30-
};
31-
32-
struct IncrementFunctor {
33-
IncrementFunctor(const framework::LoDTensor &x, framework::LoDTensor *out,
34-
float value)
35-
: x_(x), out_(out), value_(value) {}
36-
37-
template <typename T>
38-
void operator()() const {
39-
*out_->data<T>() = *x_.data<T>() + static_cast<T>(value_);
40-
}
41-
42-
const framework::LoDTensor &x_;
43-
framework::LoDTensor *out_;
44-
float value_;
45-
};
46-
47-
class IncrementOp : public framework::OperatorBase {
48-
public:
49-
IncrementOp(const std::string &type, const framework::VariableNameMap &inputs,
50-
const framework::VariableNameMap &outputs,
51-
const framework::AttributeMap &attrs)
52-
: OperatorBase(type, inputs, outputs, attrs) {}
53-
54-
private:
55-
void RunImpl(const framework::Scope &scope,
56-
const platform::Place &place) const override {
57-
auto &x = scope.FindVar(Input("X"))->Get<framework::LoDTensor>();
58-
auto &out =
59-
*scope.FindVar(Output("Out"))->GetMutable<framework::LoDTensor>();
6036

61-
PADDLE_ENFORCE(platform::is_cpu_place(x.place()));
62-
out.Resize(x.dims());
63-
out.mutable_data(x.place(), x.type());
64-
float value = Attr<float>("step");
65-
VLOG(10) << Output("Out") << " increase " << Input("X") << " with "
66-
<< value;
67-
framework::VisitDataType(framework::ToDataType(out.type()),
68-
IncrementFunctor(x, &out, value));
37+
protected:
38+
framework::OpKernelType GetExpectedKernelType(
39+
const framework::ExecutionContext &ctx) const override {
40+
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
41+
// IncrementOp kernel's device type is decided by input tensor place
42+
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
43+
return kt;
6944
}
7045
};
7146

@@ -108,5 +83,10 @@ class IncrementGradOpMaker : public framework::SingleGradOpDescMaker {
10883
} // namespace paddle
10984

11085
namespace ops = paddle::operators;
111-
REGISTER_OPERATOR(increment, ops::IncrementOp, ops::IncrementInferShape,
112-
ops::IncrementOpMaker, ops::IncrementGradOpMaker);
86+
REGISTER_OPERATOR(increment, ops::IncrementOp, ops::IncrementOpMaker,
87+
ops::IncrementGradOpMaker);
88+
REGISTER_OP_CPU_KERNEL(
89+
increment, ops::IncrementKernel<paddle::platform::CPUDeviceContext, float>,
90+
ops::IncrementKernel<paddle::platform::CPUDeviceContext, double>,
91+
ops::IncrementKernel<paddle::platform::CPUDeviceContext, int>,
92+
ops::IncrementKernel<paddle::platform::CPUDeviceContext, int64_t>)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/increment_op.h"
16+
17+
namespace ops = paddle::operators;
18+
REGISTER_OP_CUDA_KERNEL(
19+
increment, ops::IncrementKernel<paddle::platform::CUDADeviceContext, float>,
20+
ops::IncrementKernel<paddle::platform::CUDADeviceContext, double>,
21+
ops::IncrementKernel<paddle::platform::CUDADeviceContext, int>,
22+
ops::IncrementKernel<paddle::platform::CUDADeviceContext, int64_t>)

paddle/fluid/operators/increment_op.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#include "paddle/fluid/framework/eigen.h"
17+
#include "paddle/fluid/framework/op_registry.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
template <typename DeviceContext, typename T>
23+
class IncrementKernel : public framework::OpKernel<T> {
24+
public:
25+
void Compute(const framework::ExecutionContext& context) const override {
26+
auto* x_tensor = context.Input<framework::Tensor>("X");
27+
auto* out_tensor = context.Output<framework::Tensor>("Out");
28+
float step = context.Attr<float>("step");
29+
30+
out_tensor->mutable_data<T>(context.GetPlace());
31+
auto& dev =
32+
*context.template device_context<DeviceContext>().eigen_device();
33+
framework::EigenScalar<T>::From(*out_tensor).device(dev) =
34+
framework::EigenScalar<T>::From(*x_tensor) + static_cast<T>(step);
35+
}
36+
};
37+
38+
} // namespace operators
39+
} // namespace paddle

python/paddle/fluid/layers/control_flow.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .. import core
1919
from ..framework import Program, Variable, Operator
2020
from ..layer_helper import LayerHelper, unique_name
21+
from ..initializer import force_init_on_cpu
2122
from ops import logical_and, logical_not, logical_or
2223

2324
__all__ = [
@@ -949,7 +950,7 @@ def create_array(dtype):
949950
dtype=dtype)
950951

951952

952-
def less_than(x, y, cond=None, **ignored):
953+
def less_than(x, y, force_cpu=True, cond=None, **ignored):
953954
"""
954955
**Less than**
955956
@@ -958,6 +959,7 @@ def less_than(x, y, cond=None, **ignored):
958959
Args:
959960
x(Variable): First operand of *less_than*
960961
y(Variable): Second operand of *less_than*
962+
force_cpu(Bool|True): The output data will be on CPU if set true.
961963
cond(Variable|None): Optional output variable to store the result of *less_than*
962964
963965
Returns:
@@ -974,8 +976,11 @@ def less_than(x, y, cond=None, **ignored):
974976
cond.stop_gradient = True
975977

976978
helper.append_op(
977-
type='less_than', inputs={'X': [x],
978-
'Y': [y]}, outputs={'Out': [cond]})
979+
type='less_than',
980+
inputs={'X': [x],
981+
'Y': [y]},
982+
outputs={'Out': [cond]},
983+
attrs={'force_cpu': force_cpu or force_init_on_cpu()})
979984
return cond
980985

981986

@@ -1396,7 +1401,8 @@ def step_input(self, x):
13961401
type='less_than',
13971402
inputs={'X': self.step_idx,
13981403
'Y': self.max_seq_len},
1399-
outputs={'Out': self.cond})
1404+
outputs={'Out': self.cond},
1405+
attrs={'force_cpu': True})
14001406

14011407
input_array = parent_block.create_var(
14021408
name=unique_name.generate('dynamic_rnn_input_array'),
@@ -1445,7 +1451,11 @@ def block(self):
14451451
for new_mem, mem_array in self.mem_link:
14461452
array_write(x=new_mem, i=self.step_idx, array=mem_array)
14471453

1448-
less_than(x=self.step_idx, y=self.max_seq_len, cond=self.cond)
1454+
less_than(
1455+
x=self.step_idx,
1456+
y=self.max_seq_len,
1457+
force_cpu=True,
1458+
cond=self.cond)
14491459

14501460
self.status = DynamicRNN.AFTER_RNN
14511461
for each_array in self.output_array:

0 commit comments

Comments
 (0)