Skip to content

Commit dfbac60

Browse files
committed
Merge remote-tracking branch 'upstream/develop' into windows/build
2 parents 7c8c9dc + dd6fd4c commit dfbac60

File tree

16 files changed

+1447
-24
lines changed

16 files changed

+1447
-24
lines changed

AUTHORS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
| kexinzhao | Ke-Xin Zhao |
2626
| kuke | Yi-Bing Liu |
2727
| lcy-seso | Ying Cao |
28+
| cjld | Dun Liang |
2829
| lipeng-unisound | Peng Li |
2930
| liuyuan | Yuan Liu |
3031
| livc | Zhao Li |

paddle/fluid/API.spec

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 's
103103
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
104104
paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None)
105105
paddle.fluid.layers.layer_norm ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None))
106+
paddle.fluid.layers.group_norm ArgSpec(args=['input', 'groups', 'epsilon', 'param_attr', 'bias_attr', 'act', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None, 'NCHW', None))
106107
paddle.fluid.layers.softmax_with_cross_entropy ArgSpec(args=['logits', 'label', 'soft_label', 'ignore_index', 'numeric_stable_mode', 'return_softmax'], varargs=None, keywords=None, defaults=(False, -100, False, False))
107108
paddle.fluid.layers.smooth_l1 ArgSpec(args=['x', 'y', 'inside_weight', 'outside_weight', 'sigma'], varargs=None, keywords=None, defaults=(None, None, None))
108109
paddle.fluid.layers.one_hot ArgSpec(args=['input', 'depth'], varargs=None, keywords=None, defaults=None)
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/group_norm_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
using Tensor = framework::Tensor;
21+
using LoDTensor = framework::LoDTensor;
22+
using DataLayout = framework::DataLayout;
23+
24+
class GroupNormOp : public framework::OperatorWithKernel {
25+
public:
26+
using framework::OperatorWithKernel::OperatorWithKernel;
27+
28+
void InferShape(framework::InferShapeContext *ctx) const override {
29+
PADDLE_ENFORCE(ctx->HasInput("X"),
30+
"Input(X) of GroupNormOp should not be null.");
31+
PADDLE_ENFORCE(ctx->HasOutput("Y"),
32+
"Output(Y) of GroupNormOp should not be null.");
33+
PADDLE_ENFORCE(ctx->HasOutput("Mean"),
34+
"Output(Mean) of GroupNormOp should not be null.");
35+
PADDLE_ENFORCE(ctx->HasOutput("Variance"),
36+
"Output(Variance) of GroupNormOp should not be null.");
37+
38+
auto x_dim = ctx->GetInputDim("X");
39+
auto channel_num = x_dim[1];
40+
auto batch_size = x_dim[0];
41+
auto groups = ctx->Attrs().Get<int>("groups");
42+
PADDLE_ENFORCE_LE(
43+
groups, channel_num,
44+
"'groups' must be less equal than the number of channels.");
45+
PADDLE_ENFORCE_GE(groups, 1, "'groups' must be greater equal than 1.");
46+
47+
if (ctx->HasInput("Scale")) {
48+
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL);
49+
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], channel_num);
50+
}
51+
if (ctx->HasInput("Bias")) {
52+
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL);
53+
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], channel_num);
54+
}
55+
56+
ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
57+
ctx->SetOutputDim("Mean", {batch_size, groups});
58+
ctx->SetOutputDim("Variance", {batch_size, groups});
59+
ctx->ShareLoD("X", "Y");
60+
}
61+
};
62+
63+
class GroupNormOpMaker : public framework::OpProtoAndCheckerMaker {
64+
public:
65+
void Make() override {
66+
AddInput("X", "The input tensor.");
67+
AddInput("Scale",
68+
"Scale is a 1-dimensional tensor of size C"
69+
"that is applied to the output.")
70+
.AsDispensable();
71+
AddInput("Bias",
72+
"Bias is a 1-dimensional tensor of size C "
73+
"that is applied to the output")
74+
.AsDispensable();
75+
AddOutput("Y", "Result after normalization.");
76+
AddOutput("Mean", "Mean of each group.").AsIntermediate();
77+
AddOutput("Variance", "Variance of each group.").AsIntermediate();
78+
79+
AddAttr<float>("epsilon",
80+
"Constant for numerical stability [default 1e-5].")
81+
.SetDefault(1e-5)
82+
.AddCustomChecker([](const float &epsilon) {
83+
PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 1.0f,
84+
"'epsilon' should be between 0.0 and 1.0.");
85+
});
86+
AddAttr<int>("groups", "The number of groups that divided from channels.")
87+
.AddCustomChecker([](const int &groups) {
88+
PADDLE_ENFORCE_GT(groups, 0, "'groups' should be greater than zero.");
89+
});
90+
91+
AddComment(R"DOC(
92+
Group Normalization
93+
94+
Refer to `Group Normalization <https://arxiv.org/abs/1803.08494>`_
95+
)DOC");
96+
}
97+
};
98+
99+
class GroupNormGradOp : public framework::OperatorWithKernel {
100+
public:
101+
using framework::OperatorWithKernel::OperatorWithKernel;
102+
103+
void InferShape(framework::InferShapeContext *ctx) const override {
104+
// check input
105+
PADDLE_ENFORCE(ctx->HasInput("X"),
106+
"Input(X) of GroupNormOp should not be null.");
107+
PADDLE_ENFORCE(ctx->HasInput("Mean"),
108+
"Input(Mean) of GroupNormOp should not be null.");
109+
PADDLE_ENFORCE(ctx->HasInput("Variance"),
110+
"Input(Variance) of GroupNormOp should not be null.");
111+
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
112+
"Input(Y@GRAD) of GroupNormOp should not be null.");
113+
114+
// check output
115+
if (ctx->HasOutput(framework::GradVarName("X"))) {
116+
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
117+
}
118+
if (ctx->HasOutput(framework::GradVarName("Scale"))) {
119+
ctx->SetOutputDim(framework::GradVarName("Scale"),
120+
ctx->GetInputDim("Scale"));
121+
}
122+
if (ctx->HasOutput(framework::GradVarName("Bias"))) {
123+
ctx->SetOutputDim(framework::GradVarName("Bias"),
124+
ctx->GetInputDim("Bias"));
125+
}
126+
}
127+
128+
protected:
129+
framework::OpKernelType GetExpectedKernelType(
130+
const framework::ExecutionContext &ctx) const override {
131+
const auto *var = ctx.InputVar(framework::GradVarName("Y"));
132+
if (var == nullptr) {
133+
PADDLE_THROW("can't find Y@GRAD");
134+
}
135+
const Tensor *t = nullptr;
136+
if (var->IsType<Tensor>()) {
137+
t = &var->Get<Tensor>();
138+
} else if (var->IsType<LoDTensor>()) {
139+
t = &var->Get<LoDTensor>();
140+
}
141+
if (t == nullptr) {
142+
PADDLE_THROW("can't find Y@GRAD");
143+
}
144+
return framework::OpKernelType(framework::ToDataType(t->type()),
145+
ctx.GetPlace());
146+
}
147+
};
148+
149+
} // namespace operators
150+
} // namespace paddle
151+
152+
namespace ops = paddle::operators;
153+
REGISTER_OPERATOR(group_norm, ops::GroupNormOp, ops::GroupNormOpMaker,
154+
paddle::framework::DefaultGradOpDescMaker<true>);
155+
REGISTER_OPERATOR(group_norm_grad, ops::GroupNormGradOp);
156+
REGISTER_OP_CPU_KERNEL(
157+
group_norm, ops::GroupNormKernel<paddle::platform::CPUDeviceContext, float>,
158+
ops::GroupNormKernel<paddle::platform::CPUDeviceContext, double>);
159+
REGISTER_OP_CPU_KERNEL(
160+
group_norm_grad,
161+
ops::GroupNormGradKernel<paddle::platform::CPUDeviceContext, float>,
162+
ops::GroupNormGradKernel<paddle::platform::CPUDeviceContext, double>);

0 commit comments

Comments
 (0)