Skip to content

Commit 599a326

Browse files
authored
Merge pull request #12971 from sneaxiy/unstack_op
Add unstack op
2 parents 0b77518 + 52a480b commit 599a326

File tree

6 files changed

+286
-0
lines changed

6 files changed

+286
-0
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], vara
164164
paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None))
165165
paddle.fluid.layers.sequence_mask ArgSpec(args=['x', 'maxlen', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 'int64', None))
166166
paddle.fluid.layers.stack ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,))
167+
paddle.fluid.layers.unstack ArgSpec(args=['x', 'axis', 'num'], varargs=None, keywords=None, defaults=(0, None))
167168
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
168169
paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True))
169170
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))

paddle/fluid/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ op_library(unsqueeze_op DEPS reshape_op)
291291
op_library(squeeze_op DEPS reshape_op)
292292
op_library(extract_rows_op DEPS memory)
293293
op_library(flatten_op DEPS reshape_op)
294+
op_library(unstack_op DEPS stack_op)
294295

295296
if (WITH_GPU)
296297
op_library(conv_op DEPS vol2col depthwise_conv im2col)

paddle/fluid/operators/unstack_op.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/unstack_op.h"
16+
17+
namespace plat = paddle::platform;
18+
namespace ops = paddle::operators;
19+
20+
USE_OP(stack);
21+
22+
REGISTER_OPERATOR(unstack, ops::UnStackOp, ops::UnStackOpMaker,
23+
ops::UnStackOpInferShape, ops::UnStackGradOpDescMaker);
24+
25+
REGISTER_OPERATOR(unstack_grad, ops::UnStackGradOp,
26+
ops::UnStackOpGradInferShape);

paddle/fluid/operators/unstack_op.h

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/op_registry.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
class UnStackOpInferShape : public framework::InferShapeBase {
23+
public:
24+
void operator()(framework::InferShapeContext *ctx) const override {
25+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist.");
26+
27+
int axis = ctx->Attrs().Get<int>("axis");
28+
int num = ctx->Attrs().Get<int>("num");
29+
auto x_dim = ctx->GetInputDim("X");
30+
int rank = x_dim.size();
31+
PADDLE_ENFORCE(axis >= -rank && axis < rank,
32+
"Attr(axis) must be inside [-rank, rank), where rank = %d",
33+
rank);
34+
if (axis < 0) axis += rank;
35+
36+
PADDLE_ENFORCE_EQ(ctx->Outputs("Y").size(), static_cast<size_t>(num),
37+
"Number of Outputs(Y) is wrong");
38+
if (x_dim[axis] > 0) {
39+
PADDLE_ENFORCE_EQ(num, x_dim[axis], "Number of Outputs(Y) is wrong");
40+
}
41+
auto vec = framework::vectorize2int(x_dim);
42+
vec.erase(vec.begin() + axis);
43+
ctx->SetOutputsDim("Y", std::vector<framework::DDim>( // NOLINT
44+
x_dim[axis], framework::make_ddim(vec)));
45+
}
46+
};
47+
48+
class UnStackOpMaker : public framework::OpProtoAndCheckerMaker {
49+
public:
50+
void Make() override {
51+
AddInput("X", "The input of unstack op.");
52+
AddOutput("Y", "The output of unstack op.").AsDuplicable();
53+
AddAttr<int>("axis", "The axis along which Input(X) should be unstacked.")
54+
.SetDefault(0);
55+
AddAttr<int>("num", "The number of outputs(Y).").GreaterThan(0);
56+
AddComment(R"DOC(
57+
UnStack Operator.
58+
59+
UnStack Input(X) into several tensors along Attr(axis).
60+
)DOC");
61+
}
62+
};
63+
64+
class UnStackOp : public framework::OperatorBase {
65+
public:
66+
using OperatorBase::OperatorBase;
67+
68+
private:
69+
void RunImpl(const framework::Scope &scope,
70+
const platform::Place &place) const override {
71+
auto stack_grad_op = framework::OpRegistry::CreateOp(
72+
"stack_grad", {{framework::GradVarName("Y"), {Input("X")}}},
73+
{{framework::GradVarName("X"), Outputs("Y")}}, Attrs());
74+
stack_grad_op->Run(scope, place);
75+
}
76+
};
77+
78+
class UnStackOpGradInferShape : public framework::InferShapeBase {
79+
public:
80+
void operator()(framework::InferShapeContext *ctx) const override {
81+
PADDLE_ENFORCE_GT(ctx->Inputs(framework::GradVarName("Y")).size(), 0,
82+
"Number of Inputs(Y@Grad) must be larger than 0");
83+
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
84+
"Output(X@Grad) must exist.");
85+
86+
auto input_dims = ctx->GetInputsDim(framework::GradVarName("Y"));
87+
for (size_t i = 1; i < input_dims.size(); ++i) {
88+
PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0],
89+
"Dims of all Inputs(Y@Grad) must be the same");
90+
}
91+
92+
int axis = ctx->Attrs().Get<int>("axis");
93+
int rank = input_dims[0].size();
94+
PADDLE_ENFORCE(
95+
axis >= -(rank + 1) && axis < rank + 1,
96+
"Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank);
97+
if (axis < 0) axis += (rank + 1);
98+
99+
auto vec = framework::vectorize2int(input_dims[0]);
100+
vec.insert(vec.begin() + axis, input_dims.size());
101+
ctx->SetOutputDim(framework::GradVarName("X"), framework::make_ddim(vec));
102+
}
103+
};
104+
105+
class UnStackGradOpDescMaker : public framework::SingleGradOpDescMaker {
106+
public:
107+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
108+
109+
protected:
110+
std::unique_ptr<framework::OpDesc> Apply() const override {
111+
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
112+
op->SetType("unstack_grad");
113+
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
114+
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
115+
op->SetAttrMap(Attrs());
116+
return op;
117+
}
118+
};
119+
120+
class UnStackGradOp : public framework::OperatorBase {
121+
public:
122+
using OperatorBase::OperatorBase;
123+
124+
private:
125+
void RunImpl(const framework::Scope &scope,
126+
const platform::Place &place) const override {
127+
auto stack_op = framework::OpRegistry::CreateOp(
128+
"stack", {{"X", Inputs(framework::GradVarName("Y"))}},
129+
{{"Y", {Output(framework::GradVarName("X"))}}}, Attrs());
130+
stack_op->Run(scope, place);
131+
}
132+
};
133+
134+
} // namespace operators
135+
} // namespace paddle

