Skip to content

Commit 6cfcf62

Browse files
authored
Adding logical operators for beam search and control flow (#5708)
1 parent bce1c03 commit 6cfcf62

File tree

6 files changed

+315
-0
lines changed

6 files changed

+315
-0
lines changed

paddle/framework/data_type.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ inline std::type_index ToTypeIndex(DataType type) {
4646
return typeid(int);
4747
case DataType::INT64:
4848
return typeid(int64_t);
49+
case DataType::BOOL:
50+
return typeid(bool);
4951
default:
5052
PADDLE_THROW("Not support type %d", type);
5153
}
@@ -66,6 +68,9 @@ inline void VisitDataType(DataType type, Visitor visitor) {
6668
case DataType::INT64:
6769
visitor.template operator()<int64_t>();
6870
break;
71+
case DataType::BOOL:
72+
visitor.template operator()<bool>();
73+
break;
6974
default:
7075
PADDLE_THROW("Not supported");
7176
}

paddle/operators/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ function(op_library TARGET)
8787
file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n")
8888
endif()
8989

90+
if ("${TARGET}" STREQUAL "logical_op")
91+
set(pybind_flag 1)
92+
file(APPEND ${pybind_file} "USE_OP(logical_and);\n")
93+
endif()
94+
9095
# pool_with_index_op contains several operators
9196
if ("${TARGET}" STREQUAL "pool_with_index_op")
9297
set(pybind_flag 1)

paddle/operators/logical_op.cc

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/logical_op.h"
16+
#include "paddle/framework/op_registry.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
template <typename OpComment>
21+
class BinaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker {
22+
public:
23+
BinaryLogicalOpProtoMaker(framework::OpProto *proto,
24+
framework::OpAttrChecker *op_checker)
25+
: OpProtoAndCheckerMaker(proto, op_checker) {
26+
OpComment comment;
27+
AddInput("X",
28+
string::Sprintf("(LoDTensor) Left hand operand of %s operator",
29+
comment.type));
30+
AddInput("Y",
31+
string::Sprintf("(LoDTensor) Right hand operand of %s operator",
32+
comment.type));
33+
AddOutput("Out", string::Sprintf(
34+
"(LoDTensor) n-dim bool tensor. Each element is %s",
35+
comment.equation));
36+
AddComment(string::Sprintf(R"DOC(%s Operator
37+
38+
It operates element-wise on X and Y, and returns the Out. X, Y and Out are N-dim boolean tensors.
39+
Each element of Out is calculated by %s
40+
)DOC",
41+
comment.type, comment.equation));
42+
}
43+
};
44+
45+
template <typename OpComment>
46+
class UnaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker {
47+
public:
48+
UnaryLogicalOpProtoMaker(framework::OpProto *proto,
49+
framework::OpAttrChecker *op_checker)
50+
: OpProtoAndCheckerMaker(proto, op_checker) {
51+
OpComment comment;
52+
AddInput("X", string::Sprintf("(LoDTensor) Operand of %s operator",
53+
comment.type));
54+
AddOutput("Out", string::Sprintf(
55+
"(LoDTensor) n-dim bool tensor. Each element is %s",
56+
comment.equation));
57+
AddComment(string::Sprintf(R"DOC(%s Operator
58+
59+
It operates element-wise on X, and returns the Out. X and Out are N-dim boolean tensors.
60+
Each element of Out is calculated by %s
61+
)DOC",
62+
comment.type, comment.equation));
63+
}
64+
};
65+
66+
template <typename OpComment>
67+
class BinaryLogicalOpInferShape : public framework::InferShapeBase {
68+
public:
69+
void operator()(framework::InferShapeContext *context) const override {
70+
OpComment comment;
71+
PADDLE_ENFORCE(context->HasInput("X"),
72+
"Input(X) of %s operator must not be null", comment.type);
73+
PADDLE_ENFORCE(context->HasInput("Y"),
74+
"Input(Y) of %s operator must not be null", comment.type);
75+
auto dim_x = context->GetInputDim("X");
76+
auto dim_y = context->GetInputDim("Y");
77+
PADDLE_ENFORCE_EQ(framework::product(dim_x), framework::product(dim_y),
78+
"The number of elements in X and Y should be same");
79+
80+
context->SetOutputDim("Out", context->GetInputDim("X"));
81+
context->ShareLoD("X", "Out");
82+
}
83+
};
84+
85+
template <typename OpComment>
86+
class UnaryLogicalOpInferShape : public framework::InferShapeBase {
87+
public:
88+
void operator()(framework::InferShapeContext *context) const override {
89+
OpComment comment;
90+
PADDLE_ENFORCE(context->HasInput("X"),
91+
"Input(X) of %s operator must not be null", comment.type);
92+
auto dim_x = context->GetInputDim("X");
93+
94+
context->SetOutputDim("Out", context->GetInputDim("X"));
95+
context->ShareLoD("X", "Out");
96+
}
97+
};
98+
99+
class LogicalOp : public framework::OperatorWithKernel {
100+
public:
101+
using framework::OperatorWithKernel::OperatorWithKernel;
102+
103+
protected:
104+
framework::OpKernelType GetKernelType(
105+
const framework::ExecutionContext &ctx) const override {
106+
framework::OpKernelType kt = OperatorWithKernel::GetKernelType(ctx);
107+
// LogicalOp kernel's device type is decided by input tensor place
108+
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
109+
return kt;
110+
}
111+
};
112+
113+
} // namespace operators
114+
} // namespace paddle
115+
116+
#define REGISTER_BINARY_LOGICAL_OP(op_type, _equation) \
117+
struct _##op_type##Comment { \
118+
static char type[]; \
119+
static char equation[]; \
120+
}; \
121+
char _##op_type##Comment::type[]{#op_type}; \
122+
char _##op_type##Comment::equation[]{_equation}; \
123+
REGISTER_OPERATOR( \
124+
op_type, ::paddle::operators::LogicalOp, \
125+
::paddle::operators::BinaryLogicalOpProtoMaker<_##op_type##Comment>, \
126+
::paddle::operators::BinaryLogicalOpInferShape<_##op_type##Comment>, \
127+
::paddle::framework::EmptyGradOpMaker);
128+
129+
#define REGISTER_UNARY_LOGICAL_OP(op_type, _equation) \
130+
struct _##op_type##Comment { \
131+
static char type[]; \
132+
static char equation[]; \
133+
}; \
134+
char _##op_type##Comment::type[]{#op_type}; \
135+
char _##op_type##Comment::equation[]{_equation}; \
136+
REGISTER_OPERATOR( \
137+
op_type, ::paddle::operators::LogicalOp, \
138+
::paddle::operators::UnaryLogicalOpProtoMaker<_##op_type##Comment>, \
139+
::paddle::operators::UnaryLogicalOpInferShape<_##op_type##Comment>, \
140+
::paddle::framework::EmptyGradOpMaker);
141+
142+
REGISTER_BINARY_LOGICAL_OP(logical_and, "Out = X && Y");
143+
REGISTER_BINARY_LOGICAL_KERNEL(logical_and, CPU,
144+
paddle::operators::LogicalAndFunctor);
145+
REGISTER_BINARY_LOGICAL_OP(logical_or, "Out = X && Y");
146+
REGISTER_BINARY_LOGICAL_KERNEL(logical_or, CPU,
147+
paddle::operators::LogicalOrFunctor);
148+
REGISTER_UNARY_LOGICAL_OP(logical_not, "Out = !X");
149+
REGISTER_UNARY_LOGICAL_KERNEL(logical_not, CPU,
150+
paddle::operators::LogicalNotFunctor);
151+
REGISTER_BINARY_LOGICAL_OP(logical_xor, "Out = (X || Y) && !(X && Y)");
152+
REGISTER_BINARY_LOGICAL_KERNEL(logical_xor, CPU,
153+
paddle::operators::LogicalXorFunctor);

