Skip to content

Commit 06155ab

Browse files
authored
[0-size Tensor Job2 No.13、54] Add 0-size Tensor support for gammainc (#73437)
* Fix * Fix * Fix
1 parent b6b5d60 commit 06155ab

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

test/legacy_test/test_bce_loss.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,16 @@ def test_fp16(self):
336336
paddle.disable_static()
337337

338338

339+
class TestBceLossOp_ZeroSize(TestBceLossOp):
340+
def init_test_cast(self):
341+
self.shape = [0, 1, 2]
342+
343+
344+
class TestBceLossOp_ZeroSize2(TestBceLossOp):
345+
def init_test_cast(self):
346+
self.shape = [0]
347+
348+
339349
if __name__ == "__main__":
340350
paddle.enable_static()
341351
unittest.main()

test/legacy_test/test_gammainc.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import unittest
1616

1717
import numpy as np
18+
from op_test import OpTest
1819
from scipy import special
1920

2021
import paddle
@@ -70,5 +71,36 @@ def init_dtype_type(self):
7071
self.dtype = "float32"
7172

7273

74+
class TestGammaincOp_ZeroSize(OpTest):
75+
def setUp(self):
76+
self.op_type = 'gammaincc'
77+
self.python_api = paddle.gammainc
78+
self.init_dtype_type()
79+
self.init_shape()
80+
self.x = np.random.random(self.shape).astype(self.dtype) + 1
81+
self.y = np.random.random(self.shape).astype(self.dtype) + 1
82+
self.inputs = {'x': self.x, 'y': self.y}
83+
out = ref_gammainc(self.x, self.y)
84+
self.outputs = {'out': out}
85+
86+
def init_shape(self):
87+
self.shape = (0, 40)
88+
89+
def init_dtype_type(self):
90+
self.dtype = np.float64
91+
92+
def test_check_output(self):
93+
self.check_output()
94+
95+
def test_check_grad(self):
96+
self.check_grad(['y'], 'out')
97+
98+
99+
class TestGammaincOp_ZeroSize2(TestGammaincOp_ZeroSize):
100+
101+
def init_shape(self):
102+
self.shape = (0,)
103+
104+
73105
if __name__ == "__main__":
74106
unittest.main()

0 commit comments

Comments
 (0)