@@ -276,20 +276,25 @@ def get_pserver_program(self, endpoint):
276
276
suff_idx = v .name .find (".trainer_" )
277
277
if suff_idx >= 0 :
278
278
orig_var_name = v .name [:suff_idx ]
279
- pserver_program .global_block ().create_var (
279
+ else :
280
+ orig_var_name = v .name
281
+ single_trainer_var = pserver_program .global_block ().create_var (
280
282
name = orig_var_name ,
281
283
persistable = True ,
282
284
type = v .type ,
283
285
dtype = v .dtype ,
284
286
shape = v .shape )
285
- for trainer_id in xrange (self .trainers ):
286
- var = pserver_program .global_block ().create_var (
287
- name = "%s.trainer_%d" % (orig_var_name , trainer_id ),
288
- persistable = False ,
289
- type = v .type ,
290
- dtype = v .dtype ,
291
- shape = v .shape )
292
- recv_inputs .append (var )
287
+ if self .trainers > 1 :
288
+ for trainer_id in xrange (self .trainers ):
289
+ var = pserver_program .global_block ().create_var (
290
+ name = "%s.trainer_%d" % (orig_var_name , trainer_id ),
291
+ persistable = False ,
292
+ type = v .type ,
293
+ dtype = v .dtype ,
294
+ shape = v .shape )
295
+ recv_inputs .append (var )
296
+ else :
297
+ recv_inputs .append (single_trainer_var )
293
298
294
299
# step3
295
300
optimize_block = pserver_program .create_block (0 )
@@ -511,8 +516,11 @@ def _clone_var(self, block, var):
511
516
512
517
def _append_split_op (self , program , gradblocks ):
513
518
# Split variables that need to be split and append respective ops
519
+ add_suffix = False
520
+ if self .trainers > 1 :
521
+ add_suffix = True
514
522
var_mapping = self ._create_vars_from_blocklist (
515
- program , gradblocks , add_trainer_suffix = True )
523
+ program , gradblocks , add_trainer_suffix = add_suffix )
516
524
for varname , splited_vars in var_mapping .iteritems ():
517
525
# variable that don't need to split have empty splited_vars
518
526
if len (splited_vars ) <= 1 :
0 commit comments