Skip to content

Commit 6618ad5

Browse files
committed
enable dil_layer_norm
1 parent f0cdcc6 commit 6618ad5

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

scripts/cpu/gen-dense-cpu-ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,8 @@
7373
'aten::view(Tensor(a) self, int[] size) -> Tensor(a)',
7474
'aten::index_select(Tensor self, int dim, Tensor index) -> Tensor',
7575
'aten::_unsafe_view(Tensor self, int[] size) -> Tensor',
76-
#'aten::native_layer_norm(Tensor input, Tensor? weight, Tensor? bias, int M, int N, float eps) -> (Tensor, Tensor, Tensor)',
77-
#'aten::native_layer_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, int M, int N, bool[3] output_mask) -> (Tensor, Tensor, Tensor)'
76+
'aten::native_layer_norm(Tensor input, Tensor? weight, Tensor? bias, int M, int N, float eps) -> (Tensor, Tensor, Tensor)',
77+
'aten::native_layer_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, int M, int N, bool[3] output_mask) -> (Tensor, Tensor, Tensor)'
7878
]
7979

8080
_FN_IPEX_FUNCS_WITH_SIMPLE_ATEN_SIG = [

tests/cpu/test_lazy_reorder.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -998,7 +998,11 @@ def test_layer_norm(self):
998998
input = torch.randn(2, 5, 10, 10, dtype=torch.float32)
999999
input_dpcpp=input.to(device=device)
10001000
m = torch.nn.LayerNorm([10, 10])
1001-
self.assertEqual(m(input), m(input_dpcpp))
1001+
m_dpcpp = copy.deepcopy(m).to(device=device)
1002+
output = m(input)
1003+
output_dpcpp = m_dpcpp(input_dpcpp)
1004+
self.assertTrue(ipex.core.is_dil_tensor(output_dpcpp))
1005+
self.assertEqual(output, output_dpcpp)
10021006

10031007
def test_layer_norm_backward(self):
10041008
with AutoDNNL(True):

torch_ipex/csrc/cpu/DevOPs.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1937,6 +1937,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> AtenIpexCPUDev::dil_native_layer_
19371937
double eps) {
19381938
DEBUG("AtenIpexCPUDev::dil_native_layer_norm\n");
19391939
CHECK_DNNL_OP_PRE_COND(X);
1940+
//It's a temporary solution to fall back to fp32 since bf16 layer_norm is not ready for dnnl path now.
1941+
dbl::comm::reorder_to_dtype(X, at::kFloat);
19401942
dil::tensor x = dbl::comm::try_gen_dil_tensor(X);
19411943
const dil::tensor scale = dbl::comm::try_gen_dil_tensor(gamma);
19421944
const dil::tensor shift = dbl::comm::try_gen_dil_tensor(beta);
@@ -1976,6 +1978,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> AtenIpexCPUDev::dil_native_layer_
19761978
DEBUG("AtenIpexCPUDev::dil_native_layer_norm_backward\n");
19771979
CHECK_DNNL_OP_PRE_COND(dY);
19781980
CHECK_DNNL_OP_PRE_COND(X);
1981+
//it's a temporary solution to fall back to fp32 since bf16 layer_norm is not ready for dnnl path now.
1982+
dbl::comm::reorder_to_dtype(dY, at::kFloat);
1983+
dbl::comm::reorder_to_dtype(X, at::kFloat);
19791984
dil::tensor dy = dbl::comm::try_gen_dil_tensor(dY);
19801985
dil::tensor x = dbl::comm::try_gen_dil_tensor(X);
19811986
dil::tensor m = dbl::comm::try_gen_dil_tensor(mean);

0 commit comments

Comments
 (0)