Skip to content

Commit ae07cb6

Browse files
Support Half/BFloat16 in native_group_norm
Differential Revision: D81351792 Pull Request resolved: #13823
1 parent ab44d06 commit ae07cb6

File tree

2 files changed

+335
-120
lines changed

2 files changed

+335
-120
lines changed

kernels/portable/cpu/op_native_group_norm.cpp

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ void group_norm(
3030
int64_t sC,
3131
int64_t sHxW,
3232
int64_t group,
33-
CTYPE eps,
33+
double eps,
3434
Tensor& out,
3535
Tensor& mean,
3636
Tensor& rstd) {
@@ -77,37 +77,43 @@ void group_norm(
7777
const CTYPE* x = input_data + i * inner_size;
7878

7979
// compute E[X] and Var[x] = E[x^2] - E[x]^2
80-
CTYPE sum = reduce_add(x, inner_size);
81-
CTYPE sq_sum = vec_powerf(x, inner_size);
82-
CTYPE mean_value = sum / inner_size;
83-
CTYPE variance = sq_sum / inner_size - mean_value * mean_value;
84-
CTYPE std = std::sqrt(variance + eps);
85-
CTYPE rstd_value = 1.0 / std;
80+
CTYPE sum = reduce_add(x, static_cast<CTYPE>(inner_size));
81+
CTYPE sq_sum = vec_powerf(x, static_cast<CTYPE>(inner_size));
82+
double mean_value =
83+
static_cast<double>(sum) / static_cast<double>(inner_size);
84+
double variance =
85+
static_cast<double>(sq_sum) / static_cast<double>(inner_size) -
86+
mean_value * mean_value;
87+
double std = std::sqrt(variance + eps);
88+
double rstd_value = 1.0 / std;
8689

8790
// Calculate the elements of output
8891
if (weight_data == nullptr && bias_data == nullptr) {
8992
CTYPE* y = out_data + i * inner_size;
9093
for (const auto j : c10::irange(inner_size)) {
91-
y[j] = (x[j] - mean_value) * rstd_value;
94+
y[j] = static_cast<CTYPE>(
95+
(static_cast<double>(x[j]) - mean_value) * rstd_value);
9296
}
9397
} else {
9498
const size_t g = i % G;
9599
for (const auto j : c10::irange(D)) {
96100
const size_t ch = g * D + j;
97-
const CTYPE scale =
98-
rstd_value * (weight_data == nullptr ? 1.0 : weight_data[ch]);
99-
const CTYPE beta =
100-
-scale * mean_value + (bias_data == nullptr ? 0.0 : bias_data[ch]);
101+
const double scale = rstd_value *
102+
(weight_data == nullptr ? double(1.0)
103+
: static_cast<double>(weight_data[ch]));
104+
const double beta = -scale * mean_value +
105+
(bias_data == nullptr ? double(0.0)
106+
: static_cast<double>(bias_data[ch]));
101107
x = input_data + (i * D + j) * HxW;
102108
CTYPE* y = out_data + (i * D + j) * HxW;
103109
for (const auto k : c10::irange(HxW)) {
104-
y[k] = scale * x[k] + beta;
110+
y[k] = static_cast<CTYPE>(scale * static_cast<double>(x[k]) + beta);
105111
}
106112
}
107113
}
108114

109-
mean_data[i] = mean_value;
110-
rstd_data[i] = rstd_value;
115+
mean_data[i] = static_cast<CTYPE>(mean_value);
116+
rstd_data[i] = static_cast<CTYPE>(rstd_value);
111117
}
112118
}
113119

@@ -186,7 +192,7 @@ std::tuple<Tensor&, Tensor&, Tensor&> native_group_norm_out(
186192

187193
constexpr auto name = "native_group_norm.out";
188194

189-
ET_SWITCH_FLOAT_TYPES(input.scalar_type(), ctx, name, CTYPE, [&]() {
195+
ET_SWITCH_FLOATHBF16_TYPES(input.scalar_type(), ctx, name, CTYPE, [&]() {
190196
group_norm<CTYPE>(
191197
input, weight, bias, N, C, HxW, group, eps, out, mean_out, rstd_out);
192198
});

0 commit comments

Comments
 (0)