Skip to content

Commit 07e2365

Browse files
fmo-mtpytorchmergebot
authored andcommitted
Fix RMSNorm epsilon value type for BF16 or FP16 (pytorch#142848)
Fixes pytorch#140092 Here's what this PR does: In before, we create a `scalar_t eps_val;` variable, and the `eps` is mostly a double scalar which passed from python frontend, like 1e-6. While we do `eps_val = std::numeric_limits<at::scalar_value_type<scalar_t>::type>::epsilon();` or `eps_val = eps.value();`, we down cast this epsilon to match input tensor dtype (`scalar_t`), in case of BFloat16, the 1e-6 double would be cast to `1.00136e-05`. However, while we act `auto rqrst_input = rsqrt(at::pow(upcasted_input, 2).mean(dims_to_reduce_ref, /*keepdim=*/true).add_(eps_val));`, we up cast `eps_val` to match the `opmath_t`, the conversion between these two dtypes is UNNECESSARY, so we could just make the `opmath_t eps_val` instead of `scalar_t`. Pull Request resolved: pytorch#142848 Approved by: https://github.com/mikaylagawarecki
1 parent a8ef423 commit 07e2365

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

aten/src/ATen/native/layer_norm.cpp

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -278,18 +278,27 @@ Tensor rms_norm_symint(
278278
input.scalar_type(),
279279
"rms_norm",
280280
[&] {
281-
scalar_t eps_val;
282-
if (!eps.has_value()) {
283-
eps_val = std::numeric_limits<at::scalar_value_type<scalar_t>::type>::epsilon();
284-
} else {
285-
eps_val = eps.value();
286-
}
287-
288281
// upcast is needed for fp16 and bf16
289282
c10::ScalarType opmath_t = toOpMathType(input.scalar_type());
290283
Tensor upcasted_input = input.to(opmath_t);
291284

292-
auto rqrst_input = rsqrt(at::pow(upcasted_input, 2).mean(dims_to_reduce_ref, /*keepdim=*/true).add_(eps_val));
285+
Tensor rqrst_input;
286+
287+
// opmath_t would be one of [Double, Float, ComplexFloat, ComplexDouble]
288+
if (opmath_t == at::ScalarType::Float || opmath_t == at::ScalarType::ComplexFloat) {
289+
float eps_val = std::numeric_limits<float>::epsilon();
290+
if (eps.has_value()) {
291+
eps_val = eps.value();
292+
}
293+
rqrst_input = rsqrt(at::pow(upcasted_input, 2).mean(dims_to_reduce_ref, /*keepdim=*/true).add_(eps_val));
294+
} else {
295+
double eps_val = std::numeric_limits<double>::epsilon();
296+
if (eps.has_value()) {
297+
eps_val = eps.value();
298+
}
299+
rqrst_input = rsqrt(at::pow(upcasted_input, 2).mean(dims_to_reduce_ref, /*keepdim=*/true).add_(eps_val));
300+
}
301+
293302
Tensor result = upcasted_input.mul(rqrst_input).type_as(input);
294303

295304
if (weight_opt.has_value()) {

0 commit comments

Comments
 (0)