Skip to content

Commit b03a44e

Browse files
authored
Merge pull request #14026 from JiabinYang/add_reorg_op
Add reorg op
2 parents ff6c809 + 9f65b61 commit b03a44e

File tree

8 files changed

+498
-0
lines changed

8 files changed

+498
-0
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None
174174
paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None))
175175
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,))
176176
paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,))
177+
paddle.fluid.layers.space_to_depth ArgSpec(args=['x', 'blocksize', 'name'], varargs=None, keywords=None, defaults=(None,))
177178
paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,))
178179
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
179180
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/space_to_depth_op.h"
16+
#include <string>
17+
#include <vector>
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
class SpaceToDepthOp : public framework::OperatorWithKernel {
23+
public:
24+
using framework::OperatorWithKernel::OperatorWithKernel;
25+
26+
void InferShape(framework::InferShapeContext* ctx) const override {
27+
PADDLE_ENFORCE(ctx->HasInput("X"),
28+
"Input(X) of SpaceToDepthOp should not be null.");
29+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
30+
"Output(Out) of SpaceToDepthOp should not be null.");
31+
32+
auto x_dims = ctx->GetInputDim("X");
33+
PADDLE_ENFORCE_EQ(x_dims.size(), 4, "input should be a 4D tensor");
34+
auto blocksize = ctx->Attrs().Get<int64_t>("blocksize");
35+
36+
PADDLE_ENFORCE_GT(blocksize, 1, "The blocksize should be Greater than 1");
37+
PADDLE_ENFORCE_GT(x_dims[1], 0, "input channel should be Greater than 0");
38+
PADDLE_ENFORCE_GT(x_dims[2], 0, "input Height should be Greater than 0");
39+
PADDLE_ENFORCE_GT(x_dims[3], 0, "input Width should be Greater than 0");
40+
41+
PADDLE_ENFORCE_EQ(x_dims[1] % (blocksize * blocksize), 0,
42+
"input channel should be divisible of the square of "
43+
"SpaceToDepthOp blocksize");
44+
PADDLE_ENFORCE_EQ(x_dims[2] % (blocksize), 0,
45+
"input Height should be divisible of the square of "
46+
"SpaceToDepthOp blocksize");
47+
PADDLE_ENFORCE_EQ(x_dims[3] % (blocksize), 0,
48+
"input Width should be divisible of the square of "
49+
"SpaceToDepthOp blocksize");
50+
51+
VLOG(3) << "SpaceToDepthOp operator x.shape=" << x_dims
52+
<< "Attribute blocksize" << blocksize << std::endl;
53+
54+
std::vector<int64_t> output_shape(4, 0); // [B,C,H,W]
55+
output_shape[0] = x_dims[0];
56+
output_shape[1] = x_dims[1] * blocksize * blocksize;
57+
output_shape[2] = x_dims[2] / blocksize;
58+
output_shape[3] = x_dims[3] / blocksize;
59+
60+
auto out_dims = framework::make_ddim(output_shape);
61+
62+
ctx->SetOutputDim("Out", out_dims);
63+
64+
if (x_dims[0] == out_dims[0]) {
65+
// Only pass LoD when the first dimension of output and Input(X)
66+
// are the same.
67+
ctx->ShareLoD("X", /*->*/ "Out");
68+
}
69+
}
70+
};
71+
72+
class SpaceToDepthOpMaker : public framework::OpProtoAndCheckerMaker {
73+
public:
74+
void Make() override {
75+
AddInput("X",
76+
"(Tensor). The input should be a 4D tensor B * C * W * H of "
77+
"SpaceToDepthOp "
78+
"operator.");
79+
AddOutput("Out",
80+
"(Tensor), The output should be a 4D tensor B * C2 * W2 * H2 of "
81+
"SpaceToDepthOp operator.");
82+
AddAttr<int64_t>(
83+
"blocksize",
84+
"(int64_t, default 2) blocksize used to do change Space To Depth.")
85+
.SetDefault(2)
86+
.GreaterThan(1);
87+
AddComment(R"DOC(
88+
reorg operator used in Yolo v2.
89+
The equation is: C2 = C1/blocksize * blocksize, W2 = W1 ∗ blocksize + offset % blocksize, H2 = H1 ∗ blocksize + offset / blocksize,
90+
91+
Reshape Input(X) into the shape according to Attr(blocksize). The
92+
data in Input(X) are unchanged.
93+
94+
Examples:
95+
96+
1. Given a 4-D tensor Input(X) with a shape [128, 2048, 26, 26], and the blocksize is 2, the reorg operator will transform Input(X)
97+
into a 4-D tensor with shape [128, 2048, 13, 13] and leaving Input(X)'s data unchanged.
98+
99+
)DOC");
100+
}
101+
};
102+
103+
class SpaceToDepthGradOp : public framework::OperatorWithKernel {
104+
public:
105+
using framework::OperatorWithKernel::OperatorWithKernel;
106+
107+
void InferShape(framework::InferShapeContext* ctx) const override {
108+
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null.");
109+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
110+
"Input(Out@GRAD) shouldn't be null.");
111+
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
112+
}
113+
};
114+
} // namespace operators
115+
} // namespace paddle
116+
117+
namespace ops = paddle::operators;
118+
119+
REGISTER_OPERATOR(space_to_depth, ops::SpaceToDepthOp, ops::SpaceToDepthOpMaker,
120+
paddle::framework::DefaultGradOpDescMaker<true>);
121+
REGISTER_OPERATOR(space_to_depth_grad, ops::SpaceToDepthGradOp);
122+
REGISTER_OP_CPU_KERNEL(
123+
space_to_depth,
124+
ops::SpaceToDepthKernel<paddle::platform::CPUDeviceContext, float>,
125+
ops::SpaceToDepthKernel<paddle::platform::CPUDeviceContext, double>,
126+
ops::SpaceToDepthKernel<paddle::platform::CPUDeviceContext, int64_t>);
127+
REGISTER_OP_CPU_KERNEL(
128+
space_to_depth_grad,
129+
ops::SpaceToDepthGradKernel<paddle::platform::CPUDeviceContext, float>,
130+
ops::SpaceToDepthGradKernel<paddle::platform::CPUDeviceContext, double>,
131+
ops::SpaceToDepthGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/space_to_depth_op.h"
16+
17+
namespace plat = paddle::platform;
18+
namespace ops = paddle::operators;
19+
20+
REGISTER_OP_CUDA_KERNEL(
21+
space_to_depth,
22+
ops::SpaceToDepthKernel<paddle::platform::CUDADeviceContext, float>,
23+
ops::SpaceToDepthKernel<paddle::platform::CUDADeviceContext, double>,
24+
ops::SpaceToDepthKernel<paddle::platform::CUDADeviceContext, int64_t>);
25+
26+
REGISTER_OP_CUDA_KERNEL(
27+
space_to_depth_grad,
28+
ops::SpaceToDepthGradKernel<paddle::platform::CUDADeviceContext, float>,
29+
ops::SpaceToDepthGradKernel<paddle::platform::CUDADeviceContext, double>,
30+
ops::SpaceToDepthGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
#ifndef PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_
15+
#define PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_
16+
#endif // PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_
17+
18+
#include "paddle/fluid/framework/op_registry.h"
19+
#include "paddle/fluid/platform/for_range.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
template <typename T>
25+
class space_to_depth_compute {
26+
public:
27+
HOSTDEVICE space_to_depth_compute(const T *x, int64_t w, int64_t h, int64_t c,
28+
int64_t batch, int64_t blocksize,
29+
int64_t forward, T *out)
30+
: x_(x),
31+
w_(w),
32+
h_(h),
33+
c_(c),
34+
batch_(batch),
35+
blocksize_(blocksize),
36+
forward_(forward),
37+
out_(out) {}
38+
39+
HOSTDEVICE void operator()(int64_t in_index) {
40+
int64_t out_c = c_ / (blocksize_ * blocksize_);
41+
// calculate each dim position with index of tensor
42+
int64_t b = in_index / (c_ * h_ * w_);
43+
int64_t k = (in_index % (c_ * h_ * w_)) / (h_ * w_);
44+
int64_t j = ((in_index % (c_ * h_ * w_)) % (h_ * w_)) / w_;
45+
int64_t i = ((in_index % (c_ * h_ * w_)) % (h_ * w_)) % w_;
46+
47+
int64_t c2 = k % out_c;
48+
int64_t offset = k / out_c;
49+
int64_t w2 = i * blocksize_ + offset % blocksize_;
50+
int64_t h2 = j * blocksize_ + offset / blocksize_;
51+
int64_t out_index =
52+
w2 + w_ * blocksize_ * (h2 + h_ * blocksize_ * (c2 + out_c * b));
53+
if (forward_)
54+
out_[out_index] = x_[in_index];
55+
else
56+
out_[in_index] = x_[out_index];
57+
}
58+
59+
private:
60+
const T *x_;
61+
int64_t w_, h_, c_, batch_, blocksize_, forward_;
62+
T *out_;
63+
};
64+
65+
template <typename DeviceContext, typename T>
66+
class SpaceToDepthKernel : public framework::OpKernel<T> {
67+
public:
68+
void Compute(const framework::ExecutionContext &context) const override {
69+
auto *out = context.Output<framework::LoDTensor>("Out");
70+
auto *x = context.Input<framework::LoDTensor>("X");
71+
auto blocksize = context.Attr<int64_t>("blocksize");
72+
auto in_dims = x->dims();
73+
out->mutable_data(context.GetPlace(), x->type());
74+
75+
auto out_dims = out->dims();
76+
auto B = in_dims[0];
77+
auto C = in_dims[1];
78+
auto H = in_dims[2];
79+
auto W = in_dims[3];
80+
platform::ForRange<DeviceContext> for_range(
81+
context.template device_context<DeviceContext>(),
82+
static_cast<size_t>(x->numel()));
83+
84+
auto *x_data = x->data<T>();
85+
auto *out_data = out->data<T>();
86+
paddle::operators::space_to_depth_compute<T> computer(
87+
x_data, W, H, C, B, blocksize, 1, out_data);
88+
for_range(computer);
89+
90+
out->Resize(out_dims);
91+
}
92+
};
93+
94+
template <typename DeviceContext, typename T>
95+
class SpaceToDepthGradKernel : public framework::OpKernel<T> {
96+
public:
97+
void Compute(const framework::ExecutionContext &context) const override {
98+
auto *d_out =
99+
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
100+
auto *d_x =
101+
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
102+
auto blocksize = context.Attr<int64_t>("blocksize");
103+
auto in_dims = d_x->dims();
104+
d_x->mutable_data(context.GetPlace(), d_out->type());
105+
106+
auto B = in_dims[0];
107+
auto C = in_dims[1];
108+
auto H = in_dims[2];
109+
auto W = in_dims[3];
110+
111+
platform::ForRange<DeviceContext> for_range(
112+
context.template device_context<DeviceContext>(),
113+
static_cast<size_t>(d_x->numel()));
114+
115+
auto *dx_data = d_x->data<T>();
116+
auto *dout_data = d_out->data<T>();
117+
118+
paddle::operators::space_to_depth_compute<T> computer(
119+
dout_data, W, H, C, B, blocksize, 0, dx_data);
120+
for_range(computer);
121+
122+
d_x->Resize(in_dims);
123+
}
124+
};
125+
126+
} // namespace operators
127+
} // namespace paddle

