Skip to content

Commit 9709e4f

Browse files
[ARM]fix group_norm compute error when compared with paddle (#5683) (#5701)
* fix group_norm compute error when compared with paddle. test=develop
1 parent 82849f2 commit 9709e4f

File tree

4 files changed

+118
-63
lines changed

4 files changed

+118
-63
lines changed

lite/kernels/arm/group_norm_compute.cc

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,23 @@ void GroupNormCompute::Run() {
3535
float epsilon = param.epsilon;
3636
int groups = param.groups;
3737
int channels = param.channels;
38-
int n = param.x->dims()[0];
39-
int c = param.x->dims()[1];
38+
auto x_dims = param.x->dims();
39+
int n = x_dims[0];
40+
int c = x_dims[1];
41+
if (channels == -1) {
42+
CHECK_EQ(param.data_layout_str, "NCHW")
43+
<< "it only support NCHW layout!, but recived layout is "
44+
<< param.data_layout_str;
45+
channels = c;
46+
}
47+
int height = x_dims[2];
48+
int width = x_dims[3];
4049
int ch_per_group = channels / groups;
41-
int height = param.x->dims()[2];
42-
int width = param.x->dims()[3];
4350
int spatial_size = ch_per_group * height * width;
4451
int ngroup = n * groups;
4552
int cnt = spatial_size >> 4;
4653
int remain = spatial_size % 16;
54+
float* std_vec = new float[param.saved_variance->numel()];
4755
// compute saved_mean and saved_variance
4856
#pragma omp parallel for
4957
for (int n = 0; n < ngroup; ++n) {
@@ -103,7 +111,8 @@ void GroupNormCompute::Run() {
103111
float variance = (summ - mean * mean * spatial_size) / spatial_size;
104112
float std = 1.f / sqrtf(variance + epsilon);
105113
saved_mean[n] = mean;
106-
saved_variance[n] = std;
114+
saved_variance[n] = variance;
115+
std_vec[n] = std;
107116
}
108117
int in_size = height * width;
109118
cnt = in_size >> 4;
@@ -117,7 +126,7 @@ void GroupNormCompute::Run() {
117126
numc *= ch_per_group;
118127
for (int c = 0; c < ch_per_group; c++) {
119128
int chin = numc + c;
120-
const float sstd_val = scale[chin] * saved_variance[i];
129+
const float sstd_val = scale[chin] * std_vec[i];
121130
const float bias_val = bias[chin];
122131
const float mean_val = saved_mean[i];
123132
const float32x4_t vsstd = vdupq_n_f32(sstd_val);
@@ -158,6 +167,7 @@ void GroupNormCompute::Run() {
158167
}
159168
}
160169
}
170+
delete[] std_vec;
161171
}
162172

163173
} // namespace arm

lite/operators/group_norm_op.cc

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,27 +34,35 @@ bool GroupNormOp::CheckShape() const {
3434
auto scale_dims = param_.scale->dims();
3535
auto bias_dims = param_.bias->dims();
3636
if (param_.channels == -1) {
37-
param_.channels = x_dims[1];
37+
param_.channels = (param_.data_layout_str == "NCHW")
38+
? x_dims[1]
39+
: x_dims[x_dims.size() - 1];
3840
}
41+
// only support NCHW
42+
CHECK_EQ(param_.data_layout_str, "NCHW") << "data_layout must be NCHW";
3943
CHECK(x_dims.size() >= 2 && x_dims.size() <= 5)
4044
<< "Input X must have 2 to 5 dimensions.";
4145
CHECK_EQ(scale_dims.size(), 1UL) << "Input Scale must have 1 dimensions.";
4246
CHECK_EQ(bias_dims.size(), 1UL) << "Input Bias must have 1 dimensions.";
4347
CHECK_GT(param_.epsilon, 0.f) << "epsilon should be greater than 0.f";
4448
CHECK_LT(param_.epsilon, 0.01f) << "epsilon should be less than 0.01f";
45-
CHECK_EQ(param_.channels, x_dims[1])
46-
<< "Input channels must be equal input_shape[1]";
47-
CHECK_EQ(param_.channels % param_.groups, 0)
48-
<< "channels must be divide groups";
49+
CHECK_LE(param_.groups, param_.channels)
50+
<< "groups should be less than channels";
51+
CHECK_GE(param_.groups, 1) << "groups should be greater than 1";
52+
CHECK_EQ(param_.channels, scale_dims[0])
53+
<< "The Input(Scale)'s first dimension size of Op(group_norm) must be "
54+
"equal to the number of channels";
55+
CHECK_EQ(param_.channels, bias_dims[0])
56+
<< "The Input(Bias)'s first dimension size of Op(group_norm) must be "
57+
"equal to the number of channels";
4958
return true;
5059
}
5160

5261
bool GroupNormOp::InferShapeImpl() const {
5362
auto x_dims = param_.x->dims();
5463
int64_t batch_size = x_dims[0];
55-
int64_t num = param_.channels / param_.groups;
56-
param_.saved_mean->Resize({batch_size * num});
57-
param_.saved_variance->Resize({batch_size * num});
64+
param_.saved_mean->Resize({batch_size, param_.groups});
65+
param_.saved_variance->Resize({batch_size, param_.groups});
5866
param_.out->Resize(x_dims);
5967
return true;
6068
}
@@ -82,6 +90,9 @@ bool GroupNormOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
8290
}
8391
param_.out =
8492
scope->FindVar(op_desc.Output("Y").front())->GetMutable<Tensor>();
93+
if (op_desc.HasAttr("data_layout")) {
94+
param_.data_layout_str = op_desc.GetAttr<std::string>("data_layout");
95+
}
8596
param_.epsilon = op_desc.GetAttr<float>("epsilon");
8697
param_.groups = op_desc.GetAttr<int>("groups");
8798
if (op_desc.HasAttr("channels")) {

lite/operators/op_params.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1750,6 +1750,7 @@ struct GroupNormParam : ParamBase {
17501750
lite::Tensor* scale{};
17511751
lite::Tensor* saved_mean{};
17521752
lite::Tensor* saved_variance{};
1753+
std::string data_layout_str{"NCHW"};
17531754
float epsilon;
17541755
int groups;
17551756
int channels;

lite/tests/kernels/group_norm_compute_test.cc

Lines changed: 82 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -34,20 +34,20 @@ class GroupNormComputeTest : public arena::TestCase {
3434
DDim dims_{{4, 5, 19, 19}};
3535
float epsilon_ = 1e-5f;
3636
int groups_ = 1;
37-
int channels_ = dims_[1];
37+
std::string data_layout_str_ = "NCHW";
3838

3939
public:
4040
GroupNormComputeTest(const Place& place,
4141
const std::string& alias,
4242
DDim dims,
4343
float epsilon,
4444
int groups,
45-
int channels)
45+
std::string data_layout_str)
4646
: TestCase(place, alias),
4747
dims_(dims),
4848
epsilon_(epsilon),
4949
groups_(groups),
50-
channels_(channels) {}
50+
data_layout_str_(data_layout_str) {}
5151

5252
void RunBaseline(Scope* scope) override {
5353
auto x = scope->FindTensor(x_);
@@ -59,7 +59,7 @@ class GroupNormComputeTest : public arena::TestCase {
5959
CHECK(y);
6060
CHECK(saved_mean);
6161
CHECK(saved_variance);
62-
DDim saved_dim({dims_[0] * groups_});
62+
DDim saved_dim({dims_[0], groups_});
6363
y->Resize(dims_);
6464
saved_mean->Resize(saved_dim);
6565
saved_variance->Resize(saved_dim);
@@ -68,49 +68,82 @@ class GroupNormComputeTest : public arena::TestCase {
6868
auto scale_data = scale->data<float>();
6969
auto bias_data = bias->data<float>();
7070
auto y_data = y->mutable_data<float>();
71-
auto saved_mean_data = saved_mean->mutable_data<float>();
72-
auto saved_variance_data = saved_variance->mutable_data<float>();
73-
74-
int n = x->dims()[0];
75-
int ch_per_group = channels_ / groups_;
76-
CHECK_EQ(x->dims()[1], channels_);
77-
int spatial_size = ch_per_group * x->dims()[2] * x->dims()[3];
78-
// compute mean
79-
for (int i = 0; i < n * groups_; ++i) {
80-
const float* x_ptr = x_data + i * spatial_size;
81-
float sum = 0.f;
82-
for (int j = 0; j < spatial_size; ++j) {
83-
sum += x_ptr[j];
84-
}
85-
saved_mean_data[i] = sum / spatial_size;
86-
}
87-
// compute variance
88-
for (int i = 0; i < n * groups_; ++i) {
89-
const float* x_ptr = x_data + i * spatial_size;
90-
float sum = 0.f;
91-
for (int j = 0; j < spatial_size; ++j) {
92-
sum +=
93-
(x_ptr[j] - saved_mean_data[i]) * (x_ptr[j] - saved_mean_data[i]);
94-
}
95-
saved_variance_data[i] = 1.f / sqrtf(sum / spatial_size + epsilon_);
96-
}
97-
int in_size = x->dims()[2] * x->dims()[3];
98-
// compute out
99-
for (int i = 0; i < n * groups_; ++i) {
100-
const float* x_ptr = x_data + i * spatial_size;
101-
float* y_ptr = y_data + i * spatial_size;
102-
int c_num = i % groups_;
103-
for (int c = 0; c < ch_per_group; c++) {
104-
int chin = c_num * ch_per_group + c;
105-
float scale_val = scale_data[chin];
106-
float bias_val = bias_data[chin];
107-
const float* x_ch_ptr = x_ptr + c * in_size;
108-
float* y_ch_ptr = y_ptr + c * in_size;
109-
for (int j = 0; j < in_size; j++) {
110-
y_ch_ptr[j] = scale_val * (x_ch_ptr[j] - saved_mean_data[i]) *
111-
saved_variance_data[i] +
112-
bias_val;
71+
auto mean_data = saved_mean->mutable_data<float>();
72+
auto var_data = saved_variance->mutable_data<float>();
73+
74+
auto x_dims = x->dims();
75+
int groups = groups_;
76+
int channels =
77+
(data_layout_str_ == "NCHW") ? x_dims[1] : x_dims[x_dims.size() - 1];
78+
int group_size = (channels - 1) / groups + 1;
79+
int imsize = (data_layout_str_ == "NCHW") ? (x_dims[2] * x_dims[3])
80+
: (x_dims[1] * x_dims[2]);
81+
82+
auto* iter_x_data = x_data;
83+
auto* iter_y_data = y_data;
84+
for (int bid = 0; bid < x_dims[0]; bid++) {
85+
for (int gid = 0; gid < groups; gid++) {
86+
float x_mean = 0;
87+
float x_var = 0;
88+
int number =
89+
std::min(group_size, static_cast<int>(channels - gid * group_size));
90+
auto* tmp_x = iter_x_data;
91+
auto* x_src_data = iter_x_data;
92+
auto* tmp_y = iter_y_data;
93+
auto* y_src_data = iter_y_data;
94+
95+
if (data_layout_str_ == "NCHW") {
96+
for (int cid = 0; cid < number; cid++) {
97+
for (int imid = 0; imid < imsize; imid++, iter_x_data++) {
98+
x_mean += iter_x_data[0];
99+
x_var += iter_x_data[0] * iter_x_data[0];
100+
}
101+
}
102+
} else {
103+
for (int cid = 0; cid < number; cid++) {
104+
iter_x_data = tmp_x + cid;
105+
for (int imid = 0; imid < imsize; imid++, iter_x_data += channels) {
106+
x_mean += iter_x_data[0];
107+
x_var += iter_x_data[0] * iter_x_data[0];
108+
}
109+
}
110+
iter_x_data = tmp_x + group_size;
113111
}
112+
113+
x_mean /= number * imsize;
114+
x_var /= number * imsize;
115+
x_var = x_var - x_mean * x_mean;
116+
float var_inv = 1.0 / std::sqrt(x_var + epsilon_);
117+
mean_data[bid * groups + gid] = x_mean;
118+
var_data[bid * groups + gid] = x_var;
119+
120+
if (data_layout_str_ == "NCHW") {
121+
for (int cid = 0; cid < number; cid++) {
122+
for (int imid = 0; imid < imsize; imid++, tmp_x++, iter_y_data++) {
123+
float val = (tmp_x[0] - x_mean) * var_inv;
124+
if (scale_data) val *= scale_data[gid * group_size + cid];
125+
if (bias_data) val += bias_data[gid * group_size + cid];
126+
iter_y_data[0] = val;
127+
}
128+
}
129+
} else {
130+
for (int cid = 0; cid < number; cid++) {
131+
tmp_x = x_src_data + cid;
132+
iter_y_data = y_src_data + cid;
133+
for (int imid = 0; imid < imsize;
134+
imid++, tmp_x += channels, iter_y_data += channels) {
135+
float val = (tmp_x[0] - x_mean) * var_inv;
136+
if (scale_data) val *= scale_data[gid * group_size + cid];
137+
if (bias_data) val += bias_data[gid * group_size + cid];
138+
iter_y_data[0] = val;
139+
}
140+
}
141+
iter_y_data = tmp_y + group_size;
142+
}
143+
}
144+
if (data_layout_str_ == "NCHW") {
145+
iter_x_data = x_data + (bid + 1) * channels * imsize;
146+
iter_y_data = y_data + (bid + 1) * channels * imsize;
114147
}
115148
}
116149
}
@@ -125,7 +158,7 @@ class GroupNormComputeTest : public arena::TestCase {
125158
op_desc->SetOutput("Variance", {saved_variance_});
126159
op_desc->SetAttr("epsilon", epsilon_);
127160
op_desc->SetAttr("groups", groups_);
128-
op_desc->SetAttr("channels", channels_);
161+
op_desc->SetAttr("data_layout", data_layout_str_);
129162
}
130163

131164
void PrepareData() override {
@@ -148,7 +181,7 @@ void TestGroupNorm(Place place,
148181
float abs_error = 6e-5,
149182
std::vector<std::string> ignored_outs = {}) {
150183
for (auto& n : {1, 3, 16}) {
151-
for (auto& c : {1}) {
184+
for (auto& c : {1, 2}) {
152185
for (auto& h : {1, 16, 33, 56}) {
153186
for (auto& w : {1, 17, 55}) {
154187
for (auto& groups : {1, 2, 4}) {
@@ -158,7 +191,7 @@ void TestGroupNorm(Place place,
158191
DDim dim_in({n, c, h, w});
159192
float epsilon = 1e-5f;
160193
std::unique_ptr<arena::TestCase> tester(new GroupNormComputeTest(
161-
place, "def", dim_in, epsilon, groups, c));
194+
place, "def", dim_in, epsilon, groups, "NCHW"));
162195
#ifdef LITE_WITH_ARM
163196
if (place == TARGET(kARM)) {
164197
auto& ctx = tester->context()->As<ARMContext>();

0 commit comments

Comments
 (0)