Skip to content

Commit 21ec93a

Browse files
author
Qingsheng Li
authored
[WIP]Sequence Scatter Op (#12625)
Sequence Scatter Op
1 parent 103deb1 commit 21ec93a

File tree

6 files changed

+415
-0
lines changed

6 files changed

+415
-0
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ paddle.fluid.layers.image_resize_short ArgSpec(args=['input', 'out_short_len', '
154154
paddle.fluid.layers.resize_bilinear ArgSpec(args=['input', 'out_shape', 'scale', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
155155
paddle.fluid.layers.gather ArgSpec(args=['input', 'index'], varargs=None, keywords=None, defaults=None)
156156
paddle.fluid.layers.scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,))
157+
paddle.fluid.layers.sequence_scatter ArgSpec(args=['input', 'index', 'updates', 'name'], varargs=None, keywords=None, defaults=(None,))
157158
paddle.fluid.layers.random_crop ArgSpec(args=['x', 'shape', 'seed'], varargs=None, keywords=None, defaults=(None,))
158159
paddle.fluid.layers.mean_iou ArgSpec(args=['input', 'label', 'num_classes'], varargs=None, keywords=None, defaults=None)
159160
paddle.fluid.layers.relu ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/sequence_scatter_op.h"
16+
#include "paddle/fluid/framework/eigen.h"
17+
#include "paddle/fluid/framework/op_registry.h"
18+
#include "paddle/fluid/operators/gather.h"
19+
#include "paddle/fluid/operators/scatter.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
using Tensor = framework::Tensor;
25+
using LoDTensor = framework::LoDTensor;
26+
27+
class SequenceScatterOpMaker : public framework::OpProtoAndCheckerMaker {
28+
public:
29+
void Make() override {
30+
AddInput("X", "(Tensor) The source input of sequence scatter op");
31+
AddInput("Ids",
32+
"(LoDTensor) The index input of sequence scatter op where X"
33+
" will be updated, must be a LoDTensor");
34+
AddInput("Updates",
35+
"(LoDTensor) The values to scatter to the input tensor "
36+
"X, must be a LoDTensor with the same LoD information as Ids");
37+
AddOutput("Out",
38+
"(Tensor) The output tensor of sequence scatter op, which "
39+
"has the same dims as X");
40+
AddComment(R"DOC(
41+
Sequence Scatter Operator.
42+
43+
This operator scatters the Updates tensor to the input X. It uses the LoD
44+
information of Ids to select the rows to update, and use the values in Ids as
45+
the columns to update in each row of X.
46+
47+
Following are cases to better explain how this works:
48+
49+
Example 1:
50+
Given an all-ones Tensor input(X)
51+
X.data = [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
52+
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
53+
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]
54+
X.dims = [3, 6]
55+
a LoDTensor input(Ids)
56+
Ids.data = [[0], [1], [2], [5], [4], [3], [2], [1], [3], [2], [5], [4]]
57+
Ids.lod = [[0, 3, 8, 12]]
58+
and a Tensor input(Updates)
59+
Updates.data = [[0.3], [0.3], [0.4], [0.1], [0.2], [0.3], [0.4], [0.0], [0.2], [0.3], [0.1], [0.4]]
60+
Updates.lod = [[ 0, 3, 8, 12]]
61+
then we get an output Tensor
62+
Out.data = [[1.3, 1.3, 1.4, 1.0, 1.0, 1.0],
63+
[1.0, 1.0, 1.4, 1.3, 1.2, 1.1],
64+
[1.0, 1.0, 1.3, 1.2, 1.4, 1.1]]
65+
Out.dims = X.dims = [3, 6]
66+
)DOC");
67+
}
68+
};
69+
70+
class SequenceScatterOp : public framework::OperatorWithKernel {
71+
public:
72+
using framework::OperatorWithKernel::OperatorWithKernel;
73+
74+
void InferShape(framework::InferShapeContext* ctx) const override {
75+
// Enforce has inputs and outputs
76+
PADDLE_ENFORCE(ctx->HasInput("X"),
77+
"Input(X) of SequenceScatterOp should not be null.");
78+
PADDLE_ENFORCE(ctx->HasInput("Ids"),
79+
"Input(Ids) of SequenceScatterOp should not be null.");
80+
PADDLE_ENFORCE(ctx->HasInput("Updates"),
81+
"Input(Updates) of SequenceScatterOp should not be null.");
82+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
83+
"Output(Out) of SequenceScatterOp should not be null.");
84+
85+
// Set output dim the same as input
86+
auto ref_dims = ctx->GetInputDim("X");
87+
ctx->SetOutputDim("Out", ref_dims);
88+
89+
// Enforce the Updates and Ids are the same shape
90+
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Updates")[0],
91+
ctx->GetInputDim("Ids")[0],
92+
"Updates and Ids should have same shape.");
93+
94+
// Enforce LoD of ids and updates be the same
95+
if (ctx->IsRuntime()) {
96+
framework::Variable* ids_var =
97+
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("Ids")[0]);
98+
framework::Variable* updates_var =
99+
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("Updates")[0]);
100+
101+
auto& ids_lod = ids_var->Get<LoDTensor>().lod();
102+
auto& updates_lod = updates_var->Get<LoDTensor>().lod();
103+
PADDLE_ENFORCE_EQ(ids_lod.size(), 1,
104+
"Currently only level 1 LoD could be"
105+
" processed by sequence scatter op.");
106+
PADDLE_ENFORCE_EQ(updates_lod.size(), 1,
107+
"Currently only level 1 LoD "
108+
"could be processed by sequence scatter op.");
109+
}
110+
}
111+
112+
protected:
113+
framework::OpKernelType GetExpectedKernelType(
114+
const framework::ExecutionContext& ctx) const override {
115+
return framework::OpKernelType(
116+
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
117+
platform::CPUPlace());
118+
}
119+
};
120+
121+
class SequenceScatterGradOp : public framework::OperatorWithKernel {
122+
public:
123+
using framework::OperatorWithKernel::OperatorWithKernel;
124+
125+
void InferShape(framework::InferShapeContext* ctx) const override {
126+
ctx->SetOutputDim(framework::GradVarName("Updates"),
127+
ctx->GetInputDim("Updates"));
128+
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
129+
}
130+
131+
protected:
132+
framework::OpKernelType GetExpectedKernelType(
133+
const framework::ExecutionContext& ctx) const override {
134+
return framework::OpKernelType(
135+
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
136+
platform::CPUPlace());
137+
}
138+
};
139+
140+
} // namespace operators
141+
} // namespace paddle
142+
143+
namespace ops = paddle::operators;
144+
REGISTER_OPERATOR(sequence_scatter, ops::SequenceScatterOp,
145+
ops::SequenceScatterOpMaker,
146+
paddle::framework::DefaultGradOpDescMaker<true>);
147+
REGISTER_OPERATOR(sequence_scatter_grad, ops::SequenceScatterGradOp);
148+
REGISTER_OP_CPU_KERNEL(sequence_scatter, ops::SequenceScatterOpKernel<float>,
149+
ops::SequenceScatterOpKernel<double>,
150+
ops::SequenceScatterOpKernel<int>,
151+
ops::SequenceScatterOpKernel<int64_t>);
152+
REGISTER_OP_CPU_KERNEL(sequence_scatter_grad,
153+
ops::SequenceScatterGradientOpKernel<float>,
154+
ops::SequenceScatterGradientOpKernel<double>,
155+
ops::SequenceScatterGradientOpKernel<int>,
156+
ops::SequenceScatterGradientOpKernel<int64_t>);
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/fluid/framework/eigen.h"
17+
#include "paddle/fluid/framework/op_registry.h"
18+
#include "paddle/fluid/operators/gather.h"
19+
#include "paddle/fluid/operators/scatter.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
using Tensor = framework::Tensor;
25+
using LoDTensor = framework::LoDTensor;
26+
27+
template <typename T>
28+
class SequenceScatterOpKernel : public framework::OpKernel<T> {
29+
public:
30+
void Compute(const framework::ExecutionContext& ctx) const override {
31+
auto* x = ctx.Input<Tensor>("X");
32+
auto* ids = ctx.Input<LoDTensor>("Ids");
33+
auto* updates = ctx.Input<LoDTensor>("Updates");
34+
auto* out = ctx.Output<Tensor>("Out");
35+
36+
auto& ids_lod = ids->lod();
37+
38+
// Initialize out as same as x
39+
out->mutable_data<T>(ctx.GetPlace());
40+
framework::TensorCopySync(*x, ctx.GetPlace(), out);
41+
42+
auto x_dims = x->dims();
43+
auto out_dims = out->dims();
44+
45+
for (int i = 0; i < x_dims.size(); ++i)
46+
PADDLE_ENFORCE(x_dims[i] == out_dims[i],
47+
"Input and output shape of "
48+
"sequence scatter op must exactly be the same.");
49+
50+
size_t slice_size = 1;
51+
for (int i = 1; i < x_dims.size(); ++i) slice_size *= x_dims[i];
52+
53+
auto lod_vec = ids_lod[0];
54+
unsigned int seg = 0;
55+
for (int i = 0; i < ids->dims()[0]; ++i) {
56+
PADDLE_ENFORCE_LT(seg, lod_vec.size() - 1,
57+
"Segment num must not exceed batch size.\n");
58+
int lower_bound = lod_vec[seg];
59+
int upper_bound = lod_vec[seg + 1];
60+
if (i >= lower_bound && i < upper_bound) {
61+
T* p_out = out->data<T>();
62+
const T* p_updates = updates->data<T>();
63+
const int64_t* p_index = ids->data<int64_t>();
64+
p_out[seg * slice_size + p_index[i]] += p_updates[i];
65+
} else {
66+
++seg;
67+
--i;
68+
}
69+
}
70+
}
71+
};
72+
73+
template <typename T>
74+
class SequenceScatterGradientOpKernel : public framework::OpKernel<T> {
75+
public:
76+
void Compute(const framework::ExecutionContext& ctx) const override {
77+
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
78+
"This kernel only runs on CPU.");
79+
auto* dX = ctx.Output<Tensor>(framework::GradVarName("X"));
80+
auto* dUpdates = ctx.Output<LoDTensor>(framework::GradVarName("Updates"));
81+
auto* ids = ctx.Input<LoDTensor>("Ids");
82+
auto* dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
83+
84+
auto& ids_lod = ids->lod();
85+
86+
dX->mutable_data<T>(ctx.GetPlace());
87+
framework::TensorCopySync(*dOut, ctx.GetPlace(), dX);
88+
dUpdates->mutable_data<T>(ctx.GetPlace());
89+
90+
auto dx_dims = dX->dims();
91+
auto dout_dims = dOut->dims();
92+
93+
for (int i = 0; i < dx_dims.size(); ++i)
94+
PADDLE_ENFORCE(dx_dims[i] == dout_dims[i],
95+
"Input and output shape of "
96+
"sequence scatter grad op must exactly be the same.");
97+
98+
size_t slice_size = 1;
99+
for (int i = 1; i < dx_dims.size(); ++i) slice_size *= dx_dims[i];
100+
101+
auto lod_vec = ids_lod[0];
102+
unsigned int seg = 0;
103+
104+
for (int i = 0; i < ids->dims()[0]; ++i) {
105+
PADDLE_ENFORCE_LT(seg, lod_vec.size() - 1,
106+
"Segment num must not exceed batch size.\n");
107+
int lower_bound = lod_vec[seg];
108+
int upper_bound = lod_vec[seg + 1];
109+
if (i >= lower_bound && i < upper_bound) {
110+
const T* p_dOut = dOut->data<T>();
111+
const int64_t* p_index = ids->data<int64_t>();
112+
T* p_dUpdates = dUpdates->data<T>();
113+
p_dUpdates[i] = p_dOut[seg * slice_size + p_index[i]];
114+
} else {
115+
++seg;
116+
--i;
117+
}
118+
}
119+
}
120+
};
121+
} // namespace operators
122+
} // namespace paddle

