19
19
20
20
import unittest
21
21
import os
22
+ import sys
22
23
import signal
23
24
import subprocess
24
25
@@ -56,7 +57,7 @@ def _wait_ps_ready(self, pid):
56
57
except os .error :
57
58
retry_times -= 1
58
59
59
- def no_test_with_place (self ):
60
+ def test_with_place (self ):
60
61
# *ATTENTION* THIS TEST NEEDS AT LEAST 2GPUS TO RUN
61
62
required_envs = {
62
63
"PATH" : os .getenv ("PATH" ),
@@ -70,9 +71,15 @@ def no_test_with_place(self):
70
71
local_cmd = "%s dist_se_resnext.py trainer %s 0 %s %d FLASE" % \
71
72
(self ._python_interp , "127.0.0.1:1234" , "127.0.0.1:1234" , 1 )
72
73
local_proc = subprocess .Popen (
73
- local_cmd .split (" " ), stdout = subprocess .PIPE , env = env_local )
74
+ local_cmd .split (" " ),
75
+ stdout = subprocess .PIPE ,
76
+ stderr = subprocess .PIPE ,
77
+ env = env_local )
74
78
local_proc .wait ()
75
- local_ret = local_proc .stdout .read ()
79
+ out , err = local_proc .communicate ()
80
+ local_ret = out
81
+ sys .stderr .write ('local_loss: %s\n ' % local_ret )
82
+ sys .stderr .write ('local_stderr: %s\n ' % err )
76
83
77
84
# Run dist train to compare with local results
78
85
ps0 , ps1 = self .start_pserver ()
@@ -92,13 +99,22 @@ def no_test_with_place(self):
92
99
FNULL = open (os .devnull , 'w' )
93
100
94
101
tr0_proc = subprocess .Popen (
95
- tr0_cmd .split (" " ), stdout = subprocess .PIPE , stderr = FNULL , env = env0 )
102
+ tr0_cmd .split (" " ),
103
+ stdout = subprocess .PIPE ,
104
+ stderr = subprocess .PIPE ,
105
+ env = env0 )
96
106
tr1_proc = subprocess .Popen (
97
- tr1_cmd .split (" " ), stdout = subprocess .PIPE , stderr = FNULL , env = env1 )
107
+ tr1_cmd .split (" " ),
108
+ stdout = subprocess .PIPE ,
109
+ stderr = subprocess .PIPE ,
110
+ env = env1 )
98
111
99
112
tr0_proc .wait ()
100
113
tr1_proc .wait ()
101
- loss_data0 = tr0_proc .stdout .read ()
114
+ out , err = tr0_proc .communicate ()
115
+ sys .stderr .write ('dist_stderr: %s\n ' % err )
116
+ loss_data0 = out
117
+ sys .stderr .write ('dist_loss: %s\n ' % loss_data0 )
102
118
lines = loss_data0 .split ("\n " )
103
119
dist_first_loss = eval (lines [0 ].replace (" " , "," ))[0 ]
104
120
dist_last_loss = eval (lines [1 ].replace (" " , "," ))[0 ]
0 commit comments