@@ -273,15 +273,25 @@ def transpile(self,
273
273
if param_grad [0 ].name == self .table_name
274
274
][0 ]
275
275
table_grad_var = self .table_param_grad [1 ]
276
- self .table_grad_list = [
277
- program .global_block ().create_var (
278
- name = "%s.trainer_%d.pserver_%d" %
279
- (table_grad_var .name , trainer_id , index ),
280
- type = table_grad_var .type ,
281
- shape = table_grad_var .shape ,
282
- dtype = table_grad_var .dtype )
283
- for index in range (len (self .pserver_endpoints ))
284
- ]
276
+ if self .sync_mode :
277
+ self .trainer_side_table_grad_list = [
278
+ program .global_block ().create_var (
279
+ name = "%s.trainer_%d.pserver_%d" %
280
+ (table_grad_var .name , trainer_id , index ),
281
+ type = table_grad_var .type ,
282
+ shape = table_grad_var .shape ,
283
+ dtype = table_grad_var .dtype )
284
+ for index in range (len (self .pserver_endpoints ))
285
+ ]
286
+ else :
287
+ self .trainer_side_table_grad_list = [
288
+ program .global_block ().create_var (
289
+ name = "%s.pserver_%d" % (table_grad_var .name , index ),
290
+ type = table_grad_var .type ,
291
+ shape = table_grad_var .shape ,
292
+ dtype = table_grad_var .dtype )
293
+ for index in range (len (self .pserver_endpoints ))
294
+ ]
285
295
286
296
grad_blocks = split_dense_variable (grad_list , len (pserver_endpoints ))
287
297
param_blocks = split_dense_variable (param_list , len (pserver_endpoints ))
@@ -400,7 +410,8 @@ def transpile(self,
400
410
attrs = {"axis" : 0 })
401
411
402
412
if self .has_distributed_lookup_table :
403
- self ._replace_lookup_table_op_with_prefetch (program , eplist )
413
+ self ._replace_lookup_table_op_with_prefetch (program ,
414
+ pserver_endpoints )
404
415
self ._split_table_grad_and_add_send_vars (program , pserver_endpoints )
405
416
406
417
def get_trainer_program (self ):
@@ -537,7 +548,7 @@ def __append_optimize_op__(op, block, grad_to_block_id):
537
548
if self .has_distributed_lookup_table :
538
549
pserver_index = self .pserver_endpoints .index (endpoint )
539
550
table_opt_block = self ._create_table_optimize_block (
540
- pserver_index , pserver_program , pre_block_idx )
551
+ pserver_index , pserver_program , pre_block_idx , grad_to_block_id )
541
552
prefetch_block = self ._create_prefetch_block (
542
553
pserver_index , pserver_program , table_opt_block )
543
554
@@ -621,7 +632,8 @@ def _get_splited_name_and_shape(varname):
621
632
return s_prog
622
633
623
634
# transpiler function for dis lookup_table
624
- def _replace_lookup_table_op_with_prefetch (self , program , eplist ):
635
+ def _replace_lookup_table_op_with_prefetch (self , program ,
636
+ pserver_endpoints ):
625
637
# 1. replace lookup_table_op with split_ids_op -> prefetch_op -> sum_op
626
638
self .prefetch_input_vars = None
627
639
self .prefetch_output_vars = None
@@ -670,7 +682,7 @@ def _replace_lookup_table_op_with_prefetch(self, program, eplist):
670
682
inputs = {'X' : self .prefetch_input_vars },
671
683
outputs = {"Out" : self .prefetch_output_vars },
672
684
attrs = {
673
- "epmap" : eplist ,
685
+ "epmap" : pserver_endpoints ,
674
686
RPC_OP_ROLE_ATTR_NAME : RPC_OP_ROLE_ATTR_VALUE
675
687
})
676
688
@@ -707,11 +719,11 @@ def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints):
707
719
inputs = {
708
720
'Ids' : [program .global_block ().vars [table_grad_name ]]
709
721
},
710
- outputs = {"Out" : self .table_grad_list })
722
+ outputs = {"Out" : self .trainer_side_table_grad_list })
711
723
program .global_block ().insert_op (
712
724
index = op_index + 2 ,
713
725
type = "send_vars" ,
714
- inputs = {'X' : self .table_grad_list },
726
+ inputs = {'X' : self .trainer_side_table_grad_list },
715
727
outputs = {},
716
728
attrs = {
717
729
"sync_send" : True ,
@@ -750,16 +762,7 @@ def _create_prefetch_block(self, pserver_index, pserver_program,
750
762
return prefetch_block
751
763
752
764
def _create_table_optimize_block (self , pserver_index , pserver_program ,
753
- pre_block_idx ):
754
- def _clone_var (block , var , persistable = True ):
755
- assert isinstance (var , Variable )
756
- return block .create_var (
757
- name = var .name ,
758
- shape = var .shape ,
759
- dtype = var .dtype ,
760
- type = var .type ,
761
- persistable = persistable )
762
-
765
+ pre_block_idx , grad_to_block_id ):
763
766
# STEP: create table optimize block
764
767
# create table param and grad var in pserver program
765
768
origin_param_var = self .origin_program .global_block ().vars [
@@ -770,11 +773,11 @@ def _clone_var(block, var, persistable=True):
770
773
dtype = origin_param_var .dtype ,
771
774
type = core .VarDesc .VarType .SELECTED_ROWS ,
772
775
persistable = True )
773
- grad_var = _clone_var (
774
- pserver_program .global_block (),
776
+ # parameter must be selected rows
777
+ param_var .desc .set_type (core .VarDesc .VarType .SELECTED_ROWS )
778
+ grad_var = pserver_program .global_block ().clone_variable (
775
779
self .origin_program .global_block ().vars [grad_var_name (
776
- self .table_name )],
777
- persistable = False )
780
+ self .table_name )])
778
781
779
782
# create table optimize block in pserver program
780
783
table_opt_op = [
@@ -788,7 +791,7 @@ def _clone_var(block, var, persistable=True):
788
791
if self .sync_mode :
789
792
# create grad vars in pserver program
790
793
table_grad_var = self .table_param_grad [1 ]
791
- table_grad_list = [
794
+ pserver_side_table_grad_list = [
792
795
pserver_program .global_block ().create_var (
793
796
name = "%s.trainer_%d.pserver_%d" %
794
797
(table_grad_var .name , index , pserver_index ),
@@ -798,11 +801,21 @@ def _clone_var(block, var, persistable=True):
798
801
for index in range (self .trainer_num )
799
802
]
800
803
801
- # append sum op for table_grad_list
804
+ # append sum op for pserver_side_table_grad_list
802
805
table_opt_block .append_op (
803
806
type = "sum" ,
804
- inputs = {"X" : table_grad_list },
807
+ inputs = {"X" : pserver_side_table_grad_list },
805
808
outputs = {"Out" : [grad_var ]})
809
+ else :
810
+ # in async_mode, for table gradient, it also need to be splited to each parameter server
811
+ origin_grad_name = grad_var .name
812
+ splited_grad_name = self .trainer_side_table_grad_list [
813
+ pserver_index ].name
814
+ if not splited_grad_name .startswith (origin_grad_name ):
815
+ raise ValueError ("origin_grad_var: " + splited_grad_name +
816
+ " grad_var:" + grad_var .name )
817
+ grad_var = pserver_program .global_block ().rename_var (
818
+ origin_grad_name , splited_grad_name )
806
819
807
820
lr_var = pserver_program .global_block ().vars [table_opt_op .input (
808
821
"LearningRate" )[0 ]]
@@ -818,6 +831,9 @@ def _clone_var(block, var, persistable=True):
818
831
outputs = outputs ,
819
832
attrs = table_opt_op .attrs )
820
833
834
+ # add table parameter gradient and it's block id to grad_to_block_id
835
+ grad_to_block_id .append (grad_var .name + ":" + str (table_opt_block .idx ))
836
+
821
837
return table_opt_block
822
838
823
839
# ====================== private transpiler functions =====================
0 commit comments