Skip to content

Commit 1617fe2

Browse files
author
Yibing Liu
authored
Merge pull request #11897 from chenwhql/unsqueeze_op
Add Unsqueeze operator and unit testing
2 parents 99cdfed + 938319b commit 1617fe2

File tree

3 files changed

+303
-0
lines changed

3 files changed

+303
-0
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ op_library(recurrent_op DEPS executor)
265265
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
266266
op_library(cos_sim_op DEPS cos_sim_functor)
267267
op_library(parallel_do_op DEPS executor)
268+
op_library(unsqueeze_op DEPS reshape_op)
268269
op_library(squeeze_op DEPS reshape_op)
269270

270271
if (WITH_GPU)
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <string>
16+
#include <vector>
17+
#include "paddle/fluid/framework/op_registry.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
class UnsqueezeOpInferShape : public framework::InferShapeBase {
23+
public:
24+
void operator()(framework::InferShapeContext *ctx) const override {
25+
PADDLE_ENFORCE(ctx->HasInput("X"),
26+
"Input(X) of UnsqueezeOp should not be null.");
27+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
28+
"Output(Out) of UnsqueezeOp should not be null.");
29+
30+
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
31+
const auto &x_dims = ctx->GetInputDim("X");
32+
// Validity Check: input tensor dims (<6).
33+
PADDLE_ENFORCE(x_dims.size() <= 6,
34+
"Invalid dimensions, the rank of Input(X) "
35+
"should be in the range of [1, 6] (Eigen limit)");
36+
auto out_dims = GetOutputShape(axes, x_dims);
37+
ctx->SetOutputDim("Out", out_dims);
38+
if (x_dims[0] == out_dims[0]) {
39+
// Only pass LoD when the first dimension of output and Input(X)
40+
// are the same.
41+
ctx->ShareLoD("X", "Out");
42+
}
43+
}
44+
45+
static framework::DDim GetOutputShape(const std::vector<int> unsqz_dims,
46+
const framework::DDim &in_dims) {
47+
int output_size = in_dims.size() + static_cast<int>(unsqz_dims.size());
48+
int cur_output_size = in_dims.size();
49+
std::vector<int64_t> output_shape(output_size, 0);
50+
51+
// Validity Check: rank range.
52+
PADDLE_ENFORCE(output_size <= 6,
53+
"The output tensor's rank should be less than 6.");
54+
55+
for (int axis : unsqz_dims) {
56+
int cur = axis < 0 ? axis + cur_output_size + 1 : axis;
57+
// Vaildity Check: the axis bound
58+
PADDLE_ENFORCE(
59+
cur >= 0 && cur <= cur_output_size,
60+
"The unsqueeze dims must be within range of current rank.");
61+
// Move old axis, and insert new axis
62+
for (int i = cur_output_size; i >= cur; --i) {
63+
if (output_shape[i] == 1) {
64+
// Move axis
65+
output_shape[i + 1] = 1;
66+
output_shape[i] = 0;
67+
}
68+
}
69+
output_shape[cur] = 1;
70+
// Add the output size.
71+
cur_output_size++;
72+
}
73+
74+
// Make output shape
75+
for (int in_idx = 0, out_idx = 0; out_idx < output_size; ++out_idx) {
76+
if (output_shape[out_idx] == 0) {
77+
output_shape[out_idx] = in_dims[in_idx++];
78+
}
79+
}
80+
81+
return framework::make_ddim(output_shape);
82+
}
83+
};
84+
85+
class UnsqueezeOp : public framework::OperatorBase {
86+
public:
87+
using OperatorBase::OperatorBase;
88+
89+
private:
90+
void RunImpl(const framework::Scope &scope,
91+
const platform::Place &place) const override {
92+
auto &axes = Attr<std::vector<int>>("axes");
93+
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
94+
auto out_dims = UnsqueezeOpInferShape::GetOutputShape(axes, x_dims);
95+
96+
framework::AttributeMap attrs;
97+
attrs["shape"] = framework::vectorize2int(out_dims);
98+
attrs["inplace"] = Attr<bool>("inplace");
99+
// Invoke Reshape op.
100+
auto reshape_op = framework::OpRegistry::CreateOp(
101+
"reshape", {{"X", {Input("X")}}, {"Shape", {}}},
102+
{{"Out", {Output("Out")}}}, attrs);
103+
reshape_op->Run(scope, place);
104+
}
105+
};
106+
107+
class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
108+
public:
109+
void Make() override {
110+
AddInput("X", "(Tensor). The input tensor of unsqueeze operator.");
111+
AddOutput("Out", "(Tensor). The output tensor of unsqueeze operator.");
112+
AddAttr<std::vector<int>>("axes",
113+
"(std::vector<int>). List of integers,"
114+
" indicating the dimensions to be inserted")
115+
.AddCustomChecker([](const std::vector<int> &axes) {
116+
PADDLE_ENFORCE(!axes.empty(),
117+
"Invalid axes, The unsqueeze axes is empty.");
118+
// Validity Check: axes dims (<6).
119+
PADDLE_ENFORCE(static_cast<int>(axes.size()) < 6,
120+
"Invalid dimensions, dynamic dimensions should be "
121+
"within [1, 6] dimensions (Eigen limit).");
122+
// Validity Check: the range of unsqueeze aixs.
123+
for (int axis : axes) {
124+
PADDLE_ENFORCE(axis < 6,
125+
"Invalid dimensions, input axis should be"
126+
" within [1, 6] dimensions (Eigen limit).");
127+
}
128+
});
129+
AddAttr<bool>(
130+
"inplace",
131+
"(default: false) Unsqueeze the source tensor's shape without "
132+
"memory copy. When Attr(inplace) is set true, the output "
133+
"tensor shares memory with Input(X), otherwise, a new output "
134+
"tensor is created, and its data are copied from Input(x).")
135+
.SetDefault(false);
136+
AddComment(R"DOC(
137+
Unsqueeze Operator.
138+
139+
Insert single-dimensional entries to the shape of a tensor.
140+
Takes one required argument axes, a list of dimensions that will be inserted.
141+
Dimension indices in axes are as seen in the output tensor.
142+
143+
For example:
144+
Given a tensor such that tensor with shape [3, 4, 5],
145+
then Unsqueeze(tensor, axes=[0, 4]) has shape [1, 3, 4, 5, 1]
146+
)DOC");
147+
}
148+
};
149+
150+
class UnsqueezeGradInferShape : public framework::InferShapeBase {
151+
public:
152+
void operator()(framework::InferShapeContext *ctx) const override {
153+
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
154+
ctx->ShareLoD("X", framework::GradVarName("X"));
155+
}
156+
};
157+
158+
class UnsqueezeGradOp : public framework::OperatorBase {
159+
public:
160+
using OperatorBase::OperatorBase;
161+
162+
private:
163+
void RunImpl(const framework::Scope &scope,
164+
const platform::Place &place) const override {
165+
auto dx_name = Output(framework::GradVarName("X"));
166+
auto dout_name = Input(framework::GradVarName("Out"));
167+
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
168+
169+
framework::AttributeMap attrs;
170+
attrs["shape"] = framework::vectorize2int(x_dims);
171+
attrs["inplace"] = Attr<bool>("inplace");
172+
173+
auto reshape_op = framework::OpRegistry::CreateOp(
174+
"reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}},
175+
attrs);
176+
reshape_op->Run(scope, place);
177+
}
178+
};
179+
180+
} // namespace operators
181+
} // namespace paddle
182+
183+
// Tell linker to use reshape op.
184+
USE_OP(reshape);
185+
186+
namespace ops = paddle::operators;
187+
REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker,
188+
ops::UnsqueezeOpInferShape,
189+
paddle::framework::DefaultGradOpDescMaker<true>);
190+
REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp,
191+
ops::UnsqueezeGradInferShape);
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import numpy as np
17+
18+
from op_test import OpTest
19+
20+
21+
# Correct: General.
22+
class TestUnsqueezeOp(OpTest):
23+
def setUp(self):
24+
self.init_test_case()
25+
self.op_type = "unsqueeze"
26+
self.inputs = {"X": np.random.random(self.ori_shape).astype("float32")}
27+
self.init_attrs()
28+
self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)}
29+
30+
def test_check_output(self):
31+
self.check_output()
32+
33+
def test_check_grad(self):
34+
self.check_grad(["X"], "Out")
35+
36+
def init_test_case(self):
37+
self.ori_shape = (3, 5)
38+
self.axes = (1, 2)
39+
self.new_shape = (3, 1, 1, 5)
40+
41+
def init_attrs(self):
42+
self.attrs = {"axes": self.axes, "inplace": False}
43+
44+
45+
# Correct: Single input index.
46+
class TestUnsqueezeOp1(TestUnsqueezeOp):
47+
def init_test_case(self):
48+
self.ori_shape = (3, 5)
49+
self.axes = (-1, )
50+
self.new_shape = (3, 5, 1)
51+
52+
53+
# Correct: Mixed input axis.
54+
class TestUnsqueezeOp2(TestUnsqueezeOp):
55+
def init_test_case(self):
56+
self.ori_shape = (3, 5)
57+
self.axes = (0, -1)
58+
self.new_shape = (1, 3, 5, 1)
59+
60+
61+
# Correct: There is duplicated axis.
62+
class TestUnsqueezeOp3(TestUnsqueezeOp):
63+
def init_test_case(self):
64+
self.ori_shape = (3, 2, 5)
65+
self.axes = (0, 3, 3)
66+
self.new_shape = (1, 3, 2, 1, 1, 5)
67+
68+
69+
# Correct: Reversed axes.
70+
class TestUnsqueezeOp4(TestUnsqueezeOp):
71+
def init_test_case(self):
72+
self.ori_shape = (3, 2, 5)
73+
self.axes = (3, 1, 1)
74+
self.new_shape = (3, 1, 1, 2, 5, 1)
75+
76+
77+
# Correct: Inplace.
78+
class TestUnsqueezeOpInplace1(TestUnsqueezeOp):
79+
def init_test_case(self):
80+
self.ori_shape = (3, 5)
81+
self.axes = (0, 2)
82+
self.new_shape = (1, 3, 1, 5)
83+
84+
def init_attrs(self):
85+
self.attrs = {"axes": self.axes, "inplace": True}
86+
87+
88+
# Correct: Inplace. There is mins index.
89+
class TestUnsqueezeOpInplace2(TestUnsqueezeOp):
90+
def init_test_case(self):
91+
self.ori_shape = (3, 5)
92+
self.axes = (0, -2)
93+
self.new_shape = (1, 3, 1, 5)
94+
95+
def init_attrs(self):
96+
self.attrs = {"axes": self.axes, "inplace": True}
97+
98+
99+
# Correct: Inplace. There is duplicated axis.
100+
class TestUnsqueezeOpInplace3(TestUnsqueezeOp):
101+
def init_test_case(self):
102+
self.ori_shape = (3, 2, 5)
103+
self.axes = (0, 3, 3)
104+
self.new_shape = (1, 3, 2, 1, 1, 5)
105+
106+
def init_attrs(self):
107+
self.attrs = {"axes": self.axes, "inplace": True}
108+
109+
110+
if __name__ == "__main__":
111+
unittest.main()

0 commit comments

Comments
 (0)