Skip to content

Commit d189d4d

Browse files
authored
Merge pull request #12884 from sneaxiy/sequence_mask_op
Add sequence_mask_op for DAM model
2 parents 7240e72 + 7df74a5 commit d189d4d

File tree

6 files changed

+366
-0
lines changed

6 files changed

+366
-0
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs
162162
paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,))
163163
paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None))
164164
paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None))
165+
paddle.fluid.layers.sequence_mask ArgSpec(args=['x', 'maxlen', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 'int64', None))
165166
paddle.fluid.layers.stack ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,))
166167
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
167168
paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True))
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/sequence_mask_op.h"
16+
17+
REGISTER_OPERATOR(sequence_mask, paddle::operators::SequenceMaskOp,
18+
paddle::operators::SequenceMaskOpMaker,
19+
paddle::framework::EmptyGradOpMaker);
20+
21+
REGISTER_OP_CPU_KERNEL(
22+
sequence_mask,
23+
paddle::operators::SequenceMaskKernel<paddle::platform::CPUDeviceContext,
24+
int>,
25+
paddle::operators::SequenceMaskKernel<paddle::platform::CPUDeviceContext,
26+
int64_t>);
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/sequence_mask_op.h"
16+
17+
REGISTER_OP_CUDA_KERNEL(
18+
sequence_mask,
19+
paddle::operators::SequenceMaskKernel<paddle::platform::CUDADeviceContext,
20+
int>,
21+
paddle::operators::SequenceMaskKernel<paddle::platform::CUDADeviceContext,
22+
int64_t>);
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#ifdef __NVCC__
18+
#include <thrust/device_ptr.h>
19+
#include <thrust/functional.h>
20+
#include <thrust/reduce.h>
21+
#else
22+
#include <algorithm>
23+
#endif
24+
25+
#include "paddle/fluid/framework/op_registry.h"
26+
#include "paddle/fluid/platform/for_range.h"
27+
28+
namespace paddle {
29+
namespace operators {
30+
31+
class SequenceMaskOp : public framework::OperatorWithKernel {
32+
public:
33+
using framework::OperatorWithKernel::OperatorWithKernel;
34+
35+
void InferShape(framework::InferShapeContext *ctx) const override {
36+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist");
37+
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) must exist");
38+
39+
auto maxlen = ctx->Attrs().Get<int>("maxlen");
40+
if (maxlen > 0) { // We can only infershape when maxlen > 0
41+
auto dim = framework::vectorize2int(ctx->GetInputDim("X"));
42+
dim.push_back(maxlen);
43+
ctx->SetOutputDim("Y", framework::make_ddim(dim));
44+
}
45+
}
46+
};
47+
48+
class SequenceMaskOpMaker : public framework::OpProtoAndCheckerMaker {
49+
public:
50+
void Make() override {
51+
AddInput("X", "The input tensor of sequence_mask op.");
52+
AddOutput("Y", "The output mask of sequence_mask op.");
53+
AddAttr<int>("maxlen",
54+
"The maximum length of the sequence. If maxlen < 0, maxlen "
55+
"= max(Input(X)).")
56+
.SetDefault(-1)
57+
.AddCustomChecker([](int &v) {
58+
PADDLE_ENFORCE(v < 0 || v >= 1,
59+
"Attr(maxlen) must be less than 0 or larger than 1");
60+
});
61+
AddAttr<int>("out_dtype", "Output data type");
62+
AddComment(R"DOC(
63+
SequenceMask Operator
64+
65+
This operator outputs a Mask according to Input(X) and Attr(maxlen).
66+
Supposing Input(X) is a Tensor with shape [d_1, d_2, ..., d_n], the
67+
Output(Y) is a mask with shape [d_1, d_2, ..., d_n, maxlen], where:
68+
69+
Y(i_1, i_2, ..., i_n, j) = (j < X(i_1, i_2, ..., i_n))
70+
71+
If maxlen < 0, maxlen = max(X)
72+
)DOC");
73+
}
74+
};
75+
76+
template <typename Tx, typename Ty>
77+
struct SequenceMaskForRangeFunctor {
78+
HOSTDEVICE SequenceMaskForRangeFunctor(const Tx *x, Ty *y, int maxlen)
79+
: x_(x), y_(y), maxlen_(maxlen) {}
80+
81+
HOSTDEVICE void operator()(int y_idx) const {
82+
int x_idx = y_idx / maxlen_;
83+
int j = y_idx % maxlen_;
84+
y_[y_idx] = static_cast<Ty>(j < x_[x_idx] ? 1 : 0);
85+
}
86+
87+
private:
88+
const Tx *x_;
89+
Ty *y_;
90+
int maxlen_;
91+
};
92+
93+
template <typename DeviceContext, typename Tx>
94+
struct SequenceMaskFunctor {
95+
using Tensor = framework::LoDTensor;
96+
97+
SequenceMaskFunctor(const DeviceContext &ctx, const Tx *x, Tensor *y,
98+
int limits, int maxlen)
99+
: ctx_(ctx), x_(x), y_(y), limits_(limits), maxlen_(maxlen) {}
100+
101+
template <typename Ty>
102+
void operator()() const {
103+
auto *y_data = y_->mutable_data<Ty>(ctx_.GetPlace());
104+
platform::ForRange<DeviceContext> for_range(ctx_, limits_);
105+
for_range(SequenceMaskForRangeFunctor<Tx, Ty>(x_, y_data, maxlen_));
106+
}
107+
108+
private:
109+
const DeviceContext &ctx_;
110+
const Tx *x_;
111+
Tensor *y_;
112+
int limits_;
113+
int maxlen_;
114+
};
115+
116+
template <typename DeviceContext, typename Tx>
117+
class SequenceMaskKernel : public framework::OpKernel<Tx> {
118+
using Tensor = framework::LoDTensor;
119+
120+
public:
121+
void Compute(const framework::ExecutionContext &ctx) const override {
122+
auto *x = ctx.Input<Tensor>("X");
123+
auto *y = ctx.Output<Tensor>("Y");
124+
auto maxlen = ctx.Attr<int>("maxlen");
125+
126+
auto *x_data = x->data<Tx>();
127+
auto x_numel = x->numel();
128+
if (maxlen < 0) {
129+
#ifdef __NVCC__
130+
VLOG(10)
131+
<< "SequenceMaskOp on GPU may be slow when maxlen is not provided.";
132+
maxlen = static_cast<int>(
133+
thrust::reduce(thrust::device_pointer_cast(x_data),
134+
thrust::device_pointer_cast(x_data) + x_numel,
135+
static_cast<Tx>(0), thrust::maximum<Tx>()));
136+
#else
137+
maxlen = static_cast<int>(*std::max_element(x_data, x_data + x_numel));
138+
#endif
139+
auto y_dim = framework::vectorize2int(x->dims());
140+
y_dim.push_back(maxlen);
141+
y->Resize(framework::make_ddim(y_dim));
142+
}
143+
144+
auto out_dtype = static_cast<framework::proto::VarType::Type>(
145+
ctx.Attr<int>("out_dtype"));
146+
auto &dev_ctx = ctx.template device_context<DeviceContext>();
147+
framework::VisitDataType(out_dtype,
148+
SequenceMaskFunctor<DeviceContext, Tx>(
149+
dev_ctx, x_data, y, x_numel * maxlen, maxlen));
150+
}
151+
};
152+
153+
} // namespace operators
154+
} // namespace paddle