python/paddle/fluid/layers/nn.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
'resize_bilinear',
101101
'gather',
102102
'scatter',
103+
'sequence_scatter',
103104
'random_crop',
104105
'mean_iou',
105106
'relu',
@@ -5425,6 +5426,66 @@ def scatter(input, index, updates, name=None):
54255426
return out
54265427

54275428

5429+
def sequence_scatter(input, index, updates, name=None):
5430+
"""
5431+
**Sequence Scatter Layer**
5432+
5433+
This operator scatters the Updates tensor to the input X. It uses the LoD
5434+
information of Ids to select the rows to update, and use the values in Ids as
5435+
the columns to update in each row of X.
5436+
5437+
Here is an example:
5438+
Given the following input:
5439+
.. code-block:: text
5440+
input.data = [[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
5441+
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
5442+
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]
5443+
input.dims = [3, 6]
5444+
5445+
index.data = [[0], [1], [2], [5], [4], [3], [2], [1], [3], [2], [5], [4]]
5446+
index.lod = [[0, 3, 8, 12]]
5447+
5448+
updates.data = [[0.3], [0.3], [0.4], [0.1], [0.2], [0.3], [0.4], [0.0], [0.2], [0.3], [0.1], [0.4]]
5449+
updates.lod = [[ 0, 3, 8, 12]]
5450+
5451+
Then we have the output:
5452+
.. code-block:: text
5453+
out.data = [[1.3, 1.3, 1.4, 1.0, 1.0, 1.0],
5454+
[1.0, 1.0, 1.4, 1.3, 1.2, 1.1],
5455+
[1.0, 1.0, 1.3, 1.2, 1.4, 1.1]]
5456+
out.dims = X.dims = [3, 6]
5457+
5458+
Args:
5459+
input (Variable): The source input with rank>=1.
5460+
index (Variable): A LoD Tensor. The index input of sequence scatter op
5461+
where input will be updated. The index input with rank=1. Its dtype
5462+
should be int32 or int64 as it is used as indexes.
5463+
updates (Variable): A LoD Tensor. The values to scatter to the input
5464+
tensor X, must be a LoDTensor with the same LoD information as index.
5465+
name (str|None): The output variable name. Default None.
5466+
5467+
Returns:
5468+
output (Variable): The output is a tensor with the same shape as input.
5469+
5470+
Examples:
5471+
5472+
.. code-block:: python
5473+
5474+
output = fluid.layers.sequence_scatter(input, index, updates)
5475+
5476+
"""
5477+
helper = LayerHelper('sequence_scatter', **locals())
5478+
dtype = helper.input_dtype()
5479+
out = helper.create_tmp_variable(dtype)
5480+
helper.append_op(
5481+
type="sequence_scatter",
5482+
inputs={"X": input,
5483+
"Ids": index,
5484+
"Updates": updates},
5485+
outputs={"Out": out})
5486+
return out
5487+
5488+
54285489
@templatedoc()
54295490
def random_crop(x, shape, seed=None):
54305491
"""

python/paddle/fluid/tests/unittests/test_layers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,30 @@ def test_scatter(self):
382382
self.assertIsNotNone(out)
383383
print(str(program))
384384

385+
def test_sequence_scatter(self):
386+
program = Program()
387+
with program_guard(program):
388+
x = layers.data(
389+
name='x',
390+
shape=[3, 6],
391+
append_batch_size=False,
392+
dtype='float32')
393+
idx = layers.data(
394+
name='idx',
395+
shape=[12, 1],
396+
append_batch_size=False,
397+
dtype='int32',
398+
lod_level=1)
399+
updates = layers.data(
400+
name='updates',
401+
shape=[12, 1],
402+
append_batch_size=False,
403+
dtype='float32',
404+
lod_level=1)
405+
out = layers.sequence_scatter(input=x, index=idx, updates=updates)
406+
self.assertIsNotNone(out)
407+
print(str(program))
408+
385409
def test_lod_reset(self):
386410
program = Program()
387411
with program_guard(program):

0 commit comments

Comments
 (0)