Skip to content

Commit 509c839

Browse files
committed
address comments
1 parent 182da95 commit 509c839

File tree

1 file changed

+20
-20
lines changed

1 file changed

+20
-20
lines changed

python/paddle/fluid/tests/unittests/test_dropout_op.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -83,36 +83,36 @@ def test_check_output(self):
8383
self.check_output()
8484

8585

86-
class TestFP16DropoutOp1(OpTest):
86+
class TestFP16DropoutOp(OpTest):
8787
def setUp(self):
88-
x = np.random.random((32, 64)).astype("float16")
89-
prob = 0.35
90-
out = x * (1.0 - prob)
91-
9288
self.op_type = "dropout"
89+
self.init_test_case()
90+
91+
x = np.random.random(self.input_size).astype("float16")
92+
out = x * (1.0 - self.prob)
9393
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
94-
self.attrs = {'dropout_prob': prob, 'fix_seed': True, 'is_test': True}
94+
self.attrs = {
95+
'dropout_prob': self.prob,
96+
'fix_seed': self.fix_seed,
97+
'is_test': True
98+
}
9599
self.outputs = {'Out': out}
96100

101+
def init_test_case(self):
102+
self.input_size = [32, 64]
103+
self.prob = 0.35
104+
self.fix_seed = True
105+
97106
def test_check_output(self):
98107
if core.is_compiled_with_cuda() and core.op_support_gpu("dropout"):
99108
self.check_output_with_place(core.CUDAPlace(0), atol=1e-3)
100109

101110

102-
class TestFP16DropoutOp2(OpTest):
103-
def setUp(self):
104-
x = np.random.random((32, 64, 3)).astype("float16")
105-
prob = 0.75
106-
out = x * (1.0 - prob)
107-
108-
self.op_type = "dropout"
109-
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
110-
self.attrs = {'dropout_prob': prob, 'is_test': True}
111-
self.outputs = {'Out': out}
112-
113-
def test_check_output(self):
114-
if core.is_compiled_with_cuda() and core.op_support_gpu("dropout"):
115-
self.check_output_with_place(core.CUDAPlace(0), atol=1e-3)
111+
class TestFP16DropoutOp2(TestFP16DropoutOp):
112+
def init_test_case(self):
113+
self.input_size = [32, 64, 3]
114+
self.prob = 0.75
115+
self.fix_seed = False
116116

117117

118118
if __name__ == '__main__':

0 commit comments

Comments
 (0)