Skip to content

Commit 2f7b093

Browse files
committed
Merge remote-tracking branch 'ups/develop' into feature/libxsmm
2 parents 908b534 + 1617fe2 commit 2f7b093

20 files changed

+723
-54
lines changed

paddle/fluid/operators/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,8 @@ op_library(recurrent_op DEPS executor)
265265
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
266266
op_library(cos_sim_op DEPS cos_sim_functor)
267267
op_library(parallel_do_op DEPS executor)
268+
op_library(unsqueeze_op DEPS reshape_op)
269+
op_library(squeeze_op DEPS reshape_op)
268270

269271
if (WITH_GPU)
270272
op_library(conv_op DEPS vol2col depthwise_conv im2col)

paddle/fluid/operators/squeeze_op.cc

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <string>
16+
#include <vector>
17+
#include "paddle/fluid/framework/op_registry.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
class SqueezeOpInferShape : public framework::InferShapeBase {
23+
public:
24+
void operator()(framework::InferShapeContext *ctx) const override {
25+
PADDLE_ENFORCE(ctx->HasInput("X"),
26+
"Input(X) of SqueezeOp should not be null.");
27+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
28+
"Output(Out) of SqueezeOp should not be null.");
29+
30+
const auto &x_dims = ctx->GetInputDim("X");
31+
// Check input tensor dims (<6) Eigen limit.
32+
PADDLE_ENFORCE(x_dims.size() <= 6,
33+
"Invalid dimnesions, the rank of Input(X) "
34+
"should be in the range of [1, 6] (Eigen limit).");
35+
36+
const auto &axes = ctx->Attrs().Get<std::vector<int>>("axes");
37+
for (int a : axes) {
38+
PADDLE_ENFORCE_LT(a, x_dims.size(),
39+
"The squeeze axis should be less than input "
40+
"tensor's rank.");
41+
}
42+
43+
auto out_dims = GetOutputShape(axes, x_dims);
44+
ctx->SetOutputDim("Out", out_dims);
45+
if (x_dims[0] == out_dims[0]) {
46+
// Only pass LoD when the first dimension of output and Input(X)
47+
// are the same.
48+
ctx->ShareLoD("X", "Out");
49+
}
50+
}
51+
52+
static framework::DDim GetOutputShape(const std::vector<int> squeeze_dims,
53+
const framework::DDim &in_dims) {
54+
size_t num_squeeze_dims = squeeze_dims.size();
55+
int cnt_squeezed_dims = 0;
56+
bool should_squeeze[9] = {false};
57+
58+
// Determines number of dimensions of output tensor after squeeze.
59+
// Mark and count the dimensions need to be squeezed
60+
if (num_squeeze_dims == 0) {
61+
for (int idx = 0; idx < in_dims.size(); ++idx) {
62+
if (in_dims[idx] == 1) {
63+
should_squeeze[idx] = true;
64+
++cnt_squeezed_dims;
65+
}
66+
}
67+
} else {
68+
for (size_t idx = 0; idx < num_squeeze_dims; ++idx) {
69+
int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size()
70+
: squeeze_dims[idx];
71+
// Check current index, the upper limit has beed checked in line 36.
72+
PADDLE_ENFORCE(current >= 0,
73+
"Invalid axis, the negative axis is out of range.");
74+
PADDLE_ENFORCE(in_dims[current] == 1,
75+
"Invalid axis index, the axis that will be squeezed "
76+
"should be equal to 1.");
77+
78+
if (!(should_squeeze[current])) {
79+
++cnt_squeezed_dims;
80+
}
81+
should_squeeze[current] = true;
82+
}
83+
}
84+
85+
// Make output dimensions
86+
std::vector<int64_t> output_shape(in_dims.size() - cnt_squeezed_dims, 0);
87+
for (int in_idx = 0, out_idx = 0; in_idx < in_dims.size(); ++in_idx) {
88+
if (!should_squeeze[in_idx]) {
89+
output_shape[out_idx++] = in_dims[in_idx];
90+
}
91+
}
92+
93+
return framework::make_ddim(output_shape);
94+
}
95+
};
96+
97+
class SqueezeOp : public framework::OperatorBase {
98+
public:
99+
using OperatorBase::OperatorBase;
100+
101+
private:
102+
void RunImpl(const framework::Scope &scope,
103+
const platform::Place &place) const override {
104+
auto &axes = Attr<std::vector<int>>("axes");
105+
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
106+
auto out_dims = SqueezeOpInferShape::GetOutputShape(axes, x_dims);
107+
108+
framework::AttributeMap attrs;
109+
attrs["shape"] = framework::vectorize2int(out_dims);
110+
attrs["inplace"] = Attr<bool>("inplace");
111+
// Invoke Reshape Op
112+
auto reshape_op = framework::OpRegistry::CreateOp(
113+
"reshape", {{"X", {Input("X")}}, {"Shape", {}}},
114+
{{"Out", {Output("Out")}}}, attrs);
115+
reshape_op->Run(scope, place);
116+
}
117+
};
118+
119+
class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker {
120+
public:
121+
void Make() override {
122+
AddInput("X", "(Tensor). The input tensor of squeeze operator.");
123+
AddOutput("Out", "(Tensor). The output tensor of squeeze operator.");
124+
AddAttr<std::vector<int>>("axes",
125+
"(std::vector<int>). List of integers,"
126+
" indicating the dimensions to squeeze.")
127+
.SetDefault({});
128+
AddAttr<bool>("inplace",
129+
"(default: false) Squeeze the source tensor's shape without "
130+
"memory copy. When Attr(inplace) is set true, the output "
131+
"tensor shares memory with Input(X), otherwise, a new output "
132+
"tensor is created, and its data are copied from Input(x).")
133+
.SetDefault(false);
134+
AddComment(R"DOC(
135+
Squeeze Operator.
136+
137+
Remove single-dimensional entries from the shape of a tensor.
138+
Takes a parameter axes with a list of axes to squeeze.
139+
If axes is not provided, all the single dimensions will be removed from the shape.
140+
If an axis is selected with shape entry not equal to one, an error is raised.
141+
142+
Examples:
143+
Case 1:
144+
Given
145+
X.shape = (1, 3, 1, 5)
146+
and
147+
axes = [0]
148+
we get:
149+
Out.shape = (3, 1, 5)
150+
151+
Case 2:
152+
Given
153+
X.shape = (1, 3, 1, 5)
154+
and
155+
axes = []
156+
we get:
157+
Out.shape = (3, 5)
158+
)DOC");
159+
}
160+
};
161+
162+
class SqueezeGradInferShape : public framework::InferShapeBase {
163+
public:
164+
void operator()(framework::InferShapeContext *context) const override {
165+
context->SetOutputDim(framework::GradVarName("X"),
166+
context->GetInputDim("X"));
167+
context->ShareLoD("X", framework::GradVarName("X"));
168+
}
169+
};
170+
171+
class SqueezeGradOp : public framework::OperatorBase {
172+
public:
173+
using OperatorBase::OperatorBase;
174+
175+
private:
176+
void RunImpl(const framework::Scope &scope,
177+
const platform::Place &place) const override {
178+
auto dx_name = Output(framework::GradVarName("X"));
179+
auto dout_name = Input(framework::GradVarName("Out"));
180+
auto x_dims = scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
181+
framework::AttributeMap attrs;
182+
attrs["shape"] = framework::vectorize2int(x_dims);
183+
attrs["inplace"] = Attr<bool>("inplace");
184+
185+
auto reshape_op = framework::OpRegistry::CreateOp(
186+
"reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}},
187+
attrs);
188+
reshape_op->Run(scope, place);
189+
}
190+
};
191+
192+
} // namespace operators
193+
} // namespace paddle
194+
195+
// Tell linker to use reshape op
196+
USE_OP(reshape);
197+
198+
namespace ops = paddle::operators;
199+
REGISTER_OPERATOR(squeeze, ops::SqueezeOp, ops::SqueezeOpMaker,
200+
ops::SqueezeOpInferShape,
201+
paddle::framework::DefaultGradOpDescMaker<true>);
202+
REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp, ops::SqueezeGradInferShape);

0 commit comments

Comments
 (0)