python/paddle/fluid/layers/nn.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
'flatten',
106106
'sequence_mask',
107107
'stack',
108+
'unstack',
108109
]
109110

110111

@@ -5601,3 +5602,44 @@ def stack(x, axis=0):
56015602
type='stack', inputs={'X': x}, outputs={'Y': out},
56025603
attrs={'axis': axis})
56035604
return out
5605+
5606+
5607+
def unstack(x, axis=0, num=None):
5608+
"""
5609+
**UnStack Layer**
5610+
5611+
This layer unstacks input :code:`x` into several tensors along axis.
5612+
5613+
If :code:`axis` < 0, it would be replaced with :code:`axis+rank(x)`.
5614+
If :code:`num` is None, it would be inferred from :code:`x.shape[axis]`,
5615+
and if :code:`x.shape[axis]` <= 0 or is unknown, :code:`ValueError` is
5616+
raised.
5617+
5618+
Args:
5619+
x (Variable): Input variable.
5620+
axis (int): The axis along which the input is unstacked.
5621+
num (int|None): The number of output variables.
5622+
5623+
Returns:
5624+
list(Variable): The unstacked variables.
5625+
5626+
"""
5627+
5628+
helper = LayerHelper('unstack', **locals())
5629+
if num is None:
5630+
if axis is None or x.shape[axis] <= 0:
5631+
raise ValueError('unknown unstack number')
5632+
else:
5633+
num = x.shape[axis]
5634+
5635+
outs = []
5636+
for _ in num:
5637+
outs.append(helper.create_tmp_variable(x.dtype))
5638+
5639+
helper.append_op(
5640+
type='unstack',
5641+
inputs={'X': [x]},
5642+
outputs={'Y': outs},
5643+
attrs={'axis': axis,
5644+
'num': num})
5645+
return outs
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from op_test import OpTest
16+
import numpy as np
17+
import unittest
18+
19+
20+
class TestUnStackOpBase(OpTest):
21+
def initDefaultParameters(self):
22+
self.input_dim = (5, 6, 7)
23+
self.axis = 0
24+
self.dtype = 'float32'
25+
26+
def initParameters(self):
27+
pass
28+
29+
def get_y_names(self):
30+
y_names = []
31+
for i in range(self.input_dim[self.axis]):
32+
y_names.append('y{}'.format(i))
33+
return y_names
34+
35+
def setUp(self):
36+
self.initDefaultParameters()
37+
self.initParameters()
38+
self.op_type = 'unstack'
39+
self.x = np.random.random(size=self.input_dim).astype(self.dtype)
40+
41+
outs = np.split(self.x, self.input_dim[self.axis], self.axis)
42+
new_shape = list(self.input_dim)
43+
del new_shape[self.axis]
44+
y_names = self.get_y_names()
45+
tmp = []
46+
for i in range(self.input_dim[self.axis]):
47+
tmp.append((y_names[i], np.reshape(outs[i], new_shape)))
48+
49+
self.inputs = {'X': self.x}
50+
self.outputs = {'Y': tmp}
51+
self.attrs = {'axis': self.axis, 'num': self.input_dim[self.axis]}
52+
53+
def test_check_output(self):
54+
self.check_output()
55+
56+
def test_check_grad(self):
57+
self.check_grad('X', self.get_y_names())
58+
59+
60+
class TestStackOp3(TestUnStackOpBase):
61+
def initParameters(self):
62+
self.axis = -1
63+
64+
65+
class TestStackOp4(TestUnStackOpBase):
66+
def initParameters(self):
67+
self.axis = -3
68+
69+
70+
class TestStackOp5(TestUnStackOpBase):
71+
def initParameters(self):
72+
self.axis = 1
73+
74+
75+
class TestStackOp6(TestUnStackOpBase):
76+
def initParameters(self):
77+
self.axis = 2
78+
79+
80+
if __name__ == '__main__':
81+
unittest.main()

0 commit comments

Comments
 (0)