Skip to content

Commit 0f18fbf

Browse files
authored
cherry pick 15782,16763 (#16805)
* cherry-pick 16763,test=release/1.4 * cherry-pick 16763, test=release/1.4 * fix some comments, include cosine_decay,l2_normalize,pixel_shuffle * Add api.spec, test=develop * update api.spec, test=develop * add api.spec,test=develop * test=develop * test=develop * fix conflict,test=develop * Add Pixel shuffle OP (#15782) * add pixel_shuffle op * add pixel_shuffle op, test=develop * rewrite code, test=develop * delete useless comment, test=develop * Refine pixel_shuffle_op and unit testing * refine code,test=develop * refine .cu,test=develop * fix unittest,test=develop * Fix unit testing test=develop * resolve conflict, test=develop * fix test, test=develop * fix API, test=develop * fix test datatype bug,test=develop * polish comments,test=develop * add API,test=develop * test=develop * Add Pixel_Shuffle OP,test=develop * support python3,test=develop * add include memory to travis CI bug,test=develop * cherry-pick 16763,15782 , test=release/1.4
1 parent 60d4785 commit 0f18fbf

File tree

8 files changed

+376
-12
lines changed

8 files changed

+376
-12
lines changed

paddle/fluid/API.spec

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ paddle.fluid.layers.dropout (ArgSpec(args=['x', 'dropout_prob', 'is_test', 'seed
125125
paddle.fluid.layers.split (ArgSpec(args=['input', 'num_or_sections', 'dim', 'name'], varargs=None, keywords=None, defaults=(-1, None)), ('document', '652625345c2acb900029c78cc75f8aa6'))
126126
paddle.fluid.layers.ctc_greedy_decoder (ArgSpec(args=['input', 'blank', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'ebbf2adbd79683dc93db03454dfa18c2'))
127127
paddle.fluid.layers.edit_distance (ArgSpec(args=['input', 'label', 'normalized', 'ignored_tokens'], varargs=None, keywords=None, defaults=(True, None)), ('document', '97f0262f97602644c83142789d784571'))
128-
paddle.fluid.layers.l2_normalize (ArgSpec(args=['x', 'axis', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(1e-12, None)), ('document', '6e428384ce6a77207fa2c70d9f011990'))
128+
paddle.fluid.layers.l2_normalize (ArgSpec(args=['x', 'axis', 'epsilon', 'name'], varargs=None, keywords=None, defaults=(1e-12, None)), ('document', '35c6a241bcc1a1fc89508860d82ad62b'))
129129
paddle.fluid.layers.matmul (ArgSpec(args=['x', 'y', 'transpose_x', 'transpose_y', 'alpha', 'name'], varargs=None, keywords=None, defaults=(False, False, 1.0, None)), ('document', 'b4cbe1ac451005df6dad12e9ffdccca9'))
130130
paddle.fluid.layers.topk (ArgSpec(args=['input', 'k', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'd3570c02f71bcd78e60b3f31dc8f5b32'))
131131
paddle.fluid.layers.warpctc (ArgSpec(args=['input', 'label', 'blank', 'norm_by_times', 'use_cudnn'], varargs=None, keywords=None, defaults=(0, False, False)), ('document', 'aaba49c038ba927f0a8e45c0c9a686ab'))
@@ -236,6 +236,7 @@ paddle.fluid.layers.huber_loss (ArgSpec(args=['input', 'label', 'delta'], vararg
236236
paddle.fluid.layers.kldiv_loss (ArgSpec(args=['x', 'target', 'reduction', 'name'], varargs=None, keywords=None, defaults=('mean', None)), ('document', '776d536cac47c89073abc7ee524d5aec'))
237237
paddle.fluid.layers.tree_conv (ArgSpec(args=['nodes_vector', 'edge_set', 'output_size', 'num_filters', 'max_depth', 'act', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 2, 'tanh', None, None, None)), ('document', '34ea12ac9f10a65dccbc50100d12e607'))
238238
paddle.fluid.layers.npair_loss (ArgSpec(args=['anchor', 'positive', 'labels', 'l2_reg'], varargs=None, keywords=None, defaults=(0.002,)), ('document', '46994d10276dd4cb803b4062b5d14329'))
239+
paddle.fluid.layers.pixel_shuffle (ArgSpec(args=['x', 'upscale_factor'], varargs=None, keywords=None, defaults=None), ('document', '731b21c62a4add60a33bd76d802ffc5c'))
239240
paddle.fluid.layers.fsp_matrix (ArgSpec(args=['x', 'y'], varargs=None, keywords=None, defaults=None), ('document', 'b76ccca3735bea4a58a0dbf0d77c5393'))
240241
paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '33bbd42027d872b3818b3d64ec52e139'))
241242
paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'b1ae2e1cc0750e58726374061ea90ecc'))
@@ -361,7 +362,7 @@ paddle.fluid.layers.polynomial_decay (ArgSpec(args=['learning_rate', 'decay_step
361362
paddle.fluid.layers.piecewise_decay (ArgSpec(args=['boundaries', 'values'], varargs=None, keywords=None, defaults=None), ('document', 'c717d9d1d78a53c809d01b8bc56f3cae'))
362363
paddle.fluid.layers.noam_decay (ArgSpec(args=['d_model', 'warmup_steps'], varargs=None, keywords=None, defaults=None), ('document', 'd9a95746353fd574be36dc28d8726c28'))
363364
paddle.fluid.layers.append_LARS (ArgSpec(args=['params_grads', 'learning_rate', 'weight_decay'], varargs=None, keywords=None, defaults=None), ('document', 'd24fa1e7d62ac8a534fc6a86002f84f8'))
364-
paddle.fluid.layers.cosine_decay (ArgSpec(args=['learning_rate', 'step_each_epoch', 'epochs'], varargs=None, keywords=None, defaults=None), ('document', '9588c64c26ffaef3c466e404a6af9d9b'))
365+
paddle.fluid.layers.cosine_decay (ArgSpec(args=['learning_rate', 'step_each_epoch', 'epochs'], varargs=None, keywords=None, defaults=None), ('document', 'f8b2727bccf0f368c997d7cf05847e49'))
365366
paddle.fluid.layers.linear_lr_warmup (ArgSpec(args=['learning_rate', 'warmup_steps', 'start_lr', 'end_lr'], varargs=None, keywords=None, defaults=None), ('document', '2ef3f5ca5cd71ea4217c418e5a7a0565'))
366367
paddle.fluid.contrib.InitState.__init__ (ArgSpec(args=['self', 'init', 'shape', 'value', 'init_boot', 'need_reorder', 'dtype'], varargs=None, keywords=None, defaults=(None, None, 0.0, None, False, 'float32')), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
367368
paddle.fluid.contrib.StateCell.__init__ (ArgSpec(args=['self', 'inputs', 'states', 'out_state', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/*Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/fluid/operators/pixel_shuffle_op.h"
13+
#include <memory>
14+
15+
namespace paddle {
16+
namespace operators {
17+
18+
class PixelShuffleOp : public framework::OperatorWithKernel {
19+
public:
20+
using framework::OperatorWithKernel::OperatorWithKernel;
21+
22+
void InferShape(framework::InferShapeContext* ctx) const override {
23+
PADDLE_ENFORCE(ctx->HasInput("X"),
24+
"Input(X) of PixelShuffleOp should not be null.");
25+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
26+
"Output(Out) of PixelShuffleOp should not be null.");
27+
28+
auto input_dims = ctx->GetInputDim("X");
29+
PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW.");
30+
auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor");
31+
32+
PADDLE_ENFORCE(input_dims[1] % (upscale_factor * upscale_factor) == 0,
33+
"Upscale_factor should devide the number of channel");
34+
35+
auto output_dims = input_dims;
36+
output_dims[0] = input_dims[0];
37+
output_dims[1] = input_dims[1] / (upscale_factor * upscale_factor);
38+
output_dims[2] = input_dims[2] * upscale_factor;
39+
output_dims[3] = input_dims[3] * upscale_factor;
40+
ctx->SetOutputDim("Out", output_dims);
41+
}
42+
};
43+
44+
class PixelShuffleOpMaker : public framework::OpProtoAndCheckerMaker {
45+
public:
46+
void Make() override {
47+
AddInput(
48+
"X",
49+
"(Tensor, default Tensor<float>), "
50+
"the input feature data of PixelShuffleOp, the layout is [N C H W].");
51+
AddOutput(
52+
"Out",
53+
"(Tensor, default Tensor<float>), the output of "
54+
"PixelShuffleOp. The layout is [N,C/factor^2,H*factor,W*factor].");
55+
AddAttr<int>("upscale_factor",
56+
"the factor to increase spatial resolution by.")
57+
.SetDefault(1)
58+
.AddCustomChecker([](const int& upscale_factor) {
59+
PADDLE_ENFORCE_GE(upscale_factor, 1,
60+
"upscale_factor should be larger than 0.");
61+
});
62+
63+
AddComment(R"DOC(
64+
Pixel Shuffle operator
65+
This operator rearranges elements in a tensor of shape :math:`(*, C \times r^2, H, W)`
66+
to a tensor of shape :math:`(C, H \times r, W \times r)`.
67+
68+
This is useful for implementing efficient sub-pixel convolution
69+
with a stride of :math:`1/r`.
70+
71+
Please refer to the paper:
72+
`Real-Time Single Image and Video Super-Resolution Using an Efficient
73+
Sub-Pixel Convolutional Neural Network <https://arxiv.org/abs/1609.05158v2>`_
74+
by Shi et. al (2016) for more details.
75+
76+
)DOC");
77+
}
78+
};
79+
80+
class PixelShuffleGradMaker : public framework::SingleGradOpDescMaker {
81+
public:
82+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
83+
84+
std::unique_ptr<framework::OpDesc> Apply() const override {
85+
auto* op = new framework::OpDesc();
86+
op->SetType("pixel_shuffle_grad");
87+
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
88+
op->SetAttrMap(Attrs());
89+
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
90+
return std::unique_ptr<framework::OpDesc>(op);
91+
}
92+
};
93+
94+
class PixelShuffleGradOp : public framework::OperatorWithKernel {
95+
public:
96+
using framework::OperatorWithKernel::OperatorWithKernel;
97+
98+
void InferShape(framework::InferShapeContext* ctx) const override {
99+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
100+
"Input(Out@Grad) should not be null");
101+
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
102+
"Output(X@Grad) should not be null");
103+
104+
auto do_dims = ctx->GetInputDim(framework::GradVarName("Out"));
105+
PADDLE_ENFORCE(do_dims.size() == 4, "The layout of input is NCHW.");
106+
107+
auto upscale_factor = ctx->Attrs().Get<int>("upscale_factor");
108+
109+
auto dx_dims = do_dims;
110+
dx_dims[0] = do_dims[0];
111+
dx_dims[1] = do_dims[1] * (upscale_factor * upscale_factor);
112+
dx_dims[2] = do_dims[2] / upscale_factor;
113+
dx_dims[3] = do_dims[3] / upscale_factor;
114+
ctx->SetOutputDim(framework::GradVarName("X"), dx_dims);
115+
}
116+
};
117+
118+
} // namespace operators
119+
} // namespace paddle
120+
121+
namespace ops = paddle::operators;
122+
REGISTER_OPERATOR(pixel_shuffle, ops::PixelShuffleOp, ops::PixelShuffleOpMaker,
123+
ops::PixelShuffleGradMaker);
124+
125+
REGISTER_OPERATOR(pixel_shuffle_grad, ops::PixelShuffleGradOp);
126+
127+
REGISTER_OP_CPU_KERNEL(
128+
pixel_shuffle,
129+
ops::PixelShuffleOpKernel<paddle::platform::CPUDeviceContext, float>,
130+
ops::PixelShuffleOpKernel<paddle::platform::CPUDeviceContext, double>);
131+
132+
REGISTER_OP_CPU_KERNEL(
133+
pixel_shuffle_grad,
134+
ops::PixelShuffleGradOpKernel<paddle::platform::CPUDeviceContext, float>,
135+
ops::PixelShuffleGradOpKernel<paddle::platform::CPUDeviceContext, double>);
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/pixel_shuffle_op.h"
16+
17+
namespace ops = paddle::operators;
18+
namespace plat = paddle::platform;
19+
20+
REGISTER_OP_CUDA_KERNEL(
21+
pixel_shuffle, ops::PixelShuffleOpKernel<plat::CUDADeviceContext, float>,
22+
ops::PixelShuffleOpKernel<plat::CUDADeviceContext, double>);
23+
REGISTER_OP_CUDA_KERNEL(
24+
pixel_shuffle_grad,
25+
ops::PixelShuffleGradOpKernel<plat::CUDADeviceContext, float>,
26+
ops::PixelShuffleGradOpKernel<plat::CUDADeviceContext, double>);
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#pragma once
13+
#include <algorithm>
14+
#include <vector>
15+
#include "paddle/fluid/framework/op_registry.h"
16+
#include "paddle/fluid/operators/math/math_function.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
template <typename DeviceContext, typename T>
22+
class PixelShuffleOpKernel : public framework::OpKernel<T> {
23+
public:
24+
void Compute(const framework::ExecutionContext& ctx) const override {
25+
auto* in = ctx.Input<framework::Tensor>("X");
26+
auto* out = ctx.Output<framework::Tensor>("Out");
27+
out->mutable_data<T>(ctx.GetPlace());
28+
29+
int factor = ctx.Attr<int>("upscale_factor");
30+
31+
auto in_dims = in->dims();
32+
auto o_dims = out->dims();
33+
34+
framework::Tensor t;
35+
t.ShareDataWith(*in);
36+
t.Resize({in_dims[0], o_dims[1], factor, factor, in_dims[2], in_dims[3]});
37+
38+
std::vector<int> axis = {0, 1, 4, 2, 5, 3};
39+
40+
framework::Tensor o;
41+
o.ShareDataWith(*out);
42+
o.Resize({in_dims[0], o_dims[1], in_dims[2], factor, in_dims[3], factor});
43+
44+
math::Transpose<DeviceContext, T, 6> trans;
45+
auto& dev_ctx = ctx.template device_context<DeviceContext>();
46+
trans(dev_ctx, t, &o, axis);
47+
out->Resize(o_dims);
48+
}
49+
};
50+
51+
template <typename DeviceContext, typename T>
52+
class PixelShuffleGradOpKernel : public framework::OpKernel<T> {
53+
public:
54+
void Compute(const framework::ExecutionContext& ctx) const override {
55+
auto* dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
56+
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
57+
dx->mutable_data<T>(ctx.GetPlace());
58+
59+
int factor = ctx.Attr<int>("upscale_factor");
60+
61+
auto do_dims = dout->dims();
62+
auto dx_dims = dx->dims();
63+
64+
framework::Tensor t;
65+
t.ShareDataWith(*dout);
66+
t.Resize({do_dims[0], do_dims[1], dx_dims[2], factor, dx_dims[3], factor});
67+
68+
std::vector<int> axis = {0, 1, 3, 5, 2, 4};
69+
70+
framework::Tensor o;
71+
o.ShareDataWith(*dx);
72+
o.Resize({do_dims[0], do_dims[1], factor, factor, dx_dims[2], dx_dims[3]});
73+
74+
math::Transpose<DeviceContext, T, 6> trans;
75+
auto& dev_ctx = ctx.template device_context<DeviceContext>();
76+
trans(dev_ctx, t, &o, axis);
77+
dx->Resize(dx_dims);
78+
}
79+
};
80+
81+
} // namespace operators
82+
} // namespace paddle

python/paddle/fluid/layers/learning_rate_scheduler.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -349,24 +349,26 @@ def cosine_decay(learning_rate, step_each_epoch, epochs):
349349
training progresses. By using this function, the learning rate will be decayed by
350350
following cosine decay strategy.
351351
352-
decayed_lr = learning_rate * 0.5 * (math.cos(epoch * math.pi / epochs) + 1)
352+
.. math::
353+
354+
decayed\_lr = learning\_rate * 0.5 * (math.cos * (epoch * \\frac{math.pi}{epochs} ) + 1)
353355
354356
Args:
355357
learning_rate(Variable|float): The initial learning rate.
356358
step_each_epoch(int): the number of steps in an epoch.
357359
epochs(int): the number of epochs.
358360
359-
Returns:
360-
Variable: The decayed learning rate.
361-
362-
Examples:
361+
Returns:
362+
Variable: The decayed learning rate.
363363
364-
..code-block:: python
364+
Examples:
365+
.. code-block:: python
365366
366-
base_lr = 0.1
367-
lr = fluid.layers.cosine_decay(
368-
learning_rate = base_lr, step_each_epoch=10000, epochs=120)
367+
base_lr = 0.1
368+
lr = fluid.layers.cosine_decay(
369+
learning_rate = base_lr, step_each_epoch=10000, epochs=120)
369370
"""
371+
370372
with default_main_program()._lr_schedule_guard():
371373
if imperative_base.enabled():
372374
decay = imperate_lr.CosineDecay(learning_rate, step_each_epoch,

python/paddle/fluid/layers/nn.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@
191191
'kldiv_loss',
192192
'tree_conv',
193193
'npair_loss',
194+
'pixel_shuffle',
194195
'fsp_matrix',
195196
]
196197

@@ -4792,7 +4793,7 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None):
47924793
the dimension to normalization is rank(X) + axis. -1 is the
47934794
last dimension.
47944795
epsilon(float): The epsilon value is used to avoid division by zero, \
4795-
the defalut value is 1e-10.
4796+
the defalut value is 1e-12.
47964797
name(str|None): A name for this layer(optional). If set None, the layer \
47974798
will be named automatically.
47984799
@@ -10923,6 +10924,65 @@ def npair_loss(anchor, positive, labels, l2_reg=0.002):
1092310924
return l2loss + celoss
1092410925

1092510926

10927+
def pixel_shuffle(x, upscale_factor):
10928+
"""
10929+
10930+
**Pixel Shuffle Layer**
10931+
10932+
This layer rearranges elements in a tensor of shape [N, C, H, W]
10933+
to a tensor of shape [N, C/r**2, H*r, W*r].
10934+
This is useful for implementing efficient sub-pixel convolution
10935+
with a stride of 1/r.
10936+
Please refer to the paper: `Real-Time Single Image and Video Super-Resolution
10937+
Using an Efficient Sub-Pixel Convolutional Neural Network <https://arxiv.org/abs/1609.05158v2>`_ .
10938+
by Shi et. al (2016) for more details.
10939+
10940+
.. code-block:: text
10941+
10942+
Given a 4-D tensor with the shape:
10943+
x.shape = [1, 9, 4, 4]
10944+
Given upscale_factor:
10945+
upscale_factor= 3
10946+
output shape is:
10947+
[1, 1, 12, 12]
10948+
10949+
Args:
10950+
10951+
x(Variable): The input tensor variable.
10952+
upscale_factor(int): factor to increase spatial resolution
10953+
10954+
Returns:
10955+
10956+
Out(Variable): Reshaped tensor according to the new dimension.
10957+
10958+
Raises:
10959+
10960+
ValueError: If the square of upscale_factor cannot divide the channels of input.
10961+
10962+
Examples:
10963+
10964+
.. code-block:: python
10965+
10966+
input = fluid.layers.data(shape=[9,4,4])
10967+
output = fluid.layers.pixel_shuffle(x=input, upscale_factor=3)
10968+
10969+
"""
10970+
10971+
helper = LayerHelper("pixel_shuffle", **locals())
10972+
10973+
out = helper.create_variable_for_type_inference(dtype=x.dtype)
10974+
10975+
if not isinstance(upscale_factor, int):
10976+
raise TypeError("upscale factor must be int type")
10977+
10978+
helper.append_op(
10979+
type="pixel_shuffle",
10980+
inputs={"X": x},
10981+
outputs={"Out": out},
10982+
attrs={"upscale_factor": upscale_factor})
10983+
return out
10984+
10985+
1092610986
def fsp_matrix(x, y):
1092710987
"""
1092810988

0 commit comments

Comments
 (0)