Skip to content

Commit 19ef4ba

Browse files
authored
Merge pull request #12418 from panyx0718/fix_dist_seed
Fix dist seed
2 parents ccf9d3a + 412ad81 commit 19ef4ba

File tree

2 files changed

+24
-6
lines changed

2 files changed

+24
-6
lines changed

python/paddle/fluid/tests/unittests/test_dist_se_resnext.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import unittest
2121
import os
22+
import sys
2223
import signal
2324
import subprocess
2425

@@ -56,7 +57,7 @@ def _wait_ps_ready(self, pid):
5657
except os.error:
5758
retry_times -= 1
5859

59-
def no_test_with_place(self):
60+
def test_with_place(self):
6061
# *ATTENTION* THIS TEST NEEDS AT LEAST 2GPUS TO RUN
6162
required_envs = {
6263
"PATH": os.getenv("PATH"),
@@ -70,9 +71,15 @@ def no_test_with_place(self):
7071
local_cmd = "%s dist_se_resnext.py trainer %s 0 %s %d FLASE" % \
7172
(self._python_interp, "127.0.0.1:1234", "127.0.0.1:1234", 1)
7273
local_proc = subprocess.Popen(
73-
local_cmd.split(" "), stdout=subprocess.PIPE, env=env_local)
74+
local_cmd.split(" "),
75+
stdout=subprocess.PIPE,
76+
stderr=subprocess.PIPE,
77+
env=env_local)
7478
local_proc.wait()
75-
local_ret = local_proc.stdout.read()
79+
out, err = local_proc.communicate()
80+
local_ret = out
81+
sys.stderr.write('local_loss: %s\n' % local_ret)
82+
sys.stderr.write('local_stderr: %s\n' % err)
7683

7784
# Run dist train to compare with local results
7885
ps0, ps1 = self.start_pserver()
@@ -92,13 +99,22 @@ def no_test_with_place(self):
9299
FNULL = open(os.devnull, 'w')
93100

94101
tr0_proc = subprocess.Popen(
95-
tr0_cmd.split(" "), stdout=subprocess.PIPE, stderr=FNULL, env=env0)
102+
tr0_cmd.split(" "),
103+
stdout=subprocess.PIPE,
104+
stderr=subprocess.PIPE,
105+
env=env0)
96106
tr1_proc = subprocess.Popen(
97-
tr1_cmd.split(" "), stdout=subprocess.PIPE, stderr=FNULL, env=env1)
107+
tr1_cmd.split(" "),
108+
stdout=subprocess.PIPE,
109+
stderr=subprocess.PIPE,
110+
env=env1)
98111

99112
tr0_proc.wait()
100113
tr1_proc.wait()
101-
loss_data0 = tr0_proc.stdout.read()
114+
out, err = tr0_proc.communicate()
115+
sys.stderr.write('dist_stderr: %s\n' % err)
116+
loss_data0 = out
117+
sys.stderr.write('dist_loss: %s\n' % loss_data0)
102118
lines = loss_data0.split("\n")
103119
dist_first_loss = eval(lines[0].replace(" ", ","))[0]
104120
dist_last_loss = eval(lines[1].replace(" ", ","))[0]

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,7 @@ def get_pserver_program(self, endpoint):
347347

348348
# step1
349349
pserver_program = Program()
350+
pserver_program.random_seed = self.origin_program.random_seed
350351
# step2: Create vars to receive vars at parameter servers.
351352
recv_inputs = []
352353
for v in self.param_grad_ep_mapping[endpoint]["params"]:
@@ -544,6 +545,7 @@ def get_startup_program(self, endpoint, pserver_program):
544545
"""
545546
s_prog = Program()
546547
orig_s_prog = default_startup_program()
548+
s_prog.random_seed = orig_s_prog.random_seed
547549
params = self.param_grad_ep_mapping[endpoint]["params"]
548550

549551
def _get_splited_name_and_shape(varname):

0 commit comments

Comments
 (0)