|
27 | 27 | import os
|
28 | 28 | import sys
|
29 | 29 | import signal
|
| 30 | +from test_dist_base import TestDistRunnerBase, runtime_main |
30 | 31 |
|
31 | 32 | # Fix seed for test
|
32 | 33 | fluid.default_startup_program().random_seed = 1
|
@@ -196,161 +197,52 @@ def squeeze_excitation(self, input, num_channels, reduction_ratio):
|
196 | 197 | return scale
|
197 | 198 |
|
198 | 199 |
|
199 |
| -def get_model(batch_size): |
200 |
| - # Input data |
201 |
| - image = fluid.layers.data(name="data", shape=[3, 224, 224], dtype='float32') |
202 |
| - label = fluid.layers.data(name="int64", shape=[1], dtype='int64') |
| 200 | +class DistSeResneXt2x2(TestDistRunnerBase): |
| 201 | + def get_model(self, batch_size=2): |
| 202 | + # Input data |
| 203 | + image = fluid.layers.data( |
| 204 | + name="data", shape=[3, 224, 224], dtype='float32') |
| 205 | + label = fluid.layers.data(name="int64", shape=[1], dtype='int64') |
203 | 206 |
|
204 |
| - # Train program |
205 |
| - model = SE_ResNeXt(layers=50) |
206 |
| - out = model.net(input=image, class_dim=102) |
207 |
| - cost = fluid.layers.cross_entropy(input=out, label=label) |
| 207 | + # Train program |
| 208 | + model = SE_ResNeXt(layers=50) |
| 209 | + out = model.net(input=image, class_dim=102) |
| 210 | + cost = fluid.layers.cross_entropy(input=out, label=label) |
208 | 211 |
|
209 |
| - avg_cost = fluid.layers.mean(x=cost) |
210 |
| - acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) |
211 |
| - acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) |
| 212 | + avg_cost = fluid.layers.mean(x=cost) |
| 213 | + acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) |
| 214 | + acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) |
212 | 215 |
|
213 |
| - # Evaluator |
214 |
| - test_program = fluid.default_main_program().clone(for_test=True) |
| 216 | + # Evaluator |
| 217 | + test_program = fluid.default_main_program().clone(for_test=True) |
215 | 218 |
|
216 |
| - # Optimization |
217 |
| - total_images = 6149 # flowers |
218 |
| - epochs = [30, 60, 90] |
219 |
| - step = int(total_images / batch_size + 1) |
| 219 | + # Optimization |
| 220 | + total_images = 6149 # flowers |
| 221 | + epochs = [30, 60, 90] |
| 222 | + step = int(total_images / batch_size + 1) |
220 | 223 |
|
221 |
| - bd = [step * e for e in epochs] |
222 |
| - base_lr = 0.1 |
223 |
| - lr = [] |
224 |
| - lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)] |
| 224 | + bd = [step * e for e in epochs] |
| 225 | + base_lr = 0.1 |
| 226 | + lr = [] |
| 227 | + lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)] |
225 | 228 |
|
226 |
| - optimizer = fluid.optimizer.Momentum( |
227 |
| - # FIXME(typhoonzero): add back LR decay once ParallelExecutor fixed. |
228 |
| - #learning_rate=fluid.layers.piecewise_decay( |
229 |
| - # boundaries=bd, values=lr), |
230 |
| - learning_rate=base_lr, |
231 |
| - momentum=0.9, |
232 |
| - regularization=fluid.regularizer.L2Decay(1e-4)) |
233 |
| - optimizer.minimize(avg_cost) |
| 229 | + optimizer = fluid.optimizer.Momentum( |
| 230 | + # FIXME(typhoonzero): add back LR decay once ParallelExecutor fixed. |
| 231 | + #learning_rate=fluid.layers.piecewise_decay( |
| 232 | + # boundaries=bd, values=lr), |
| 233 | + learning_rate=base_lr, |
| 234 | + momentum=0.9, |
| 235 | + regularization=fluid.regularizer.L2Decay(1e-4)) |
| 236 | + optimizer.minimize(avg_cost) |
234 | 237 |
|
235 |
| - # Reader |
236 |
| - train_reader = paddle.batch( |
237 |
| - paddle.dataset.flowers.train(), batch_size=batch_size) |
238 |
| - test_reader = paddle.batch( |
239 |
| - paddle.dataset.flowers.test(use_xmap=False), batch_size=batch_size) |
| 238 | + # Reader |
| 239 | + train_reader = paddle.batch( |
| 240 | + paddle.dataset.flowers.train(), batch_size=batch_size) |
| 241 | + test_reader = paddle.batch( |
| 242 | + paddle.dataset.flowers.test(use_xmap=False), batch_size=batch_size) |
240 | 243 |
|
241 |
| - return test_program, avg_cost, train_reader, test_reader, acc_top1, out |
242 |
| - |
243 |
| - |
244 |
| -def get_transpiler(trainer_id, main_program, pserver_endpoints, trainers): |
245 |
| - t = fluid.DistributeTranspiler() |
246 |
| - t.transpile( |
247 |
| - trainer_id=trainer_id, |
248 |
| - program=main_program, |
249 |
| - pservers=pserver_endpoints, |
250 |
| - trainers=trainers) |
251 |
| - return t |
252 |
| - |
253 |
| - |
254 |
| -class DistSeResneXt2x2: |
255 |
| - def run_pserver(self, pserver_endpoints, trainers, current_endpoint, |
256 |
| - trainer_id): |
257 |
| - get_model(batch_size=2) |
258 |
| - t = get_transpiler(trainer_id, |
259 |
| - fluid.default_main_program(), pserver_endpoints, |
260 |
| - trainers) |
261 |
| - pserver_prog = t.get_pserver_program(current_endpoint) |
262 |
| - startup_prog = t.get_startup_program(current_endpoint, pserver_prog) |
263 |
| - place = fluid.CPUPlace() |
264 |
| - exe = fluid.Executor(place) |
265 |
| - exe.run(startup_prog) |
266 |
| - exe.run(pserver_prog) |
267 |
| - |
268 |
| - def _wait_ps_ready(self, pid): |
269 |
| - retry_times = 20 |
270 |
| - while True: |
271 |
| - assert retry_times >= 0, "wait ps ready failed" |
272 |
| - time.sleep(3) |
273 |
| - print("waiting ps ready: ", pid) |
274 |
| - try: |
275 |
| - # the listen_and_serv_op would touch a file which contains the listen port |
276 |
| - # on the /tmp directory until it was ready to process all the RPC call. |
277 |
| - os.stat("/tmp/paddle.%d.port" % pid) |
278 |
| - return |
279 |
| - except os.error: |
280 |
| - retry_times -= 1 |
281 |
| - |
282 |
| - def run_trainer(self, place, endpoints, trainer_id, trainers, is_dist=True): |
283 |
| - test_program, avg_cost, train_reader, test_reader, batch_acc, predict = get_model( |
284 |
| - batch_size=2) |
285 |
| - if is_dist: |
286 |
| - t = get_transpiler(trainer_id, |
287 |
| - fluid.default_main_program(), endpoints, |
288 |
| - trainers) |
289 |
| - trainer_prog = t.get_trainer_program() |
290 |
| - else: |
291 |
| - trainer_prog = fluid.default_main_program() |
292 |
| - |
293 |
| - startup_exe = fluid.Executor(place) |
294 |
| - startup_exe.run(fluid.default_startup_program()) |
295 |
| - |
296 |
| - strategy = fluid.ExecutionStrategy() |
297 |
| - strategy.num_threads = 1 |
298 |
| - strategy.allow_op_delay = False |
299 |
| - exe = fluid.ParallelExecutor( |
300 |
| - True, loss_name=avg_cost.name, exec_strategy=strategy) |
301 |
| - |
302 |
| - feed_var_list = [ |
303 |
| - var for var in trainer_prog.global_block().vars.values() |
304 |
| - if var.is_data |
305 |
| - ] |
306 |
| - |
307 |
| - feeder = fluid.DataFeeder(feed_var_list, place) |
308 |
| - reader_generator = test_reader() |
309 |
| - |
310 |
| - data = next(reader_generator) |
311 |
| - first_loss, = exe.run(fetch_list=[avg_cost.name], |
312 |
| - feed=feeder.feed(data)) |
313 |
| - print(first_loss) |
314 |
| - |
315 |
| - for i in six.moves.xrange(5): |
316 |
| - data = next(reader_generator) |
317 |
| - loss, = exe.run(fetch_list=[avg_cost.name], feed=feeder.feed(data)) |
318 |
| - |
319 |
| - data = next(reader_generator) |
320 |
| - last_loss, = exe.run(fetch_list=[avg_cost.name], feed=feeder.feed(data)) |
321 |
| - print(last_loss) |
322 |
| - |
323 |
| - |
324 |
| -def main(role="pserver", |
325 |
| - endpoints="127.0.0.1:9123", |
326 |
| - trainer_id=0, |
327 |
| - current_endpoint="127.0.0.1:9123", |
328 |
| - trainers=1, |
329 |
| - is_dist=True): |
330 |
| - model = DistSeResneXt2x2() |
331 |
| - if role == "pserver": |
332 |
| - model.run_pserver(endpoints, trainers, current_endpoint, trainer_id) |
333 |
| - else: |
334 |
| - p = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( |
335 |
| - ) else fluid.CPUPlace() |
336 |
| - model.run_trainer(p, endpoints, trainer_id, trainers, is_dist) |
| 244 | + return test_program, avg_cost, train_reader, test_reader, acc_top1, out |
337 | 245 |
|
338 | 246 |
|
339 | 247 | if __name__ == "__main__":
|
340 |
| - if len(sys.argv) != 7: |
341 |
| - print( |
342 |
| - "Usage: python dist_se_resnext.py [pserver/trainer] [endpoints] [trainer_id] [current_endpoint] [trainers] [is_dist]" |
343 |
| - ) |
344 |
| - role = sys.argv[1] |
345 |
| - endpoints = sys.argv[2] |
346 |
| - trainer_id = int(sys.argv[3]) |
347 |
| - current_endpoint = sys.argv[4] |
348 |
| - trainers = int(sys.argv[5]) |
349 |
| - is_dist = True if sys.argv[6] == "TRUE" else False |
350 |
| - main( |
351 |
| - role=role, |
352 |
| - endpoints=endpoints, |
353 |
| - trainer_id=trainer_id, |
354 |
| - current_endpoint=current_endpoint, |
355 |
| - trainers=trainers, |
356 |
| - is_dist=is_dist) |
| 248 | + runtime_main(DistSeResneXt2x2) |
0 commit comments