Skip to content

Commit b71833e

Browse files
authored
[UT]fix test_poisson op random fail (#44763)
修复poisson op单测随机挂 原因:由于随机OP的无法直接验证数值正确性,该单测随机采样100万个样本,统计落到直方图各区间的数量,计算出粗略的概率密度函数,与标准概率密度函数对比,这种测试方式会有一定误差。 当采样数量越小,误差越大,因此该PR增大采样样本数量(100万->200万),误差进一步减小在rtol范围内。
1 parent 684b12e commit b71833e

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

python/paddle/fluid/tests/unittests/test_poisson_op.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,14 @@ def output_hist(out, lam, a, b):
3939

4040

4141
class TestPoissonOp1(OpTest):
42+
4243
def setUp(self):
4344
self.op_type = "poisson"
4445
self.config()
4546

4647
self.attrs = {}
47-
self.inputs = {'X': np.full([1024, 1024], self.lam, dtype=self.dtype)}
48-
self.outputs = {'Out': np.ones([1024, 1024], dtype=self.dtype)}
48+
self.inputs = {'X': np.full([2048, 1024], self.lam, dtype=self.dtype)}
49+
self.outputs = {'Out': np.ones([2048, 1024], dtype=self.dtype)}
4950

5051
def config(self):
5152
self.lam = 10
@@ -55,10 +56,8 @@ def config(self):
5556

5657
def verify_output(self, outs):
5758
hist, prob = output_hist(np.array(outs[0]), self.lam, self.a, self.b)
58-
self.assertTrue(
59-
np.allclose(
60-
hist, prob, rtol=0.01),
61-
"actual: {}, expected: {}".format(hist, prob))
59+
self.assertTrue(np.allclose(hist, prob, rtol=0.01),
60+
"actual: {}, expected: {}".format(hist, prob))
6261

6362
def test_check_output(self):
6463
self.check_output_customized(self.verify_output)
@@ -67,22 +66,23 @@ def test_check_grad_normal(self):
6766
self.check_grad(
6867
['X'],
6968
'Out',
70-
user_defined_grads=[np.zeros(
71-
[1024, 1024], dtype=self.dtype)],
69+
user_defined_grads=[np.zeros([2048, 1024], dtype=self.dtype)],
7270
user_defined_grad_outputs=[
73-
np.random.rand(1024, 1024).astype(self.dtype)
71+
np.random.rand(2048, 1024).astype(self.dtype)
7472
])
7573

7674

7775
class TestPoissonOp2(TestPoissonOp1):
76+
7877
def config(self):
7978
self.lam = 5
8079
self.a = 1
81-
self.b = 9
80+
self.b = 8
8281
self.dtype = "float32"
8382

8483

8584
class TestPoissonAPI(unittest.TestCase):
85+
8686
def test_static(self):
8787
with paddle.static.program_guard(paddle.static.Program(),
8888
paddle.static.Program()):

0 commit comments

Comments
 (0)