Skip to content

Commit e12b1d1

Browse files
Bai Yifanqingqing01
authored andcommitted
Add flatten op (#12341)
* add flatten op
1 parent 4dbcb97 commit e12b1d1

File tree

4 files changed

+239
-0
lines changed

4 files changed

+239
-0
lines changed
16 KB
Binary file not shown.

paddle/fluid/operators/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,8 @@ op_library(parallel_do_op DEPS executor)
271271
op_library(unsqueeze_op DEPS reshape_op)
272272
op_library(squeeze_op DEPS reshape_op)
273273
op_library(extract_rows_op DEPS memory)
274+
op_library(flatten_op DEPS reshape_op)
275+
274276

275277
if (WITH_GPU)
276278
op_library(conv_op DEPS vol2col depthwise_conv im2col)

paddle/fluid/operators/flatten_op.cc

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <vector>
16+
#include "paddle/fluid/framework/op_registry.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
using Tensor = framework::Tensor;
22+
23+
class FlattenOpInferShape : public framework::InferShapeBase {
24+
public:
25+
void operator()(framework::InferShapeContext *ctx) const override {
26+
PADDLE_ENFORCE(ctx->HasInput("X"),
27+
"Input (X) of Flatten op should not be null.");
28+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
29+
"Output (Output) of Flatten op should not be null.");
30+
const auto &axis = ctx->Attrs().Get<int>("axis");
31+
const auto &in_dims = ctx->GetInputDim("X");
32+
PADDLE_ENFORCE(axis >= 0, "The axis should be greater than or equal to 0.");
33+
PADDLE_ENFORCE(
34+
axis <= in_dims.size(),
35+
"The axis should be less than or equal to input tensor's rank.");
36+
37+
const auto &out_dims = GetOutputShape(axis, in_dims);
38+
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
39+
if (in_dims[0] == out_dims[0]) {
40+
// Only pass LoD when the first dimension of output and Input(X)
41+
// are the same.
42+
ctx->ShareLoD("X", "Out");
43+
}
44+
}
45+
46+
static std::vector<int32_t> GetOutputShape(const int axis,
47+
const framework::DDim &in_dims) {
48+
int64_t outer = 1, inner = 1;
49+
for (int i = 0; i < in_dims.size(); ++i) {
50+
if (i < axis) {
51+
outer *= in_dims[i];
52+
} else {
53+
inner *= in_dims[i];
54+
}
55+
}
56+
std::vector<int32_t> out_shape(2);
57+
out_shape[0] = outer;
58+
out_shape[1] = inner;
59+
return out_shape;
60+
}
61+
};
62+
63+
class FlattenOp : public framework::OperatorBase {
64+
public:
65+
using OperatorBase::OperatorBase;
66+
67+
private:
68+
void RunImpl(const framework::Scope &scope,
69+
const platform::Place &place) const override {
70+
auto &axis = Attr<int>("axis");
71+
auto in_dims =
72+
scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
73+
const auto &out_dims = FlattenOpInferShape::GetOutputShape(axis, in_dims);
74+
75+
framework::AttributeMap attrs;
76+
attrs["shape"] = out_dims;
77+
attrs["inplace"] = false;
78+
// Invoke Reshape Op
79+
auto reshape_op = framework::OpRegistry::CreateOp(
80+
"reshape", {{"X", {Input("X")}}, {"Shape", {}}},
81+
{{"Out", {Output("Out")}}}, attrs);
82+
reshape_op->Run(scope, place);
83+
}
84+
};
85+
86+
class FlattenOpMaker : public framework::OpProtoAndCheckerMaker {
87+
public:
88+
void Make() override {
89+
AddInput("X", "(Tensor) A tensor of rank >= axis.");
90+
AddOutput("Out",
91+
"A 2D tensor is reshaped input tensor. The input dimensions"
92+
"up to axis are flattened to the outer dimension of the output"
93+
"and the remaining input dimensions are flattened into the inner"
94+
"dimension of the output.");
95+
AddAttr<int>("axis",
96+
"(int)"
97+
"Indicate up to which input dimensions (exclusive) should be"
98+
"flattened to the outer dimension of the output. The value"
99+
"for axis must be in the range [0, R], where R is the rank of"
100+
"the input tensor. When axis = 0, the shape of the output"
101+
"tensor is (1, (d_0 X d_1 ... d_n), where the shape of the"
102+
"input tensor is (d_0, d_1, ... d_n).")
103+
.SetDefault(1);
104+
AddComment(R"DOC(
105+
Flatten Operator
106+
107+
Flattens the input tensor into a 2D matrix.
108+
109+
Examples:
110+
Case 1:
111+
Given
112+
X.shape = (3, 100, 100, 4)
113+
and
114+
axis = 2
115+
We get:
116+
Out.shape = (3 * 100, 4 * 100)
117+
118+
Case 2:
119+
Given
120+
X.shape = (3, 100, 100, 4)
121+
and
122+
axis = 0
123+
We get:
124+
Out.shape = (1, 3 * 100 * 100 * 4)
125+
)DOC");
126+
}
127+
};
128+
129+
class FlattenGradInferShape : public framework::InferShapeBase {
130+
public:
131+
void operator()(framework::InferShapeContext *context) const override {
132+
context->SetOutputDim(framework::GradVarName("X"),
133+
context->GetInputDim("X"));
134+
context->ShareLoD("X", framework::GradVarName("X"));
135+
}
136+
};
137+
138+
class FlattenGradOp : public framework::OperatorBase {
139+
public:
140+
using OperatorBase::OperatorBase;
141+
142+
private:
143+
void RunImpl(const framework::Scope &scope,
144+
const platform::Place &place) const override {
145+
auto dx_name = Output(framework::GradVarName("X"));
146+
auto dout_name = Input(framework::GradVarName("Out"));
147+
auto in_dims =
148+
scope.FindVar(Input("X"))->Get<framework::LoDTensor>().dims();
149+
framework::AttributeMap attrs;
150+
attrs["shape"] = framework::vectorize2int(in_dims);
151+
attrs["inplace"] = false;
152+
153+
auto reshape_op = framework::OpRegistry::CreateOp(
154+
"reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}},
155+
attrs);
156+
reshape_op->Run(scope, place);
157+
}
158+
};
159+
160+
} // namespace operators
161+
} // namespace paddle
162+
163+
USE_OP(reshape);
164+
165+
namespace ops = paddle::operators;
166+
REGISTER_OPERATOR(flatten, ops::FlattenOp, ops::FlattenOpMaker,
167+
ops::FlattenOpInferShape,
168+
paddle::framework::DefaultGradOpDescMaker<true>);
169+
REGISTER_OPERATOR(flatten_grad, ops::FlattenGradOp, ops::FlattenGradInferShape);
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import numpy as np
17+
18+
from op_test import OpTest
19+
20+
21+
class TestFlattenOp(OpTest):
22+
def setUp(self):
23+
self.op_type = "flatten"
24+
self.init_test_case()
25+
self.inputs = {"X": np.random.random(self.in_shape).astype("float32")}
26+
self.init_attrs()
27+
self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)}
28+
29+
def test_check_output(self):
30+
self.check_output()
31+
32+
def test_check_grad(self):
33+
self.check_grad(["X"], "Out")
34+
35+
def init_test_case(self):
36+
self.in_shape = (3, 2, 2, 5)
37+
self.axis = 1
38+
self.new_shape = (3, 20)
39+
40+
def init_attrs(self):
41+
self.attrs = {"axis": self.axis}
42+
43+
44+
class TestFlattenOp(TestFlattenOp):
45+
def init_test_case(self):
46+
self.in_shape = (3, 2, 2, 3)
47+
self.axis = 0
48+
self.new_shape = (1, 36)
49+
50+
51+
class TestFlattenOpWithDefaultAxis(TestFlattenOp):
52+
def init_test_case(self):
53+
self.in_shape = (3, 2, 2, 3)
54+
self.new_shape = (3, 12)
55+
56+
def init_attrs(self):
57+
self.attrs = {}
58+
59+
60+
class TestFlattenOpSixDims(TestFlattenOp):
61+
def init_test_case(self):
62+
self.in_shape = (3, 2, 3, 2, 4, 4)
63+
self.axis = 4
64+
self.new_shape = (36, 16)
65+
66+
67+
if __name__ == "__main__":
68+
unittest.main()

0 commit comments

Comments
 (0)