Skip to content

Commit b162d61

Browse files
authored
Use same sacle in qsigmoid UT (#2094)
1 parent 19ac29e commit b162d61

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

tests/gpu/examples/test_qsigmoid.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,24 @@
66

77

88
class TestTorchMethod(TestCase):
9-
@pytest.mark.skip(reason="Need create a PR for torch to align output_scale")
109
def test_qsigmoid(self, dtype=torch.float):
1110
dtype = torch.qint8
1211

1312
input0 = torch.randn(1, 1, 5, 5, device="xpu")
1413
q_input = torch.quantize_per_tensor(input0, 0.4, 0, dtype=dtype)
1514
q_input_cpu = torch.quantize_per_tensor(input0.to("cpu"), 0.4, 0, dtype=dtype)
15+
1616
result_functional = torch.dequantize(torch.sigmoid(q_input))
1717
result_inplace = torch.dequantize(torch.sigmoid_(q_input))
1818
result_output = torch.randn(1, 1, 5, 5, device="xpu")
1919
result_out = torch.dequantize(torch.sigmoid(q_input, out=result_output))
20-
result_cpu = torch.dequantize(torch.sigmoid(q_input_cpu))
2120

22-
self.assertEqual(result_functional.to("cpu"), result_cpu)
23-
self.assertEqual(result_inplace.to("cpu"), result_cpu)
24-
self.assertEqual(result_out.to("cpu"), result_cpu)
21+
dqX = torch.dequantize(q_input_cpu)
22+
Y_ref = torch.sigmoid(dqX)
23+
qY_ref = torch.quantize_per_tensor(Y_ref, 1.0/255.0, 0, torch.quint8)
24+
dqY_ref = qY_ref.dequantize()
25+
26+
self.assertEqual(result_functional.to("cpu"), dqY_ref)
27+
self.assertEqual(result_inplace.to("cpu"), dqY_ref)
28+
self.assertEqual(result_out.to("cpu"), dqY_ref)
29+

0 commit comments

Comments
 (0)