@@ -368,21 +368,19 @@ def get_pserver_program(self, endpoint):
368
368
else :
369
369
recv_inputs .append (single_trainer_var )
370
370
371
- # step3
372
- optimize_block = pserver_program .create_block (0 )
373
- # step 4
371
+ # step 3
374
372
# Create a union-find data structure from optimize ops,
375
373
# If two ops are connected, we could add these two ops
376
374
# into one set.
377
375
ufind = self ._create_ufind (self .optimize_ops )
378
- # step 4 .2
376
+ # step 3 .2
379
377
# Iterate through the ops and append optimize op which
380
378
# located on current pserver
381
379
opt_op_on_pserver = []
382
380
for _ , op in enumerate (self .optimize_ops ):
383
381
if self ._is_opt_op (op ) and self ._is_opt_op_on_pserver (endpoint , op ):
384
382
opt_op_on_pserver .append (op )
385
- # step 4 .3
383
+ # step 3 .3
386
384
# Iterate through the ops, and if an op and the optimize ops
387
385
# which located on current pserver are in one set, then
388
386
# append it into the sub program.
@@ -415,29 +413,30 @@ def __append_optimize_op__(op, block):
415
413
else :
416
414
self ._append_pserver_non_opt_ops (block , op )
417
415
418
- append_block = optimize_block
419
416
# append lr decay ops to the child block if exists
420
417
lr_ops = self ._get_lr_ops ()
421
418
if len (lr_ops ) > 0 :
419
+ lr_decay_block = pserver_program .create_block (
420
+ pserver_program .num_blocks - 1 )
422
421
for _ , op in enumerate (lr_ops ):
423
- self ._append_pserver_non_opt_ops (append_block , op )
424
-
425
- append_block = pserver_program .create_block (append_block .idx )
422
+ self ._append_pserver_non_opt_ops (lr_decay_block , op )
426
423
427
424
# append op to the current block
428
- per_opt_block = append_block
425
+ pre_block_idx = pserver_program . num_blocks - 1
429
426
for idx , opt_op in enumerate (opt_op_on_pserver ):
427
+ per_opt_block = pserver_program .create_block (pre_block_idx )
430
428
for _ , op in enumerate (self .optimize_ops ):
431
429
# optimizer is connected to itself
432
- if ufind .is_connected (op , opt_op ) and \
433
- op not in global_ops :
430
+ if ufind .is_connected (op , opt_op ) and op not in global_ops :
434
431
__append_optimize_op__ (op , per_opt_block )
435
- if idx == len (opt_op_on_pserver ) - 1 and global_ops :
436
- per_opt_block = pserver_program .create_block (append_block .idx )
437
432
438
433
# append global ops
439
- for glb_op in global_ops :
440
- __append_optimize_op__ (glb_op , per_opt_block )
434
+ opt_state_block = None
435
+ if global_ops :
436
+ opt_state_block = pserver_program .create_block (
437
+ pserver_program .num_blocks - 1 )
438
+ for glb_op in global_ops :
439
+ __append_optimize_op__ (glb_op , opt_state_block )
441
440
442
441
# NOT USED: single block version:
443
442
#
@@ -451,10 +450,10 @@ def __append_optimize_op__(op, block):
451
450
prefetch_block = None
452
451
if self .has_distributed_lookup_table :
453
452
pserver_index = self .pserver_endpoints .index (endpoint )
454
- self ._create_table_optimize_block (pserver_index , pserver_program ,
455
- append_block )
453
+ table_opt_block = self ._create_table_optimize_block (
454
+ pserver_index , pserver_program , pre_block_idx )
456
455
prefetch_block = self ._create_prefetch_block (
457
- pserver_index , pserver_program , optimize_block )
456
+ pserver_index , pserver_program , table_opt_block )
458
457
459
458
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will
460
459
# not be executed, so it's safe to use optimize_block to hold the place
@@ -470,7 +469,7 @@ def __append_optimize_op__(op, block):
470
469
inputs = {'X' : recv_inputs },
471
470
outputs = {},
472
471
attrs = {
473
- "OptimizeBlock" : optimize_block ,
472
+ "OptimizeBlock" : pserver_program . block ( 1 ) ,
474
473
"endpoint" : endpoint ,
475
474
"Fanin" : self .trainer_num ,
476
475
"PrefetchBlock" : prefetch_block
@@ -663,7 +662,7 @@ def _create_prefetch_block(self, pserver_index, pserver_program,
663
662
return prefetch_block
664
663
665
664
def _create_table_optimize_block (self , pserver_index , pserver_program ,
666
- append_block ):
665
+ pre_block_idx ):
667
666
def _clone_var (block , var , persistable = True ):
668
667
assert isinstance (var , Variable )
669
668
return block .create_var (
@@ -700,7 +699,7 @@ def _clone_var(block, var, persistable=True):
700
699
op for op in self .optimize_ops
701
700
if op .input ("Param" )[0 ] == self .table_name
702
701
][0 ]
703
- table_opt_block = pserver_program .create_block (append_block . idx )
702
+ table_opt_block = pserver_program .create_block (pre_block_idx )
704
703
# only support sgd now
705
704
assert table_opt_op .type == "sgd"
706
705
@@ -724,6 +723,8 @@ def _clone_var(block, var, persistable=True):
724
723
outputs = outputs ,
725
724
attrs = table_opt_op .attrs )
726
725
726
+ return table_opt_block
727
+
727
728
# ====================== private transpiler functions =====================
728
729
def _create_vars_from_blocklist (self ,
729
730
program ,
0 commit comments