Skip to content

Commit d9c3123

Browse files
authored
Merge pull request #13181 from tensor-tang/refine/fusion/ut
refine fusion lstm and gru op test
2 parents ec5204b + ef71b8e commit d9c3123

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

python/paddle/fluid/tests/unittests/test_fusion_gru_op.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def fusion_gru(
3737
h0,
3838
wh,
3939
np.zeros(
40-
(1, wh.shape[1]), dtype='float64'),
40+
(1, wh.shape[1]), dtype='float32'),
4141
is_reverse,
4242
act_state,
4343
act_gate)
@@ -62,15 +62,15 @@ def setUp(self):
6262
T = sum(self.lod[0])
6363
N = len(self.lod[0])
6464

65-
x = np.random.rand(T, self.M).astype('float64')
66-
wx = np.random.rand(self.M, 3 * self.D).astype('float64')
67-
wh = np.random.rand(self.D, 3 * self.D).astype('float64')
65+
x = np.random.rand(T, self.M).astype('float32')
66+
wx = np.random.rand(self.M, 3 * self.D).astype('float32')
67+
wh = np.random.rand(self.D, 3 * self.D).astype('float32')
6868
bias = np.random.rand(
69-
1, 3 * self.D).astype('float64') if self.with_bias else np.zeros(
70-
(1, 3 * self.D), dtype='float64')
69+
1, 3 * self.D).astype('float32') if self.with_bias else np.zeros(
70+
(1, 3 * self.D), dtype='float32')
7171
h0 = np.random.rand(
72-
N, self.D).astype('float64') if self.with_h0 else np.zeros(
73-
(N, self.D), dtype='float64')
72+
N, self.D).astype('float32') if self.with_h0 else np.zeros(
73+
(N, self.D), dtype='float32')
7474

7575
_, _, _, hidden = fusion_gru(
7676
x, self.lod, h0, wx, wh, bias, self.is_reverse,
@@ -93,7 +93,9 @@ def setUp(self):
9393
}
9494

9595
def test_check_output(self):
96-
self.check_output(atol=1e-8)
96+
for use_seq in {True, False}:
97+
self.attrs['use_seq'] = use_seq
98+
self.check_output()
9799

98100

99101
class TestFusionGRUOpNoInitial(TestFusionGRUOp):

python/paddle/fluid/tests/unittests/test_fusion_lstm_op.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ def setUp(self):
114114
}
115115

116116
def test_check_output(self):
117-
self.check_output()
117+
for use_seq in {True, False}:
118+
self.attrs['use_seq'] = use_seq
119+
self.check_output()
118120

119121

120122
class TestFusionLSTMOpInit(TestFusionLSTMOp):

0 commit comments

Comments
 (0)