20
20
from paddle .fluid .incubate .fleet .base .role_maker import UserDefinedRoleMaker
21
21
from paddle .fluid .incubate .fleet .base .role_maker import UserDefinedCollectiveRoleMaker
22
22
from paddle .fluid .incubate .fleet .base .role_maker import Role
23
+ import paddle .fluid .incubate .fleet .base .role_maker as role_maker
23
24
from paddle .fluid .incubate .fleet .parameter_server .distribute_transpiler import fleet
24
25
from paddle .fluid .incubate .fleet .parameter_server .distribute_transpiler import TranspilerOptimizer
25
26
from paddle .fluid .incubate .fleet .collective import CollectiveOptimizer
27
+ from dist_simnet_bow import train_network
26
28
27
29
28
30
class DistributeTranspilerConfigTest (unittest .TestCase ):
@@ -97,6 +99,30 @@ def testInvalidInputs(self):
97
99
main_program = compiled_prog )
98
100
self .assertRaises (Exception , fleet ._transpile , "config" )
99
101
102
+ def set_program (self , avg_cost , strategy ):
103
+ optimizer = fluid .optimizer .SGD (0.1 )
104
+ optimizer = fleet .distributed_optimizer (optimizer , strategy )
105
+ optimizer .minimize (avg_cost )
106
+
107
+ def test_init_role (self ):
108
+ role = role_maker .UserDefinedRoleMaker (
109
+ current_id = 0 ,
110
+ role = role_maker .Role .SERVER ,
111
+ worker_num = 2 ,
112
+ server_endpoints = ["127.0.0.1:36011" , "127.0.0.1:36012" ])
113
+ # for test optimizer without init(role)
114
+ # fleet.init(role)
115
+ batch_size = 128
116
+ is_sparse = True
117
+ is_distribute = False
118
+ strategy = DistributeTranspilerConfig ()
119
+ strategy .sync_mode = False
120
+ strategy .geo_sgd_mode = True
121
+ strategy .geo_sgd_need_push_nums = 5
122
+ avg_cost , _ , _ = train_network (batch_size , is_distribute , is_sparse )
123
+
124
+ self .assertRaises (Exception , self .set_program , avg_cost , strategy )
125
+
100
126
101
127
class TranspilerOptimizerTest (unittest .TestCase ):
102
128
def testInvalidInputs (self ):
@@ -124,7 +150,7 @@ def createRoleMaker(self,
124
150
125
151
def testRoleMaker (self ):
126
152
self .createRoleMaker ()
127
- ## test all invalid server_endpoints
153
+ # test all invalid server_endpoints
128
154
self .assertRaises (
129
155
Exception , self .createRoleMaker ,
130
156
server_endpoints = None ) # server_endpoints must be as list
@@ -140,7 +166,7 @@ def testRoleMaker(self):
140
166
self .createRoleMaker ,
141
167
server_endpoints = ["127.0.0.1:8080" , "127.0.0.1:8080" ]
142
168
) # element in server_endpoints can't be duplicate
143
- ## test all invalid current_id
169
+ # test all invalid current_id
144
170
self .assertRaises (
145
171
Exception , self .createRoleMaker ,
146
172
current_id = "0" ) # current_id must be as int
@@ -154,14 +180,14 @@ def testRoleMaker(self):
154
180
role = Role .SERVER ,
155
181
server_endpoints = ["127.0.0.1:8080" ]
156
182
) # if role is server, current_id must be less than len(server_endpoints)
157
- ## test all invalid worker_num
183
+ # test all invalid worker_num
158
184
self .assertRaises (
159
185
Exception , self .createRoleMaker ,
160
186
worker_num = "1" ) # worker_num must be as int
161
187
self .assertRaises (
162
188
Exception , self .createRoleMaker ,
163
189
worker_num = 0 ) # worker_num must be greater than 0
164
- ## test all invalid role
190
+ # test all invalid role
165
191
self .assertRaises (
166
192
Exception , self .createRoleMaker ,
167
193
role = 3 ) # role must be as Role(Role.WORKER=1, Role.SERVER=2)
@@ -174,7 +200,7 @@ def createRoleMaker(self, current_id=0,
174
200
175
201
def testRoleMaker (self ):
176
202
self .createRoleMaker ()
177
- ## test all invalid worker_endpoints
203
+ # test all invalid worker_endpoints
178
204
self .assertRaises (
179
205
Exception , self .createRoleMaker ,
180
206
worker_endpoints = None ) # worker_endpoints must be as list
@@ -190,7 +216,7 @@ def testRoleMaker(self):
190
216
self .createRoleMaker ,
191
217
worker_endpoints = ["127.0.0.1:8080" , "127.0.0.1:8080" ]
192
218
) # element in worker_endpoints can't be duplicate
193
- ## test all invalid current_id
219
+ # test all invalid current_id
194
220
self .assertRaises (
195
221
Exception , self .createRoleMaker ,
196
222
current_id = "0" ) # current_id must be as int
0 commit comments