@@ -39,13 +39,14 @@ def output_hist(out, lam, a, b):
39
39
40
40
41
41
class TestPoissonOp1 (OpTest ):
42
+
42
43
def setUp (self ):
43
44
self .op_type = "poisson"
44
45
self .config ()
45
46
46
47
self .attrs = {}
47
- self .inputs = {'X' : np .full ([1024 , 1024 ], self .lam , dtype = self .dtype )}
48
- self .outputs = {'Out' : np .ones ([1024 , 1024 ], dtype = self .dtype )}
48
+ self .inputs = {'X' : np .full ([2048 , 1024 ], self .lam , dtype = self .dtype )}
49
+ self .outputs = {'Out' : np .ones ([2048 , 1024 ], dtype = self .dtype )}
49
50
50
51
def config (self ):
51
52
self .lam = 10
@@ -55,10 +56,8 @@ def config(self):
55
56
56
57
def verify_output (self , outs ):
57
58
hist , prob = output_hist (np .array (outs [0 ]), self .lam , self .a , self .b )
58
- self .assertTrue (
59
- np .allclose (
60
- hist , prob , rtol = 0.01 ),
61
- "actual: {}, expected: {}" .format (hist , prob ))
59
+ self .assertTrue (np .allclose (hist , prob , rtol = 0.01 ),
60
+ "actual: {}, expected: {}" .format (hist , prob ))
62
61
63
62
def test_check_output (self ):
64
63
self .check_output_customized (self .verify_output )
@@ -67,22 +66,23 @@ def test_check_grad_normal(self):
67
66
self .check_grad (
68
67
['X' ],
69
68
'Out' ,
70
- user_defined_grads = [np .zeros (
71
- [1024 , 1024 ], dtype = self .dtype )],
69
+ user_defined_grads = [np .zeros ([2048 , 1024 ], dtype = self .dtype )],
72
70
user_defined_grad_outputs = [
73
- np .random .rand (1024 , 1024 ).astype (self .dtype )
71
+ np .random .rand (2048 , 1024 ).astype (self .dtype )
74
72
])
75
73
76
74
77
75
class TestPoissonOp2 (TestPoissonOp1 ):
76
+
78
77
def config (self ):
79
78
self .lam = 5
80
79
self .a = 1
81
- self .b = 9
80
+ self .b = 8
82
81
self .dtype = "float32"
83
82
84
83
85
84
class TestPoissonAPI (unittest .TestCase ):
85
+
86
86
def test_static (self ):
87
87
with paddle .static .program_guard (paddle .static .Program (),
88
88
paddle .static .Program ()):
0 commit comments