paddle/operators/logical_op.cu

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/logical_op.h"
16+
17+
REGISTER_BINARY_LOGICAL_KERNEL(logical_and, GPU,
18+
paddle::operators::LogicalAndFunctor);
19+
REGISTER_BINARY_LOGICAL_KERNEL(logical_or, GPU,
20+
paddle::operators::LogicalOrFunctor);
21+
REGISTER_UNARY_LOGICAL_KERNEL(logical_not, GPU,
22+
paddle::operators::LogicalNotFunctor);
23+
REGISTER_BINARY_LOGICAL_KERNEL(logical_xor, GPU,
24+
paddle::operators::LogicalXorFunctor);

paddle/operators/logical_op.h

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include <math.h>
17+
#include <type_traits>
18+
#include "paddle/framework/op_registry.h"
19+
#include "paddle/platform/transform.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
template <typename T>
25+
struct LogicalAndFunctor {
26+
using ELEM_TYPE = T;
27+
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a && b; }
28+
};
29+
30+
template <typename T>
31+
struct LogicalOrFunctor {
32+
using ELEM_TYPE = T;
33+
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a || b; }
34+
};
35+
36+
template <typename T>
37+
struct LogicalNotFunctor {
38+
using ELEM_TYPE = T;
39+
HOSTDEVICE bool operator()(const T& a) const { return !a; }
40+
};
41+
42+
template <typename T>
43+
struct LogicalXorFunctor {
44+
using ELEM_TYPE = T;
45+
HOSTDEVICE bool operator()(const T& a, const T& b) const {
46+
return (a || b) && !(a && b);
47+
}
48+
};
49+
50+
template <typename Place, typename Functor>
51+
class BinaryLogicalOpKernel
52+
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
53+
public:
54+
void Compute(const framework::ExecutionContext& context) const override {
55+
using T = typename Functor::ELEM_TYPE;
56+
auto* x = context.Input<framework::Tensor>("X");
57+
auto* y = context.Input<framework::Tensor>("Y");
58+
auto* out = context.Output<framework::Tensor>("Out");
59+
Functor binary_func;
60+
platform::Transform<Place> trans;
61+
trans(context.device_context(), x->data<T>(), x->data<T>() + x->numel(),
62+
y->data<T>(), out->mutable_data<bool>(context.GetPlace()),
63+
binary_func);
64+
}
65+
};
66+
67+
template <typename Place, typename Functor>
68+
class UnaryLogicalOpKernel
69+
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
70+
public:
71+
void Compute(const framework::ExecutionContext& context) const override {
72+
using T = typename Functor::ELEM_TYPE;
73+
auto* x = context.Input<framework::Tensor>("X");
74+
auto* out = context.Output<framework::Tensor>("Out");
75+
Functor unary_func;
76+
platform::Transform<Place> trans;
77+
trans(context.device_context(), x->data<T>(), x->data<T>() + x->numel(),
78+
out->mutable_data<bool>(context.GetPlace()), unary_func);
79+
}
80+
};
81+
82+
} // namespace operators
83+
} // namespace paddle
84+
85+
#define REGISTER_BINARY_LOGICAL_KERNEL(op_type, dev, functor) \
86+
REGISTER_OP_##dev##_KERNEL( \
87+
op_type, ::paddle::operators::BinaryLogicalOpKernel< \
88+
::paddle::platform::dev##Place, functor<bool>>);
89+
90+
#define REGISTER_UNARY_LOGICAL_KERNEL(op_type, dev, functor) \
91+
REGISTER_OP_##dev##_KERNEL( \
92+
op_type, ::paddle::operators::UnaryLogicalOpKernel< \
93+
::paddle::platform::dev##Place, functor<bool>>);
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import op_test
2+
import unittest
3+
import numpy as np
4+
5+
6+
def create_test_class(op_type, callback, binary_op=True):
7+
class Cls(op_test.OpTest):
8+
def setUp(self):
9+
a = np.random.choice(a=[True, False], size=(10, 7)).astype(bool)
10+
if binary_op:
11+
b = np.random.choice(a=[True, False], size=(10, 7)).astype(bool)
12+
c = callback(a, b)
13+
else:
14+
c = callback(a)
15+
self.outputs = {'Out': c}
16+
self.op_type = op_type
17+
if binary_op:
18+
self.inputs = {'X': a, 'Y': b}
19+
else:
20+
self.inputs = {'X': a}
21+
22+
def test_output(self):
23+
self.check_output()
24+
25+
Cls.__name__ = op_type
26+
globals()[op_type] = Cls
27+
28+
29+
create_test_class('logical_and', lambda _a, _b: np.logical_and(_a, _b))
30+
create_test_class('logical_or', lambda _a, _b: np.logical_or(_a, _b))
31+
create_test_class('logical_not', lambda _a: np.logical_not(_a), False)
32+
create_test_class('logical_xor', lambda _a, _b: np.logical_xor(_a, _b))
33+
34+
if __name__ == '__main__':
35+
unittest.main()

0 commit comments

Comments
 (0)