@@ -30,7 +30,7 @@ void group_norm(
30
30
int64_t sC ,
31
31
int64_t sHxW ,
32
32
int64_t group,
33
- CTYPE eps,
33
+ double eps,
34
34
Tensor& out,
35
35
Tensor& mean,
36
36
Tensor& rstd) {
@@ -77,37 +77,43 @@ void group_norm(
77
77
const CTYPE* x = input_data + i * inner_size;
78
78
79
79
// compute E[X] and Var[x] = E[x^2] - E[x]^2
80
- CTYPE sum = reduce_add (x, inner_size);
81
- CTYPE sq_sum = vec_powerf (x, inner_size);
82
- CTYPE mean_value = sum / inner_size;
83
- CTYPE variance = sq_sum / inner_size - mean_value * mean_value;
84
- CTYPE std = std::sqrt (variance + eps);
85
- CTYPE rstd_value = 1.0 / std;
80
+ CTYPE sum = reduce_add (x, static_cast <CTYPE>(inner_size));
81
+ CTYPE sq_sum = vec_powerf (x, static_cast <CTYPE>(inner_size));
82
+ double mean_value =
83
+ static_cast <double >(sum) / static_cast <double >(inner_size);
84
+ double variance =
85
+ static_cast <double >(sq_sum) / static_cast <double >(inner_size) -
86
+ mean_value * mean_value;
87
+ double std = std::sqrt (variance + eps);
88
+ double rstd_value = 1.0 / std;
86
89
87
90
// Calculate the elements of output
88
91
if (weight_data == nullptr && bias_data == nullptr ) {
89
92
CTYPE* y = out_data + i * inner_size;
90
93
for (const auto j : c10::irange (inner_size)) {
91
- y[j] = (x[j] - mean_value) * rstd_value;
94
+ y[j] = static_cast <CTYPE>(
95
+ (static_cast <double >(x[j]) - mean_value) * rstd_value);
92
96
}
93
97
} else {
94
98
const size_t g = i % G;
95
99
for (const auto j : c10::irange (D)) {
96
100
const size_t ch = g * D + j;
97
- const CTYPE scale =
98
- rstd_value * (weight_data == nullptr ? 1.0 : weight_data[ch]);
99
- const CTYPE beta =
100
- -scale * mean_value + (bias_data == nullptr ? 0.0 : bias_data[ch]);
101
+ const double scale = rstd_value *
102
+ (weight_data == nullptr ? double (1.0 )
103
+ : static_cast <double >(weight_data[ch]));
104
+ const double beta = -scale * mean_value +
105
+ (bias_data == nullptr ? double (0.0 )
106
+ : static_cast <double >(bias_data[ch]));
101
107
x = input_data + (i * D + j) * HxW;
102
108
CTYPE* y = out_data + (i * D + j) * HxW;
103
109
for (const auto k : c10::irange (HxW)) {
104
- y[k] = scale * x[k] + beta;
110
+ y[k] = static_cast <CTYPE>( scale * static_cast < double >( x[k]) + beta) ;
105
111
}
106
112
}
107
113
}
108
114
109
- mean_data[i] = mean_value;
110
- rstd_data[i] = rstd_value;
115
+ mean_data[i] = static_cast <CTYPE>( mean_value) ;
116
+ rstd_data[i] = static_cast <CTYPE>( rstd_value) ;
111
117
}
112
118
}
113
119
@@ -186,7 +192,7 @@ std::tuple<Tensor&, Tensor&, Tensor&> native_group_norm_out(
186
192
187
193
constexpr auto name = " native_group_norm.out" ;
188
194
189
- ET_SWITCH_FLOAT_TYPES (input.scalar_type (), ctx, name, CTYPE, [&]() {
195
+ ET_SWITCH_FLOATHBF16_TYPES (input.scalar_type (), ctx, name, CTYPE, [&]() {
190
196
group_norm<CTYPE>(
191
197
input, weight, bias, N, C, HxW, group, eps, out, mean_out, rstd_out);
192
198
});
0 commit comments