@@ -34,20 +34,20 @@ class GroupNormComputeTest : public arena::TestCase {
34
34
DDim dims_{{4 , 5 , 19 , 19 }};
35
35
float epsilon_ = 1e-5f ;
36
36
int groups_ = 1 ;
37
- int channels_ = dims_[ 1 ] ;
37
+ std::string data_layout_str_ = " NCHW " ;
38
38
39
39
public:
40
40
GroupNormComputeTest (const Place& place,
41
41
const std::string& alias,
42
42
DDim dims,
43
43
float epsilon,
44
44
int groups,
45
- int channels )
45
+ std::string data_layout_str )
46
46
: TestCase(place, alias),
47
47
dims_ (dims),
48
48
epsilon_(epsilon),
49
49
groups_(groups),
50
- channels_(channels ) {}
50
+ data_layout_str_(data_layout_str ) {}
51
51
52
52
void RunBaseline (Scope* scope) override {
53
53
auto x = scope->FindTensor (x_);
@@ -59,7 +59,7 @@ class GroupNormComputeTest : public arena::TestCase {
59
59
CHECK (y);
60
60
CHECK (saved_mean);
61
61
CHECK (saved_variance);
62
- DDim saved_dim ({dims_[0 ] * groups_});
62
+ DDim saved_dim ({dims_[0 ], groups_});
63
63
y->Resize (dims_);
64
64
saved_mean->Resize (saved_dim);
65
65
saved_variance->Resize (saved_dim);
@@ -68,49 +68,82 @@ class GroupNormComputeTest : public arena::TestCase {
68
68
auto scale_data = scale->data <float >();
69
69
auto bias_data = bias->data <float >();
70
70
auto y_data = y->mutable_data <float >();
71
- auto saved_mean_data = saved_mean->mutable_data <float >();
72
- auto saved_variance_data = saved_variance->mutable_data <float >();
73
-
74
- int n = x->dims ()[0 ];
75
- int ch_per_group = channels_ / groups_;
76
- CHECK_EQ (x->dims ()[1 ], channels_);
77
- int spatial_size = ch_per_group * x->dims ()[2 ] * x->dims ()[3 ];
78
- // compute mean
79
- for (int i = 0 ; i < n * groups_; ++i) {
80
- const float * x_ptr = x_data + i * spatial_size;
81
- float sum = 0 .f ;
82
- for (int j = 0 ; j < spatial_size; ++j) {
83
- sum += x_ptr[j];
84
- }
85
- saved_mean_data[i] = sum / spatial_size;
86
- }
87
- // compute variance
88
- for (int i = 0 ; i < n * groups_; ++i) {
89
- const float * x_ptr = x_data + i * spatial_size;
90
- float sum = 0 .f ;
91
- for (int j = 0 ; j < spatial_size; ++j) {
92
- sum +=
93
- (x_ptr[j] - saved_mean_data[i]) * (x_ptr[j] - saved_mean_data[i]);
94
- }
95
- saved_variance_data[i] = 1 .f / sqrtf (sum / spatial_size + epsilon_);
96
- }
97
- int in_size = x->dims ()[2 ] * x->dims ()[3 ];
98
- // compute out
99
- for (int i = 0 ; i < n * groups_; ++i) {
100
- const float * x_ptr = x_data + i * spatial_size;
101
- float * y_ptr = y_data + i * spatial_size;
102
- int c_num = i % groups_;
103
- for (int c = 0 ; c < ch_per_group; c++) {
104
- int chin = c_num * ch_per_group + c;
105
- float scale_val = scale_data[chin];
106
- float bias_val = bias_data[chin];
107
- const float * x_ch_ptr = x_ptr + c * in_size;
108
- float * y_ch_ptr = y_ptr + c * in_size;
109
- for (int j = 0 ; j < in_size; j++) {
110
- y_ch_ptr[j] = scale_val * (x_ch_ptr[j] - saved_mean_data[i]) *
111
- saved_variance_data[i] +
112
- bias_val;
71
+ auto mean_data = saved_mean->mutable_data <float >();
72
+ auto var_data = saved_variance->mutable_data <float >();
73
+
74
+ auto x_dims = x->dims ();
75
+ int groups = groups_;
76
+ int channels =
77
+ (data_layout_str_ == " NCHW" ) ? x_dims[1 ] : x_dims[x_dims.size () - 1 ];
78
+ int group_size = (channels - 1 ) / groups + 1 ;
79
+ int imsize = (data_layout_str_ == " NCHW" ) ? (x_dims[2 ] * x_dims[3 ])
80
+ : (x_dims[1 ] * x_dims[2 ]);
81
+
82
+ auto * iter_x_data = x_data;
83
+ auto * iter_y_data = y_data;
84
+ for (int bid = 0 ; bid < x_dims[0 ]; bid++) {
85
+ for (int gid = 0 ; gid < groups; gid++) {
86
+ float x_mean = 0 ;
87
+ float x_var = 0 ;
88
+ int number =
89
+ std::min (group_size, static_cast <int >(channels - gid * group_size));
90
+ auto * tmp_x = iter_x_data;
91
+ auto * x_src_data = iter_x_data;
92
+ auto * tmp_y = iter_y_data;
93
+ auto * y_src_data = iter_y_data;
94
+
95
+ if (data_layout_str_ == " NCHW" ) {
96
+ for (int cid = 0 ; cid < number; cid++) {
97
+ for (int imid = 0 ; imid < imsize; imid++, iter_x_data++) {
98
+ x_mean += iter_x_data[0 ];
99
+ x_var += iter_x_data[0 ] * iter_x_data[0 ];
100
+ }
101
+ }
102
+ } else {
103
+ for (int cid = 0 ; cid < number; cid++) {
104
+ iter_x_data = tmp_x + cid;
105
+ for (int imid = 0 ; imid < imsize; imid++, iter_x_data += channels) {
106
+ x_mean += iter_x_data[0 ];
107
+ x_var += iter_x_data[0 ] * iter_x_data[0 ];
108
+ }
109
+ }
110
+ iter_x_data = tmp_x + group_size;
113
111
}
112
+
113
+ x_mean /= number * imsize;
114
+ x_var /= number * imsize;
115
+ x_var = x_var - x_mean * x_mean;
116
+ float var_inv = 1.0 / std::sqrt (x_var + epsilon_);
117
+ mean_data[bid * groups + gid] = x_mean;
118
+ var_data[bid * groups + gid] = x_var;
119
+
120
+ if (data_layout_str_ == " NCHW" ) {
121
+ for (int cid = 0 ; cid < number; cid++) {
122
+ for (int imid = 0 ; imid < imsize; imid++, tmp_x++, iter_y_data++) {
123
+ float val = (tmp_x[0 ] - x_mean) * var_inv;
124
+ if (scale_data) val *= scale_data[gid * group_size + cid];
125
+ if (bias_data) val += bias_data[gid * group_size + cid];
126
+ iter_y_data[0 ] = val;
127
+ }
128
+ }
129
+ } else {
130
+ for (int cid = 0 ; cid < number; cid++) {
131
+ tmp_x = x_src_data + cid;
132
+ iter_y_data = y_src_data + cid;
133
+ for (int imid = 0 ; imid < imsize;
134
+ imid++, tmp_x += channels, iter_y_data += channels) {
135
+ float val = (tmp_x[0 ] - x_mean) * var_inv;
136
+ if (scale_data) val *= scale_data[gid * group_size + cid];
137
+ if (bias_data) val += bias_data[gid * group_size + cid];
138
+ iter_y_data[0 ] = val;
139
+ }
140
+ }
141
+ iter_y_data = tmp_y + group_size;
142
+ }
143
+ }
144
+ if (data_layout_str_ == " NCHW" ) {
145
+ iter_x_data = x_data + (bid + 1 ) * channels * imsize;
146
+ iter_y_data = y_data + (bid + 1 ) * channels * imsize;
114
147
}
115
148
}
116
149
}
@@ -125,7 +158,7 @@ class GroupNormComputeTest : public arena::TestCase {
125
158
op_desc->SetOutput (" Variance" , {saved_variance_});
126
159
op_desc->SetAttr (" epsilon" , epsilon_);
127
160
op_desc->SetAttr (" groups" , groups_);
128
- op_desc->SetAttr (" channels " , channels_ );
161
+ op_desc->SetAttr (" data_layout " , data_layout_str_ );
129
162
}
130
163
131
164
void PrepareData () override {
@@ -148,7 +181,7 @@ void TestGroupNorm(Place place,
148
181
float abs_error = 6e-5 ,
149
182
std::vector<std::string> ignored_outs = {}) {
150
183
for (auto & n : {1 , 3 , 16 }) {
151
- for (auto & c : {1 }) {
184
+ for (auto & c : {1 , 2 }) {
152
185
for (auto & h : {1 , 16 , 33 , 56 }) {
153
186
for (auto & w : {1 , 17 , 55 }) {
154
187
for (auto & groups : {1 , 2 , 4 }) {
@@ -158,7 +191,7 @@ void TestGroupNorm(Place place,
158
191
DDim dim_in ({n, c, h, w});
159
192
float epsilon = 1e-5f ;
160
193
std::unique_ptr<arena::TestCase> tester (new GroupNormComputeTest (
161
- place, " def" , dim_in, epsilon, groups, c ));
194
+ place, " def" , dim_in, epsilon, groups, " NCHW " ));
162
195
#ifdef LITE_WITH_ARM
163
196
if (place == TARGET (kARM )) {
164
197
auto & ctx = tester->context ()->As <ARMContext>();
0 commit comments