11#include < ATen/ATen.h>
22#include < ATen/Config.h>
33#include < ATen/NativeFunctions.h>
4-
4+ # include < ATen/autocast_mode.h >
55#include < oneDNN/oneDNN.h>
66#include " Norm.h"
77#include " comm/RegistrationDeclarations.h"
@@ -12,6 +12,7 @@ using namespace at::AtenIpexTypeXPU::normalization;
1212
1313namespace at {
1414namespace AtenIpexTypeXPU {
15+ using autocast::cached_cast;
1516
1617// Decalre the rms_norm_fwd from RMSNorm.cpp for naive implementation fallback
1718std::tuple<Tensor, Tensor> rms_norm_fw (
@@ -581,6 +582,32 @@ Tensor fast_layer_norm(
581582 epsilon);
582583}
583584
585+ Tensor fast_layer_norm_autocast (
586+ const Tensor& input,
587+ at::IntArrayRef normalized_shape,
588+ const c10::optional<at::Tensor>& weight_opt,
589+ const c10::optional<at::Tensor>& bias_opt,
590+ double epsilon) {
591+ c10::impl::ExcludeDispatchKeyGuard no_autocast (c10::DispatchKey::AutocastXPU);
592+ auto to_type = input.scalar_type ();
593+ if (input.scalar_type () == at::ScalarType::Half ||
594+ weight_opt->scalar_type () == at::ScalarType::Half ||
595+ bias_opt->scalar_type () == at::ScalarType::Half) {
596+ to_type = at::ScalarType::Half;
597+ } else if (
598+ input.scalar_type () == at::ScalarType::BFloat16 ||
599+ weight_opt->scalar_type () == at::ScalarType::BFloat16 ||
600+ bias_opt->scalar_type () == at::ScalarType::BFloat16) {
601+ to_type = at::ScalarType::BFloat16;
602+ }
603+ return fast_layer_norm (
604+ cached_cast (to_type, input, c10::DeviceType::XPU),
605+ normalized_shape,
606+ cached_cast (to_type, *weight_opt, c10::DeviceType::XPU),
607+ cached_cast (to_type, *bias_opt, c10::DeviceType::XPU),
608+ epsilon);
609+ }
610+
584611Tensor add_add_rms_norm (
585612 const Tensor& add1,
586613 const Tensor& add2,
@@ -691,6 +718,12 @@ IPEX_LIBRARY_FRAGMENT() {
691718 IPEX_OP_REGISTER_DISPATCH (
692719 " fast_layer_norm" , fast_layer_norm, c10::DispatchKey::XPU);
693720}
721+ IPEX_LIBRARY_FRAGMENT () {
722+ IPEX_OP_REGISTER_DISPATCH (
723+ " fast_layer_norm" ,
724+ fast_layer_norm_autocast,
725+ c10::DispatchKey::AutocastXPU);
726+ }
694727IPEX_LIBRARY_FRAGMENT () {
695728 IPEX_OP_REGISTER_DISPATCH (
696729 " add_add_rms_norm" , add_add_rms_norm, c10::DispatchKey::XPU);
0 commit comments