@@ -371,22 +371,18 @@ def get_pserver_program(self, endpoint):
371
371
recv_inputs .append (single_trainer_var )
372
372
373
373
# step 3
374
- # each optimization op will has a optimize block
375
- optimize_block = None
376
-
377
- # step 4
378
374
# Create a union-find data structure from optimize ops,
379
375
# If two ops are connected, we could add these two ops
380
376
# into one set.
381
377
ufind = self ._create_ufind (self .optimize_ops )
382
- # step 4 .2
378
+ # step 3 .2
383
379
# Iterate through the ops and append optimize op which
384
380
# located on current pserver
385
381
opt_op_on_pserver = []
386
382
for _ , op in enumerate (self .optimize_ops ):
387
383
if self ._is_opt_op (op ) and self ._is_opt_op_on_pserver (endpoint , op ):
388
384
opt_op_on_pserver .append (op )
389
- # step 4 .3
385
+ # step 3 .3
390
386
# Iterate through the ops, and if an op and the optimize ops
391
387
# which located on current pserver are in one set, then
392
388
# append it into the sub program.
@@ -420,23 +416,17 @@ def __append_optimize_op__(op, block):
420
416
self ._append_pserver_non_opt_ops (block , op )
421
417
422
418
# append lr decay ops to the child block if exists
423
- lr_decay_block = None
424
419
lr_ops = self ._get_lr_ops ()
425
420
if len (lr_ops ) > 0 :
426
- lr_decay_block = pserver_program .create_block (GLOBAL_BLOCK_IDX )
421
+ lr_decay_block = pserver_program .create_block (
422
+ pserver_program .num_blocks - 1 )
427
423
for _ , op in enumerate (lr_ops ):
428
424
self ._append_pserver_non_opt_ops (lr_decay_block , op )
429
425
430
426
# append op to the current block
431
- per_opt_block = None
432
- pre_block_idx = GLOBAL_BLOCK_IDX
433
- if lr_decay_block is not None :
434
- pre_block_idx = lr_decay_block .idx
427
+ pre_block_idx = pserver_program .num_blocks - 1
435
428
for idx , opt_op in enumerate (opt_op_on_pserver ):
436
429
per_opt_block = pserver_program .create_block (pre_block_idx )
437
- if optimize_block is None :
438
- # first optimize block
439
- optimize_block = per_opt_block
440
430
for _ , op in enumerate (self .optimize_ops ):
441
431
# optimizer is connected to itself
442
432
if ufind .is_connected (op , opt_op ) and op not in global_ops :
@@ -445,9 +435,10 @@ def __append_optimize_op__(op, block):
445
435
# append global ops
446
436
opt_state_block = None
447
437
if global_ops :
448
- opt_state_block = pserver_program .create_block (per_opt_block .idx )
449
- for glb_op in global_ops :
450
- __append_optimize_op__ (glb_op , opt_state_block )
438
+ opt_state_block = pserver_program .create_block (
439
+ pserver_program .num_blocks - 1 )
440
+ for glb_op in global_ops :
441
+ __append_optimize_op__ (glb_op , opt_state_block )
451
442
452
443
# NOT USED: single block version:
453
444
#
@@ -481,7 +472,7 @@ def __append_optimize_op__(op, block):
481
472
inputs = {'X' : recv_inputs },
482
473
outputs = {},
483
474
attrs = {
484
- "OptimizeBlock" : optimize_block ,
475
+ "OptimizeBlock" : pserver_program . block ( 1 ) ,
485
476
"endpoint" : endpoint ,
486
477
"Fanin" : self .trainer_num ,
487
478
"PrefetchBlock" : prefetch_block
0 commit comments