@@ -30,7 +30,7 @@ void group_norm(
3030 int64_t sC ,
3131 int64_t sHxW ,
3232 int64_t group,
33- CTYPE eps,
33+ double eps,
3434 Tensor& out,
3535 Tensor& mean,
3636 Tensor& rstd) {
@@ -77,37 +77,43 @@ void group_norm(
7777 const CTYPE* x = input_data + i * inner_size;
7878
7979 // compute E[X] and Var[x] = E[x^2] - E[x]^2
80- CTYPE sum = reduce_add (x, inner_size);
81- CTYPE sq_sum = vec_powerf (x, inner_size);
82- CTYPE mean_value = sum / inner_size;
83- CTYPE variance = sq_sum / inner_size - mean_value * mean_value;
84- CTYPE std = std::sqrt (variance + eps);
85- CTYPE rstd_value = 1.0 / std;
80+ CTYPE sum = reduce_add (x, static_cast <CTYPE>(inner_size));
81+ CTYPE sq_sum = vec_powerf (x, static_cast <CTYPE>(inner_size));
82+ double mean_value =
83+ static_cast <double >(sum) / static_cast <double >(inner_size);
84+ double variance =
85+ static_cast <double >(sq_sum) / static_cast <double >(inner_size) -
86+ mean_value * mean_value;
87+ double std = std::sqrt (variance + eps);
88+ double rstd_value = 1.0 / std;
8689
8790 // Calculate the elements of output
8891 if (weight_data == nullptr && bias_data == nullptr ) {
8992 CTYPE* y = out_data + i * inner_size;
9093 for (const auto j : c10::irange (inner_size)) {
91- y[j] = (x[j] - mean_value) * rstd_value;
94+ y[j] = static_cast <CTYPE>(
95+ (static_cast <double >(x[j]) - mean_value) * rstd_value);
9296 }
9397 } else {
9498 const size_t g = i % G;
9599 for (const auto j : c10::irange (D)) {
96100 const size_t ch = g * D + j;
97- const CTYPE scale =
98- rstd_value * (weight_data == nullptr ? 1.0 : weight_data[ch]);
99- const CTYPE beta =
100- -scale * mean_value + (bias_data == nullptr ? 0.0 : bias_data[ch]);
101+ const double scale = rstd_value *
102+ (weight_data == nullptr ? double (1.0 )
103+ : static_cast <double >(weight_data[ch]));
104+ const double beta = -scale * mean_value +
105+ (bias_data == nullptr ? double (0.0 )
106+ : static_cast <double >(bias_data[ch]));
101107 x = input_data + (i * D + j) * HxW;
102108 CTYPE* y = out_data + (i * D + j) * HxW;
103109 for (const auto k : c10::irange (HxW)) {
104- y[k] = scale * x[k] + beta;
110+ y[k] = static_cast <CTYPE>( scale * static_cast < double >( x[k]) + beta) ;
105111 }
106112 }
107113 }
108114
109- mean_data[i] = mean_value;
110- rstd_data[i] = rstd_value;
115+ mean_data[i] = static_cast <CTYPE>( mean_value) ;
116+ rstd_data[i] = static_cast <CTYPE>( rstd_value) ;
111117 }
112118}
113119
@@ -186,7 +192,7 @@ std::tuple<Tensor&, Tensor&, Tensor&> native_group_norm_out(
186192
187193 constexpr auto name = " native_group_norm.out" ;
188194
189- ET_SWITCH_FLOAT_TYPES (input.scalar_type (), ctx, name, CTYPE, [&]() {
195+ ET_SWITCH_FLOATHBF16_TYPES (input.scalar_type (), ctx, name, CTYPE, [&]() {
190196 group_norm<CTYPE>(
191197 input, weight, bias, N, C, HxW, group, eps, out, mean_out, rstd_out);
192198 });
0 commit comments