@@ -358,7 +358,7 @@ def get_pserver_program(self, endpoint):
358
358
type = v .type ,
359
359
dtype = v .dtype ,
360
360
shape = v .shape )
361
- if self .trainer_num > 1 :
361
+ if self .sync_mode and self . trainer_num > 1 :
362
362
for trainer_id in xrange (self .trainer_num ):
363
363
var = pserver_program .global_block ().create_var (
364
364
name = "%s.trainer_%d" % (orig_var_name , trainer_id ),
@@ -688,17 +688,6 @@ def _clone_var(block, var, persistable=True):
688
688
self .table_name )],
689
689
persistable = False )
690
690
691
- # create grad vars in pserver program
692
- table_grad_var = self .table_param_grad [1 ]
693
- table_grad_list = [
694
- pserver_program .global_block ().create_var (
695
- name = "%s.trainer_%d.pserver_%d" %
696
- (table_grad_var .name , index , pserver_index ),
697
- type = table_grad_var .type ,
698
- shape = table_grad_var .shape ,
699
- dtype = table_grad_var .dtype ) for index in range (self .trainer_num )
700
- ]
701
-
702
691
# create table optimize block in pserver program
703
692
table_opt_op = [
704
693
op for op in self .optimize_ops
@@ -708,11 +697,24 @@ def _clone_var(block, var, persistable=True):
708
697
# only support sgd now
709
698
assert table_opt_op .type == "sgd"
710
699
711
- # append sum op for table_grad_list
712
- table_opt_block .append_op (
713
- type = "sum" ,
714
- inputs = {"X" : table_grad_list },
715
- outputs = {"Out" : [grad_var ]})
700
+ if self .sync_mode :
701
+ # create grad vars in pserver program
702
+ table_grad_var = self .table_param_grad [1 ]
703
+ table_grad_list = [
704
+ pserver_program .global_block ().create_var (
705
+ name = "%s.trainer_%d.pserver_%d" %
706
+ (table_grad_var .name , index , pserver_index ),
707
+ type = table_grad_var .type ,
708
+ shape = table_grad_var .shape ,
709
+ dtype = table_grad_var .dtype )
710
+ for index in range (self .trainer_num )
711
+ ]
712
+
713
+ # append sum op for table_grad_list
714
+ table_opt_block .append_op (
715
+ type = "sum" ,
716
+ inputs = {"X" : table_grad_list },
717
+ outputs = {"Out" : [grad_var ]})
716
718
717
719
lr_var = pserver_program .global_block ().vars [table_opt_op .input (
718
720
"LearningRate" )[0 ]]
@@ -751,7 +753,7 @@ def _create_vars_from_blocklist(self,
751
753
for varname , splited in block_map .iteritems ():
752
754
orig_var = program .global_block ().var (varname )
753
755
if len (splited ) == 1 :
754
- if add_trainer_suffix :
756
+ if self . sync_mode and add_trainer_suffix :
755
757
new_var_name = "%s.trainer_%d" % \
756
758
(orig_var .name , self .trainer_id )
757
759
program .global_block ().rename_var (varname , new_var_name )
@@ -775,7 +777,7 @@ def _create_vars_from_blocklist(self,
775
777
if len (orig_shape ) >= 2 :
776
778
splited_shape .extend (orig_shape [1 :])
777
779
new_var_name = ""
778
- if add_trainer_suffix :
780
+ if self . sync_mode and add_trainer_suffix :
779
781
new_var_name = "%s.block%d.trainer_%d" % \
780
782
(varname , i , self .trainer_id )
781
783
else :
@@ -907,7 +909,7 @@ def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
907
909
pserver_block .vars [self ._orig_varname (grad_block .name )]
908
910
grad_to_block_id .append (merged_var .name + ":" + str (
909
911
optimize_block .idx ))
910
- if self .trainer_num > 1 :
912
+ if self .sync_mode and self . trainer_num > 1 :
911
913
vars2merge = []
912
914
for i in xrange (self .trainer_num ):
913
915
per_trainer_name = "%s.trainer_%d" % \
@@ -925,6 +927,7 @@ def _append_pserver_ops(self, optimize_block, opt_op, endpoint,
925
927
inputs = {"X" : merged_var },
926
928
outputs = {"Out" : merged_var },
927
929
attrs = {"scale" : 1.0 / float (self .trainer_num )})
930
+
928
931
new_inputs [key ] = merged_var
929
932
elif key == "Param" :
930
933
# param is already created on global program
0 commit comments