Skip to content

Commit 5a83776

Browse files
committed
Port test_desc_clone
1 parent 50d66a0 commit 5a83776

File tree

2 files changed

+22
-44
lines changed

2 files changed

+22
-44
lines changed

python/paddle/fluid/tests/unittests/test_desc_clone.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def get_transpiler(trainer_id, main_program, pserver_endpoints, trainers):
110110

111111

112112
def operator_equal(a, b):
113-
for k, v in a.__dict__.iteritems():
113+
for k, v in six.iteritems(a.__dict__):
114114
if isinstance(v, fluid.framework.Program) or \
115115
isinstance(v, fluid.framework.Block):
116116
continue
@@ -120,8 +120,8 @@ def operator_equal(a, b):
120120
raise ValueError("In operator_equal not equal:{0}\n".format(k))
121121

122122
elif isinstance(v, collections.OrderedDict):
123-
v0 = sorted(v.iteritems(), key=lambda x: x[0])
124-
v1 = sorted(b.__dict__[k].iteritems(), key=lambda x: x[0])
123+
v0 = sorted(six.iteritems(v), key=lambda x: x[0])
124+
v1 = sorted(six.iteritems(b.__dict__[k]), key=lambda x: x[0])
125125

126126
if v0 != v1:
127127
raise ValueError("In operator_equal not equal:{0}\n".format(k))
@@ -133,7 +133,7 @@ def operator_equal(a, b):
133133

134134

135135
def block_equal(a, b):
136-
for k, v in a.__dict__.iteritems():
136+
for k, v in six.iteritems(a.__dict__):
137137
if isinstance(v, core.ProgramDesc) or isinstance(
138138
v, fluid.framework.Program) or isinstance(v, core.BlockDesc):
139139
continue
@@ -145,8 +145,8 @@ def block_equal(a, b):
145145
assert (len(a.ops) == len(b.ops))
146146

147147
elif isinstance(v, collections.OrderedDict):
148-
v0 = sorted(v.iteritems(), key=lambda x: x[0])
149-
v1 = sorted(b.__dict__[k].iteritems(), key=lambda x: x[0])
148+
v0 = sorted(six.iteritems(v), key=lambda x: x[0])
149+
v1 = sorted(six.iteritems(b.__dict__[k]), key=lambda x: x[0])
150150

151151
if v0 != v1:
152152
raise ValueError("In block_equal not equal:{0}\n".format(k))
@@ -158,7 +158,7 @@ def block_equal(a, b):
158158

159159

160160
def program_equal(a, b):
161-
for k, v in a.__dict__.iteritems():
161+
for k, v in six.iteritems(a.__dict__):
162162
if isinstance(v, core.ProgramDesc):
163163
continue
164164

python/paddle/fluid/tests/unittests/test_prelu_op.py

Lines changed: 15 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,6 @@
2121

2222
class PReluTest(OpTest):
2323
def setUp(self):
24-
print('setUp')
25-
import sys
26-
sys.stdout.flush()
2724
self.op_type = "prelu"
2825
self.initTestCase()
2926
x_np = np.random.normal(size=(3, 5, 5, 10)).astype("float32")
@@ -48,39 +45,19 @@ def setUp(self):
4845
assert out_np is not self.inputs['X']
4946
self.outputs = {'Out': out_np}
5047

51-
def tearDown(self):
52-
print('tearDown')
53-
import sys
54-
sys.stdout.flush()
55-
del self.outputs
56-
del self.inputs
57-
5848
def initTestCase(self):
5949
self.attrs = {'mode': "channel"}
6050

61-
def test_check_4_output(self):
62-
print('test_check_0_output')
63-
import sys
64-
sys.stdout.flush()
51+
def test_check_output(self):
6552
self.check_output()
6653

67-
def test_check_0_grad_2_ignore_x(self):
68-
print('test_check_2_grad_2_ignore_x')
69-
import sys
70-
sys.stdout.flush()
71-
self.check_grad(['Alpha'], 'Out', no_grad_set=set('X'))
72-
73-
# TODO(minqiyang): remove the order of tests
74-
def test_check_1_grad_1(self):
75-
print('test_check_1_grad_1')
76-
import sys
77-
sys.stdout.flush()
54+
def test_check_grad(self):
7855
self.check_grad(['X', 'Alpha'], 'Out')
7956

80-
def test_check_3_grad_3_ignore_alpha(self):
81-
print('test_check_3_grad_3_ignore_alpha')
82-
import sys
83-
sys.stdout.flush()
57+
def test_check_grad_ignore_x(self):
58+
self.check_grad(['Alpha'], 'Out', no_grad_set=set('X'))
59+
60+
def test_check_grad_ignore_alpha(self):
8461
self.check_grad(['X'], 'Out', no_grad_set=set('Alpha'))
8562

8663

@@ -89,14 +66,15 @@ def initTestCase(self):
8966
self.attrs = {'mode': "all"}
9067

9168

92-
#class TestCase2(PReluTest):
93-
# def initTestCase(self):
94-
# self.attrs = {'mode': "channel"}
95-
#
96-
#
97-
#class TestCase3(PReluTest):
98-
# def initTestCase(self):
99-
# self.attrs = {'mode': "element"}
69+
class TestCase2(PReluTest):
70+
def initTestCase(self):
71+
self.attrs = {'mode': "channel"}
72+
73+
74+
class TestCase3(PReluTest):
75+
def initTestCase(self):
76+
self.attrs = {'mode': "element"}
77+
10078

10179
if __name__ == "__main__":
10280
unittest.main()

0 commit comments

Comments
 (0)