File tree Expand file tree Collapse file tree 1 file changed +10
-6
lines changed Expand file tree Collapse file tree 1 file changed +10
-6
lines changed Original file line number Diff line number Diff line change @@ -278,11 +278,21 @@ def get_pserver_program(self, endpoint):
278
278
# we don't need to create them when grad arrives.
279
279
# change client side var name to origin name by
280
280
# removing ".trainer_%d" suffix
281
+
281
282
suff_idx = v .name .find (".trainer_" )
282
283
if suff_idx >= 0 :
283
284
orig_var_name = v .name [:suff_idx ]
284
285
else :
285
286
orig_var_name = v .name
287
+ # NOTE: single_trainer_var must be created for multi-trainer
288
+ # case to merge grads from multiple trainers
289
+ single_trainer_var = \
290
+ pserver_program .global_block ().create_var (
291
+ name = orig_var_name ,
292
+ persistable = True ,
293
+ type = v .type ,
294
+ dtype = v .dtype ,
295
+ shape = v .shape )
286
296
if self .trainers > 1 :
287
297
for trainer_id in xrange (self .trainers ):
288
298
var = pserver_program .global_block ().create_var (
@@ -293,12 +303,6 @@ def get_pserver_program(self, endpoint):
293
303
shape = v .shape )
294
304
recv_inputs .append (var )
295
305
else :
296
- single_trainer_var = pserver_program .global_block ().create_var (
297
- name = orig_var_name ,
298
- persistable = True ,
299
- type = v .type ,
300
- dtype = v .dtype ,
301
- shape = v .shape )
302
306
recv_inputs .append (single_trainer_var )
303
307
304
308
# step3
You can’t perform that action at this time.
0 commit comments