You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fix RMSNorm epsilon value type for BF16 or FP16 (pytorch#142848)
Fixespytorch#140092
Here's what this PR does:
In before, we create a `scalar_t eps_val;` variable, and the `eps` is mostly a double scalar which passed from python frontend, like 1e-6.
While we do `eps_val = std::numeric_limits<at::scalar_value_type<scalar_t>::type>::epsilon();` or `eps_val = eps.value();`, we down cast this epsilon to match input tensor dtype (`scalar_t`), in case of BFloat16, the 1e-6 double would be cast to `1.00136e-05`.
However, while we act `auto rqrst_input = rsqrt(at::pow(upcasted_input, 2).mean(dims_to_reduce_ref, /*keepdim=*/true).add_(eps_val));`, we up cast `eps_val` to match the `opmath_t`, the conversion between these two dtypes is UNNECESSARY, so we could just make the `opmath_t eps_val` instead of `scalar_t`.
Pull Request resolved: pytorch#142848
Approved by: https://github.com/mikaylagawarecki
0 commit comments