25
25
LOOKUP_TABLE_GRAD_TYPE = "lookup_table_grad"
26
26
RPC_CLIENT_VAR_NAME = "RPC_CLIENT_VAR"
27
27
28
+ GLOBAL_BLOCK_IDX = 0
29
+
28
30
29
31
class VarBlock :
30
32
def __init__ (self , varname , offset , size ):
@@ -368,8 +370,8 @@ def get_pserver_program(self, endpoint):
368
370
else :
369
371
recv_inputs .append (single_trainer_var )
370
372
371
- # step3
372
- optimize_block = pserver_program . create_block ( 0 )
373
+ optimize_block = None
374
+
373
375
# step 4
374
376
# Create a union-find data structure from optimize ops,
375
377
# If two ops are connected, we could add these two ops
@@ -415,29 +417,34 @@ def __append_optimize_op__(op, block):
415
417
else :
416
418
self ._append_pserver_non_opt_ops (block , op )
417
419
418
- append_block = optimize_block
419
420
# append lr decay ops to the child block if exists
421
+ lr_decay_block = None
420
422
lr_ops = self ._get_lr_ops ()
421
423
if len (lr_ops ) > 0 :
424
+ lr_decay_block = pserver_program .create_block (GLOBAL_BLOCK_IDX )
422
425
for _ , op in enumerate (lr_ops ):
423
- self ._append_pserver_non_opt_ops (append_block , op )
424
-
425
- append_block = pserver_program .create_block (append_block .idx )
426
+ self ._append_pserver_non_opt_ops (lr_decay_block , op )
426
427
427
428
# append op to the current block
428
- per_opt_block = append_block
429
+ per_opt_block = None
430
+ pre_block_idx = GLOBAL_BLOCK_IDX
431
+ if lr_decay_block is not None :
432
+ pre_block_idx = lr_decay_block .idx
429
433
for idx , opt_op in enumerate (opt_op_on_pserver ):
434
+ per_opt_block = pserver_program .create_block (pre_block_idx )
435
+ if optimize_block is None :
436
+ optimize_block = per_opt_block
430
437
for _ , op in enumerate (self .optimize_ops ):
431
438
# optimizer is connected to itself
432
- if ufind .is_connected (op , opt_op ) and \
433
- op not in global_ops :
439
+ if ufind .is_connected (op , opt_op ) and op not in global_ops :
434
440
__append_optimize_op__ (op , per_opt_block )
435
- if idx == len (opt_op_on_pserver ) - 1 and global_ops :
436
- per_opt_block = pserver_program .create_block (append_block .idx )
437
441
438
442
# append global ops
443
+ opt_state_block = None
444
+ if global_ops :
445
+ opt_state_block = pserver_program .create_block (per_opt_block .idx )
439
446
for glb_op in global_ops :
440
- __append_optimize_op__ (glb_op , per_opt_block )
447
+ __append_optimize_op__ (glb_op , opt_state_block )
441
448
442
449
# NOT USED: single block version:
443
450
#
@@ -451,10 +458,11 @@ def __append_optimize_op__(op, block):
451
458
prefetch_block = None
452
459
if self .has_distributed_lookup_table :
453
460
pserver_index = self .pserver_endpoints .index (endpoint )
454
- self ._create_table_optimize_block (pserver_index , pserver_program ,
455
- append_block )
461
+ table_opt_block = self ._create_table_optimize_block (
462
+ pserver_index , pserver_program , opt_state_block or
463
+ pserver_program .global_block ())
456
464
prefetch_block = self ._create_prefetch_block (
457
- pserver_index , pserver_program , optimize_block )
465
+ pserver_index , pserver_program , table_opt_block )
458
466
459
467
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will
460
468
# not be executed, so it's safe to use optimize_block to hold the place
@@ -724,6 +732,8 @@ def _clone_var(block, var, persistable=True):
724
732
outputs = outputs ,
725
733
attrs = table_opt_op .attrs )
726
734
735
+ return table_opt_block
736
+
727
737
# ====================== private transpiler functions =====================
728
738
def _create_vars_from_blocklist (self ,
729
739
program ,
0 commit comments