Skip to content

Commit 9ff6715

Browse files
author
Qingsheng Li
authored
Enhanced is_empty_op for our seq2seq model (#10704)
* Added kernel to is_empty_op * Added python API * Updated code as required * Updated unittests
1 parent 5828101 commit 9ff6715

File tree

4 files changed

+117
-57
lines changed

4 files changed

+117
-57
lines changed

paddle/fluid/operators/is_empty_op.cc

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,45 +12,41 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15+
#include "paddle/fluid/operators/is_empty_op.h"
1516
#include "paddle/fluid/framework/op_registry.h"
1617
#include "paddle/fluid/framework/operator.h"
1718

1819
namespace paddle {
1920
namespace operators {
2021

21-
constexpr char kInput[] = "X";
22-
constexpr char kOutput[] = "Out";
23-
24-
class IsEmptyOp : public framework::OperatorBase {
22+
class IsEmptyOp : public framework::OperatorWithKernel {
2523
public:
26-
IsEmptyOp(const std::string &type, const framework::VariableNameMap &inputs,
27-
const framework::VariableNameMap &outputs,
28-
const framework::AttributeMap &attrs)
29-
: OperatorBase(type, inputs, outputs, attrs) {}
24+
using framework::OperatorWithKernel::OperatorWithKernel;
3025

31-
private:
32-
void RunImpl(const framework::Scope &scope,
33-
const platform::Place &place) const override {
34-
// get input
35-
auto *var = scope.FindVar(Input(kInput));
36-
PADDLE_ENFORCE_NOT_NULL(var);
37-
auto &tensor = var->Get<framework::LoDTensor>();
38-
// get output
39-
auto *out = scope.FindVar(Output(kOutput));
40-
PADDLE_ENFORCE_NOT_NULL(out);
41-
auto *out_tensor = out->GetMutable<framework::LoDTensor>();
26+
protected:
27+
void InferShape(framework::InferShapeContext *ctx) const override {
28+
PADDLE_ENFORCE(ctx->HasInput("X"),
29+
"Input(X) of IsEmptyOp should not be null.");
30+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
31+
"Output(Out) of IsEmptyOp should not be null.");
32+
ctx->SetOutputDim("Out", {1});
33+
}
4234

43-
out_tensor->Resize({1});
44-
out_tensor->mutable_data<bool>(platform::CPUPlace())[0] =
45-
framework::product(tensor.dims()) == 0;
35+
framework::OpKernelType GetExpectedKernelType(
36+
const framework::ExecutionContext &ctx) const override {
37+
framework::OpKernelType kt = framework::OpKernelType(
38+
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
39+
platform::CPUPlace());
40+
return kt;
4641
}
4742
};
4843

49-
class IsEmptyOpProtoMaker : public framework::OpProtoAndCheckerMaker {
44+
class IsEmptyOpMaker : public framework::OpProtoAndCheckerMaker {
5045
public:
5146
void Make() override {
52-
AddInput(kInput, "(Tensor) Tensor which is to be checked.");
53-
AddOutput(kOutput, "(Tensor) a boolean Tensor that indicate empty or not.");
47+
AddInput("X", "(LoDTensor) Tensor which is to be checked.");
48+
AddOutput("Out",
49+
"(LoDTensor) a boolean Tensor that indicate empty or not.");
5450
AddComment(R"DOC(
5551
IsEmpty Operator which checks whether a tensor is empty.
5652
@@ -62,5 +58,12 @@ It will just return product(tensor.ddims()) > 0;
6258
} // namespace operators
6359
} // namespace paddle
6460

65-
REGISTER_OP_WITHOUT_GRADIENT(is_empty, paddle::operators::IsEmptyOp,
66-
paddle::operators::IsEmptyOpProtoMaker);
61+
namespace ops = paddle::operators;
62+
63+
REGISTER_OPERATOR(is_empty, ops::IsEmptyOp, ops::IsEmptyOpMaker,
64+
paddle::framework::EmptyGradOpMaker);
65+
REGISTER_OP_CPU_KERNEL(
66+
is_empty, ops::IsEmptyOpKernel<paddle::platform::CPUDeviceContext, float>,
67+
ops::IsEmptyOpKernel<paddle::platform::CPUDeviceContext, double>,
68+
ops::IsEmptyOpKernel<paddle::platform::CPUDeviceContext, int>,
69+
ops::IsEmptyOpKernel<paddle::platform::CPUDeviceContext, int64_t>);

paddle/fluid/operators/is_empty_op.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/fluid/framework/op_registry.h"
17+
#include "paddle/fluid/framework/operator.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
template <typename DeviceContext, typename T>
23+
class IsEmptyOpKernel : public framework::OpKernel<T> {
24+
public:
25+
void Compute(const framework::ExecutionContext& context) const override {
26+
// get input
27+
auto* input_tensor = context.Input<framework::LoDTensor>("X");
28+
// get output
29+
auto* output_tensor = context.Output<framework::LoDTensor>("Out");
30+
31+
output_tensor->mutable_data<bool>(platform::CPUPlace())[0] =
32+
framework::product(input_tensor->dims()) == 0;
33+
}
34+
};
35+
36+
} // namespace operators
37+
} // namespace paddle

python/paddle/fluid/layers/control_flow.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
'reorder_lod_tensor_by_rank',
5050
'ParallelDo',
5151
'Print',
52+
'is_empty',
5253
]
5354

5455

@@ -1562,3 +1563,40 @@ def reorder_lod_tensor_by_rank(x, rank_table):
15621563
'RankTable': [rank_table]},
15631564
outputs={'Out': [out]})
15641565
return out
1566+
1567+
1568+
def is_empty(x, cond=None, **ignored):
1569+
"""
1570+
**Is Empty**
1571+
1572+
This layer returns the truth value of whether the variable is empty.
1573+
1574+
Args:
1575+
x(Variable): Operand of *is_empty*
1576+
cond(Variable|None): Optional output variable to store the result
1577+
of *is_empty*
1578+
1579+
Returns:
1580+
Variable: The tensor variable storing the output of *is_empty*.
1581+
1582+
Raises:
1583+
TypeError: If input cond is not a variable, or cond's dtype is
1584+
not bool
1585+
1586+
Examples:
1587+
.. code-block:: python
1588+
1589+
less = fluid.layers.is_empty(x=input)
1590+
"""
1591+
helper = LayerHelper("is_empty", **locals())
1592+
if cond is None:
1593+
cond = helper.create_tmp_variable(dtype='bool')
1594+
cond.stop_gradient = True
1595+
elif not isinstance(cond, Variable):
1596+
raise TypeError("cond takes a variable")
1597+
elif cond.dtype != 'bool':
1598+
raise TypeError("The data type of cond must be bool")
1599+
1600+
helper.append_op(
1601+
type='is_empty', inputs={'X': [x]}, outputs={'Out': [cond]})
1602+
return cond

python/paddle/fluid/tests/unittests/test_is_empty_op.py

Lines changed: 12 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,42 +14,24 @@
1414

1515
import unittest
1616
import numpy as np
17-
from paddle.fluid.op import Operator
18-
import paddle.fluid.core as core
17+
from op_test import OpTest
1918

2019

21-
def create_tensor(scope, name, np_data):
22-
tensor = scope.var(name).get_tensor()
23-
tensor.set_dims(np_data.shape)
24-
tensor.set(np_data, core.CPUPlace())
25-
return tensor
26-
27-
28-
class TestIsEmptyOp(unittest.TestCase):
20+
class TestEmpty(OpTest):
2921
def setUp(self):
30-
self.scope = core.Scope()
31-
# create input variables
32-
np_data0 = np.array([0, 1, 2])
33-
create_tensor(self.scope, "X0", np_data0)
34-
35-
np_data1 = np.array([1])
36-
t = create_tensor(self.scope, "X1", np_data1)
37-
t.set_dims([0])
22+
self.op_type = "is_empty"
23+
self.inputs = {'X': np.array([1, 2, 3])}
24+
self.outputs = {'Out': np.array([False])}
3825

39-
# create output variables
40-
self.scope.var("out")
26+
def test_check_output(self):
27+
self.check_output()
4128

42-
def test_no_empty(self):
43-
self.one_case("X0", False)
4429

45-
def test_empty(self):
46-
self.one_case("X1", True)
47-
48-
def one_case(self, input, target):
49-
op = Operator(type="is_empty", X=input, Out="out")
50-
op.run(self.scope, core.CPUPlace())
51-
out = self.scope.var("out").get_tensor()
52-
self.assertEqual(np.array(out)[0], target)
30+
class TestNotEmpty(TestEmpty):
31+
def setUp(self):
32+
self.op_type = "is_empty"
33+
self.inputs = {'X': np.array([])}
34+
self.outputs = {'Out': np.array([True])}
5335

5436

5537
if __name__ == "__main__":

0 commit comments

Comments
 (0)