Skip to content

Commit 6cf810e

Browse files
[Rebase] add autocast for fastlayernorm (#4773) (#4950) (#4955)
Co-authored-by: Liangliang Ma <[email protected]>
1 parent a387517 commit 6cf810e

File tree

1 file changed

+34
-1
lines changed

1 file changed

+34
-1
lines changed

csrc/gpu/aten/operators/AddNormFusion.cpp

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include <ATen/ATen.h>
22
#include <ATen/Config.h>
33
#include <ATen/NativeFunctions.h>
4-
4+
#include <ATen/autocast_mode.h>
55
#include <oneDNN/oneDNN.h>
66
#include "Norm.h"
77
#include "comm/RegistrationDeclarations.h"
@@ -12,6 +12,7 @@ using namespace at::AtenIpexTypeXPU::normalization;
1212

1313
namespace at {
1414
namespace AtenIpexTypeXPU {
15+
using autocast::cached_cast;
1516

1617
// Decalre the rms_norm_fwd from RMSNorm.cpp for naive implementation fallback
1718
std::tuple<Tensor, Tensor> rms_norm_fw(
@@ -581,6 +582,32 @@ Tensor fast_layer_norm(
581582
epsilon);
582583
}
583584

585+
Tensor fast_layer_norm_autocast(
586+
const Tensor& input,
587+
at::IntArrayRef normalized_shape,
588+
const c10::optional<at::Tensor>& weight_opt,
589+
const c10::optional<at::Tensor>& bias_opt,
590+
double epsilon) {
591+
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::AutocastXPU);
592+
auto to_type = input.scalar_type();
593+
if (input.scalar_type() == at::ScalarType::Half ||
594+
weight_opt->scalar_type() == at::ScalarType::Half ||
595+
bias_opt->scalar_type() == at::ScalarType::Half) {
596+
to_type = at::ScalarType::Half;
597+
} else if (
598+
input.scalar_type() == at::ScalarType::BFloat16 ||
599+
weight_opt->scalar_type() == at::ScalarType::BFloat16 ||
600+
bias_opt->scalar_type() == at::ScalarType::BFloat16) {
601+
to_type = at::ScalarType::BFloat16;
602+
}
603+
return fast_layer_norm(
604+
cached_cast(to_type, input, c10::DeviceType::XPU),
605+
normalized_shape,
606+
cached_cast(to_type, *weight_opt, c10::DeviceType::XPU),
607+
cached_cast(to_type, *bias_opt, c10::DeviceType::XPU),
608+
epsilon);
609+
}
610+
584611
Tensor add_add_rms_norm(
585612
const Tensor& add1,
586613
const Tensor& add2,
@@ -691,6 +718,12 @@ IPEX_LIBRARY_FRAGMENT() {
691718
IPEX_OP_REGISTER_DISPATCH(
692719
"fast_layer_norm", fast_layer_norm, c10::DispatchKey::XPU);
693720
}
721+
IPEX_LIBRARY_FRAGMENT() {
722+
IPEX_OP_REGISTER_DISPATCH(
723+
"fast_layer_norm",
724+
fast_layer_norm_autocast,
725+
c10::DispatchKey::AutocastXPU);
726+
}
694727
IPEX_LIBRARY_FRAGMENT() {
695728
IPEX_OP_REGISTER_DISPATCH(
696729
"add_add_rms_norm", add_add_rms_norm, c10::DispatchKey::XPU);

0 commit comments

Comments
 (0)