python/paddle/fluid/layers/nn.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@
103103
'rank_loss',
104104
'prelu',
105105
'flatten',
106+
'sequence_mask',
106107
'stack',
107108
]
108109

@@ -5520,7 +5521,75 @@ def flatten(x, axis=1, name=None):
55205521
return out
55215522

55225523

5524+
def sequence_mask(x, maxlen=None, dtype='int64', name=None):
5525+
"""
5526+
**SequenceMask Layer**
5527+
5528+
This layer outputs a mask according to the input :code:`x` and
5529+
:code:`maxlen` with data type of :code:`dtype`.
5530+
5531+
Supposing :code:`x` is a Tensor with shape [d_1, d_2, ..., d_n], the
5532+
:code:`y` is a mask with shape [d_1, d_2, ..., d_n, maxlen], where:
5533+
5534+
.. math::
5535+
5536+
y(i_1, i_2,..., i_n, j) = (j < x(i_1, i_2,..., i_n))
5537+
5538+
Args:
5539+
x (Variable): Input tensor of sequence_mask layer,
5540+
whose elements are integers less than :code:`maxlen`.
5541+
maxlen (int|None): Maximum length of the sequence. If :code:`maxlen`
5542+
is None, it would be replace with :math:`max(x)`.
5543+
dtype (np.dtype|core.VarDesc.VarType|str): Data type of the output.
5544+
name (str|None): A name for this layer(optional). If set None, the
5545+
layer will be named automatically.
5546+
5547+
Returns:
5548+
Variable: The output sequence mask.
5549+
5550+
"""
5551+
5552+
helper = LayerHelper('sequence_mask', **locals())
5553+
if name is None:
5554+
out = helper.create_tmp_variable(dtype=dtype)
5555+
else:
5556+
out = helper.create_tmp_variable(dtype=dtype, name=name)
5557+
5558+
helper.append_op(
5559+
type='sequence_mask',
5560+
inputs={'X': [x]},
5561+
outputs={'Y': out},
5562+
attrs={
5563+
'max_len': maxlen if maxlen is not None else -1,
5564+
'out_dtype': out.dtype
5565+
})
5566+
return out
5567+
5568+
55235569
def stack(x, axis=0):
5570+
"""
5571+
**Stack Layer**
5572+
5573+
This layer stacks all of the input :code:`x` along axis.
5574+
5575+
Input :code:`x` can be a single variable, a :code:`list` of variables,
5576+
or a :code:`tuple` of variables. If :code:`x` is a :code:`list` or
5577+
:code:`tuple`, the shapes of all these variables must be the same.
5578+
Supposing the shape of each input is :math:`[d_0, d_1, ..., d_{n-1}]`,
5579+
the shape of the output variable would be
5580+
:math:`[d_0, d_1, ..., d_{axis}=len(x), ..., d_{n-1}]`.
5581+
If :code:`axis` < 0, it would be replaced with :code:`axis+rank(x[0])+1`.
5582+
If :code:`axis` is None, it would be replaced with 0.
5583+
5584+
Args:
5585+
x (Variable|list(Variable)|tuple(Variable)): Input variables.
5586+
axis (int|None): The axis along which all inputs are stacked.
5587+
5588+
Returns:
5589+
Variable: The stacked variable.
5590+
5591+
"""
5592+
55245593
helper = LayerHelper('stack', **locals())
55255594
axis = 0 if axis is None else axis
55265595

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from op_test import OpTest
16+
import paddle.fluid as fluid
17+
from paddle.fluid.framework import convert_np_dtype_to_dtype_
18+
import paddle.fluid.core as core
19+
import numpy as np
20+
import copy
21+
import unittest
22+
23+
24+
class SequenceMaskTestBase(OpTest):
25+
def initDefaultParameters(self):
26+
self.op_type = 'sequence_mask'
27+
self.maxlen = 10
28+
self.mask_dtype = 'int64'
29+
self.x = [[0, 3, 4], [5, 7, 9]]
30+
31+
def initParameters(self):
32+
pass
33+
34+
def setUp(self):
35+
self.initDefaultParameters()
36+
self.initParameters()
37+
if not isinstance(self.x, np.ndarray):
38+
self.x = np.array(self.x)
39+
40+
self.inputs = {'X': self.x}
41+
self.outputs = {'Y': self.calc_ground_truth_mask()}
42+
self.attrs = {
43+
'maxlen': self.maxlen,
44+
'out_dtype': convert_np_dtype_to_dtype_(self.mask_dtype)
45+
}
46+
47+
def calc_ground_truth_mask(self):
48+
maxlen = np.max(self.x) if self.maxlen < 0 else self.maxlen
49+
shape = self.x.shape + (maxlen, )
50+
index_broadcast = np.broadcast_to(
51+
np.reshape(
52+
range(maxlen), newshape=[1] * self.x.ndim + [-1]),
53+
shape=shape)
54+
x_broadcast = np.broadcast_to(
55+
np.reshape(
56+
self.x, newshape=self.x.shape + (-1, )), shape=shape)
57+
return (index_broadcast < x_broadcast).astype(self.mask_dtype)
58+
59+
def test_check_output(self):
60+
self.check_output()
61+
62+
63+
class SequenceMaskTest1(SequenceMaskTestBase):
64+
def initParameters(self):
65+
self.mask_dtype = 'bool'
66+
67+
68+
class SequenceMaskTest2(SequenceMaskTestBase):
69+
def initParameters(self):
70+
self.mask_dtype = 'uint8'
71+
72+
73+
class SequenceMaskTest3(SequenceMaskTestBase):
74+
def initParameters(self):
75+
self.mask_dtype = 'int32'
76+
77+
78+
class SequenceMaskTest4(SequenceMaskTestBase):
79+
def initParameters(self):
80+
self.mask_dtype = 'float32'
81+
82+
83+
class SequenceMaskTest5(SequenceMaskTestBase):
84+
def initParameters(self):
85+
self.mask_dtype = 'float64'
86+
87+
88+
class SequenceMaskTest6(SequenceMaskTestBase):
89+
def initParameters(self):
90+
self.maxlen = -1
91+
92+
93+
if __name__ == '__main__':
94+
unittest.main()

0 commit comments

Comments
 (0)