@@ -515,21 +515,20 @@ def __op_have_grad_input__(op):
515
515
grad_to_block_id , None )
516
516
517
517
# process distributed lookup_table
518
- prefetch_block = None
518
+ prefetch_var_name_to_block_id = []
519
519
if self .has_distributed_lookup_table :
520
520
pserver_index = self .pserver_endpoints .index (endpoint )
521
521
table_opt_block = self ._create_table_optimize_block (
522
522
pserver_index , pserver_program , pre_block_idx , grad_to_block_id )
523
- prefetch_block = self ._create_prefetch_block (
523
+ prefetch_var_name_to_block_id = self ._create_prefetch_block (
524
524
pserver_index , pserver_program , table_opt_block )
525
525
526
526
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will
527
527
# not be executed, so it's safe to use optimize_block to hold the place
528
528
if self .has_distributed_lookup_table :
529
- assert prefetch_block is not None
529
+ assert len ( prefetch_var_name_to_block_id ) > 0
530
530
else :
531
- assert prefetch_block is None
532
- prefetch_block = pserver_program .global_block ()
531
+ assert len (prefetch_var_name_to_block_id ) == 0
533
532
534
533
# step5 append the listen_and_serv op
535
534
pserver_program .global_block ().append_op (
@@ -540,7 +539,7 @@ def __op_have_grad_input__(op):
540
539
"OptimizeBlock" : pserver_program .block (1 ),
541
540
"endpoint" : endpoint ,
542
541
"Fanin" : self .trainer_num ,
543
- "PrefetchBlock " : prefetch_block ,
542
+ "prefetch_var_name_to_block_id " : prefetch_var_name_to_block_id ,
544
543
"sync_mode" : self .sync_mode ,
545
544
"grad_to_block_id" : grad_to_block_id
546
545
})
@@ -608,8 +607,15 @@ def _get_splited_name_and_shape(varname):
608
607
def _replace_lookup_table_op_with_prefetch (self , program ,
609
608
pserver_endpoints ):
610
609
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
611
- self .prefetch_input_vars = None
612
- self .prefetch_output_vars = None
610
+ # self.all_prefetch_input_vars =
611
+ # [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1]
612
+ # [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]]
613
+ self .all_prefetch_input_vars = []
614
+
615
+ # self.all_prefetch_input_vars =
616
+ # [[var0_prefetch_in_pserver0, var0_prefetch_in_pserver1]
617
+ # [var1_prefetch_in_pserver0, var1_prefetch_in_pserver1]]
618
+ self .all_prefetch_output_vars = []
613
619
614
620
continue_search_lookup_table_op = True
615
621
while continue_search_lookup_table_op :
@@ -623,18 +629,19 @@ def _replace_lookup_table_op_with_prefetch(self, program,
623
629
ids_name = op .input ("Ids" )
624
630
out_name = op .output ("Out" )
625
631
626
- if self .prefetch_input_vars is None :
627
- ids_var = program .global_block ().vars [ids_name [0 ]]
628
- self .prefetch_input_vars = self .create_splited_vars (
629
- source_var = ids_var ,
630
- block = program .global_block (),
631
- tag = "_prefetch_in_" )
632
- if self .prefetch_output_vars is None :
633
- out_var = program .global_block ().vars [out_name [0 ]]
634
- self .prefetch_output_vars = self .create_splited_vars (
635
- source_var = out_var ,
636
- block = program .global_block (),
637
- tag = "_prefetch_out_" )
632
+ ids_var = program .global_block ().vars [ids_name [0 ]]
633
+ prefetch_input_vars = self .create_splited_vars (
634
+ source_var = ids_var ,
635
+ block = program .global_block (),
636
+ tag = "_prefetch_in_" )
637
+ self .all_prefetch_input_vars .append (prefetch_input_vars )
638
+
639
+ out_var = program .global_block ().vars [out_name [0 ]]
640
+ prefetch_output_vars = self .create_splited_vars (
641
+ source_var = out_var ,
642
+ block = program .global_block (),
643
+ tag = "_prefetch_out_" )
644
+ self .all_prefetch_output_vars .append (prefetch_output_vars )
638
645
639
646
# insert split_ids_op
640
647
program .global_block ().insert_op (
@@ -646,14 +653,14 @@ def _replace_lookup_table_op_with_prefetch(self, program,
646
653
for varname in ids_name
647
654
]
648
655
},
649
- outputs = {"Out" : self . prefetch_input_vars })
656
+ outputs = {"Out" : prefetch_input_vars })
650
657
651
658
# insert prefetch_op
652
659
program .global_block ().insert_op (
653
660
index = op_index + 1 ,
654
661
type = "prefetch" ,
655
- inputs = {'X' : self . prefetch_input_vars },
656
- outputs = {"Out" : self . prefetch_output_vars },
662
+ inputs = {'X' : prefetch_input_vars },
663
+ outputs = {"Out" : prefetch_output_vars },
657
664
attrs = {
658
665
"epmap" : pserver_endpoints ,
659
666
RPC_OP_ROLE_ATTR_NAME : RPC_OP_ROLE_ATTR_VALUE
@@ -663,7 +670,7 @@ def _replace_lookup_table_op_with_prefetch(self, program,
663
670
program .global_block ().insert_op (
664
671
index = op_index + 2 ,
665
672
type = "concat" ,
666
- inputs = {'X' : self . prefetch_output_vars },
673
+ inputs = {'X' : prefetch_output_vars },
667
674
outputs = {
668
675
"Out" : [
669
676
program .global_block ().vars [varname ]
@@ -709,30 +716,34 @@ def _create_prefetch_block(self, pserver_index, pserver_program,
709
716
optimize_block ):
710
717
# STEP: create prefetch block
711
718
table_var = pserver_program .global_block ().vars [self .table_name ]
712
- prefetch_block = pserver_program .create_block (optimize_block .idx )
713
- trainer_ids = self .prefetch_input_vars [pserver_index ]
714
- pserver_ids = pserver_program .global_block ().create_var (
715
- name = trainer_ids .name ,
716
- type = trainer_ids .type ,
717
- shape = trainer_ids .shape ,
718
- dtype = trainer_ids .dtype )
719
- trainer_out = self .prefetch_output_vars [pserver_index ]
720
- pserver_out = pserver_program .global_block ().create_var (
721
- name = trainer_out .name ,
722
- type = trainer_out .type ,
723
- shape = trainer_out .shape ,
724
- dtype = trainer_out .dtype )
725
- prefetch_block .append_op (
726
- type = "lookup_sparse_table" ,
727
- inputs = {'Ids' : pserver_ids ,
728
- "W" : table_var },
729
- outputs = {"Out" : pserver_out },
730
- attrs = {
731
- "is_sparse" : True , # has no effect on lookup_table op
732
- "is_distributed" : True ,
733
- "padding_idx" : - 1
734
- })
735
- return prefetch_block
719
+ prefetch_var_name_to_block_id = []
720
+ for index in range (len (self .all_prefetch_input_vars )):
721
+ prefetch_block = pserver_program .create_block (optimize_block .idx )
722
+ trainer_ids = self .all_prefetch_input_vars [index ][pserver_index ]
723
+ pserver_ids = pserver_program .global_block ().create_var (
724
+ name = trainer_ids .name ,
725
+ type = trainer_ids .type ,
726
+ shape = trainer_ids .shape ,
727
+ dtype = trainer_ids .dtype )
728
+ trainer_out = self .all_prefetch_output_vars [index ][pserver_index ]
729
+ pserver_out = pserver_program .global_block ().create_var (
730
+ name = trainer_out .name ,
731
+ type = trainer_out .type ,
732
+ shape = trainer_out .shape ,
733
+ dtype = trainer_out .dtype )
734
+ prefetch_block .append_op (
735
+ type = "lookup_sparse_table" ,
736
+ inputs = {'Ids' : pserver_ids ,
737
+ "W" : table_var },
738
+ outputs = {"Out" : pserver_out },
739
+ attrs = {
740
+ "is_sparse" : True , # has no effect on lookup_table op
741
+ "is_distributed" : True ,
742
+ "padding_idx" : - 1
743
+ })
744
+ prefetch_var_name_to_block_id .append (trainer_ids .name + ":" + str (
745
+ prefetch_block .idx ))
746
+ return prefetch_var_name_to_block_id
736
747
737
748
def _create_table_optimize_block (self , pserver_index , pserver_program ,
738
749
pre_block_idx , grad_to_block_id ):
0 commit comments