Skip to content

Commit c010692

Browse files
hoshibarajiayisunx
andauthored
Remove the overide of aten::instance_norm by torch_ipex:instance_norm (#2289) (#3569) (#3575)
* Cancel the overide of aten::instance_norm by torch_ipex::cpu::instance_norm * Not running the tests * fix format issue Co-authored-by: jiayisunx <[email protected]>
1 parent 006bcfc commit c010692

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

csrc/cpu/aten/InstanceNorm.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -444,11 +444,11 @@ TORCH_LIBRARY_FRAGMENT(torch_ipex, m) {
444444
torch_ipex::cpu::instance_norm_backward);
445445
}
446446

447-
IPEX_TORCH_LIBRARY_IMPL(aten, CPU, m) {
448-
m.impl(
449-
TORCH_SELECTIVE_NAME("aten::instance_norm"),
450-
TORCH_FN((&torch_ipex::cpu::instance_norm)));
451-
}
447+
// IPEX_TORCH_LIBRARY_IMPL(aten, CPU, m) {
448+
// m.impl(
449+
// TORCH_SELECTIVE_NAME("aten::instance_norm"),
450+
// TORCH_FN((&torch_ipex::cpu::instance_norm)));
451+
// }
452452

453453
} // namespace cpu
454454
} // namespace torch_ipex

tests/cpu/test_instance_norm.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
22
import torch
3-
import unittest
3+
4+
# import unittest
45
from common_utils import TestCase
56
from torch.nn import InstanceNorm2d, InstanceNorm3d, BatchNorm2d, BatchNorm3d
67

@@ -53,7 +54,7 @@ def test_instance_norm(self):
5354
y2 = m(x2)
5455
self.assertTrue(y2.dtype == torch.float32)
5556
self.assertEqual(y2, y1)
56-
self.assertTrue(y2.is_contiguous(memory_format=memory_format))
57+
self.assertTrue(y2.is_contiguous(memory_format=torch.contiguous_format))
5758

5859
y2.mean().backward()
5960
self.assertTrue(x2.grad.dtype == torch.float32)
@@ -109,7 +110,7 @@ def test_instance_norm_bfloat16(self):
109110
)
110111
y2 = m(x2)
111112
self.assertTrue(y2.dtype == torch.bfloat16)
112-
self.assertTrue(y2.is_contiguous(memory_format=memory_format))
113+
self.assertTrue(y2.is_contiguous(memory_format=torch.contiguous_format))
113114
self.assertEqual(y2, y1, prec=0.1)
114115

115116
y2.mean().backward()
@@ -118,5 +119,5 @@ def test_instance_norm_bfloat16(self):
118119
self.assertEqual(x2.grad, x1.grad)
119120

120121

121-
if __name__ == "__main__":
122-
test = unittest.main()
122+
# if __name__ == "__main__":
123+
# test = unittest.main()

0 commit comments

Comments
 (0)