Skip to content

Commit 159c097

Browse files
authored
override group norm (#5019) (#5022)
1 parent 2d56b82 commit 159c097

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

csrc/gpu/aten/operators/GroupNorm.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
11
#include "Norm.h"
2+
#ifdef USE_OVERRIDE_OP
3+
#include "utils/CustomOperatorRegistration.h"
4+
#endif
25

36
using namespace at::AtenIpexTypeXPU::normalization;
47

@@ -987,3 +990,17 @@ std::tuple<Tensor, Tensor, Tensor> native_group_norm_backward(
987990

988991
} // namespace AtenIpexTypeXPU
989992
} // namespace at
993+
994+
#ifdef USE_OVERRIDE_OP
995+
namespace {
996+
997+
IPEX_TORCH_LIBRARY_IMPL(aten, XPU, m) {
998+
m.impl(
999+
"native_group_norm", TORCH_FN((&at::AtenIpexTypeXPU::native_group_norm)));
1000+
m.impl(
1001+
"native_group_norm_backward",
1002+
TORCH_FN((&at::AtenIpexTypeXPU::native_group_norm_backward)));
1003+
}
1004+
1005+
} // namespace
1006+
#endif

scripts/tools/torchgen/yaml/xpu_functions.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ supported:
6363
# - sum.dim_IntList
6464
# - nansum
6565
# - nansum.out
66+
# - native_group_norm
67+
# - native_group_norm_backward
6668
####################################
6769
# - _aminmax
6870
# - _aminmax.dim
@@ -459,8 +461,6 @@ supported:
459461
# - nanmedian.dim_values
460462
# - native_dropout
461463
# - native_dropout_backward
462-
# - native_group_norm
463-
# - native_group_norm_backward
464464
# - native_layer_norm
465465
# - native_layer_norm_backward
466466
# - ne.Scalar

0 commit comments

Comments
 (0)