Skip to content

Commit c838fa3

Browse files
committed
Port dist_transpiler to Python3.5
Resume prelu_op_test in python2
1 parent 90b5be8 commit c838fa3

File tree

3 files changed

+19
-15
lines changed

3 files changed

+19
-15
lines changed

python/paddle/fluid/tests/unittests/test_prelu_op.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import unittest
1818
import numpy as np
19+
import six
1920
from op_test import OpTest
2021

2122

@@ -62,17 +63,20 @@ def test_check_grad_3_ignore_alpha(self):
6263

6364

6465
# TODO(minqiyang): Resume these test cases after fixing Python3 CI job issues
65-
# class TestCase1(PReluTest):
66-
# def initTestCase(self):
67-
# self.attrs = {'mode': "all"}
66+
if six.PY2:
6867

69-
# class TestCase2(PReluTest):
70-
# def initTestCase(self):
71-
# self.attrs = {'mode': "channel"}
68+
class TestCase1(PReluTest):
69+
def initTestCase(self):
70+
self.attrs = {'mode': "all"}
71+
72+
class TestCase2(PReluTest):
73+
def initTestCase(self):
74+
self.attrs = {'mode': "channel"}
75+
76+
class TestCase3(PReluTest):
77+
def initTestCase(self):
78+
self.attrs = {'mode': "element"}
7279

73-
# class TestCase3(PReluTest):
74-
# def initTestCase(self):
75-
# self.attrs = {'mode': "element"}
7680

7781
if __name__ == "__main__":
7882
unittest.main()

python/paddle/fluid/transpiler/details/program_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def block_to_code(block, block_idx):
153153

154154
indent += 1
155155
# sort all vars
156-
all_vars = sorted(block.vars.iteritems(), key=lambda x: x[0])
156+
all_vars = sorted(six.iteritems(block.vars), key=lambda x: x[0])
157157
for var in all_vars:
158158
print("{}{}".format(get_indent_space(indent), variable_to_code(var[1])))
159159

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def transpile(self,
300300
input_deps = grad_name_to_send_dummy_out.values()
301301
program.global_block().append_op(
302302
type="send_barrier",
303-
inputs={"X": input_deps},
303+
inputs={"X": list(input_deps)},
304304
outputs={"Out": send_barrier_out},
305305
attrs={
306306
"endpoints": pserver_endpoints,
@@ -401,7 +401,7 @@ def _get_trainer_startup_program(self, recv_vars, eplist):
401401
402402
Args:
403403
recv_vars (list): Variable list to recv for current trainer_id
404-
eplist (list): A list of strings indicating
404+
eplist (list): A list of strings indicating
405405
406406
Returns:
407407
Program: trainer side startup program.
@@ -455,7 +455,7 @@ def _get_trainer_startup_program(self, recv_vars, eplist):
455455
if len(splited_var) <= 1:
456456
continue
457457
# NOTE: if enable memory optimization, origin vars maybe removed.
458-
if startup_program.global_block().vars.has_key(varname):
458+
if varname in startup_program.global_block().vars:
459459
orig_param = startup_program.global_block().vars[varname]
460460
else:
461461
origin_param_var = self.origin_program.global_block().vars[
@@ -690,7 +690,7 @@ def get_pserver_programs(self, endpoint):
690690
691691
Args:
692692
endpoint (str): current pserver endpoint.
693-
693+
694694
Returns:
695695
tuple: (main_program, startup_program), of type "Program"
696696
"""
@@ -713,7 +713,7 @@ def get_startup_program(self,
713713
endpoint (str): current pserver endpoint.
714714
pserver_program (Program): deprecated, call get_pserver_program first.
715715
startup_program (Program): deprecated, should pass startup_program
716-
when initalizing
716+
when initalizing
717717
718718
Returns:
719719
Program: parameter server side startup program.

0 commit comments

Comments
 (0)