python/paddle/fluid/layers/nn.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@
154154
'mul',
155155
'sigmoid_cross_entropy_with_logits',
156156
'maxout',
157+
'space_to_depth',
157158
'affine_grid',
158159
'sequence_reverse',
159160
'affine_channel',
@@ -7674,6 +7675,66 @@ def maxout(x, groups, name=None):
76747675
return out
76757676

76767677

7678+
def space_to_depth(x, blocksize, name=None):
7679+
"""
7680+
Gives a blocksize to space_to_depth the input LoDtensor with Layout: [batch, channel, height, width]
7681+
7682+
This op rearranges blocks of spatial data, into depth. More specifically, this op outputs a copy of the
7683+
input LoDtensor where values from the height and width dimensions are moved to the channel dimension.
7684+
The attr blocksize indicates the input block size.
7685+
7686+
space_to_depth will reorgnize the elements of input with shape[batch, channel, height, width] according
7687+
to blocksize to construct output with shape [batch, channel * blocksize * blocksize, height/blocksize, width/blocksize]:
7688+
7689+
space_to_depth is used to This operation is useful for resizing the activations between convolutions
7690+
(but keeping all data)
7691+
7692+
- Non-overlapping blocks of size block_size x block size are rearranged into depth at each location.
7693+
- The depth of the output tensor is block_size * block_size * input channel
7694+
- The Y, X coordinates within each block of the input become the high order component of the output channel index
7695+
- channel should be divisible by square of blocksize
7696+
- height, width should be divsible by blocksize
7697+
7698+
7699+
Args:
7700+
x(variable): The input LoDtensor.
7701+
blocksize(variable): The blocksize to select the element on each feature map should be > 2
7702+
7703+
Returns:
7704+
Variable: The output LoDtensor.
7705+
7706+
Raises:
7707+
TypeError: blocksize type must be a long.
7708+
7709+
Examples:
7710+
.. code-block:: python
7711+
7712+
data = fluid.layers.data(
7713+
name='data', shape=[1, 4, 2, 2], dtype='float32')
7714+
space_to_depthed = fluid.layers.space_to_depth(
7715+
x=data, blocksize=2)
7716+
"""
7717+
7718+
helper = LayerHelper("space_to_depth", **locals())
7719+
7720+
if not (isinstance(blocksize, int)):
7721+
raise ValueError("blocksize must be a python Int")
7722+
7723+
if name is None:
7724+
out = helper.create_variable_for_type_inference(
7725+
dtype=x.dtype) #fix create
7726+
else:
7727+
out = helper.create_variable(
7728+
name=name, dtype=x.dtype, persistable=False)
7729+
7730+
helper.append_op(
7731+
type="space_to_depth",
7732+
inputs={"X": x},
7733+
attrs={"blocksize": blocksize},
7734+
outputs={"Out": out})
7735+
return out
7736+
7737+
76777738
@templatedoc()
76787739
def sequence_reverse(x, name=None):
76797740
"""

python/paddle/fluid/op.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ def __call__(self, *args, **kwargs):
108108
new_attr.i = user_defined_attr
109109
elif attr.type == framework_pb2.FLOAT:
110110
new_attr.f = user_defined_attr
111+
elif attr.type == framework_pb2.LONG:
112+
new_attr.l = user_defined_attr
111113
elif attr.type == framework_pb2.STRING:
112114
new_attr.s = user_defined_attr
113115
elif attr.type == framework_pb2.BOOLEAN:

0 commit comments

Comments
 (0)