@@ -210,6 +210,9 @@ def transpile(self,
210
210
211
211
ps_dispatcher = self .config .split_method (self .pserver_endpoints )
212
212
self .has_distributed_lookup_table = self ._has_distributed_lookup_table ()
213
+ self .param_name_to_grad_name = dict ()
214
+ for param_var , grad_var in self .params_grads :
215
+ self .param_name_to_grad_name [param_var .name ] = grad_var .name
213
216
214
217
# step 1: split and create vars, then put splited vars in dicts for later use.
215
218
self ._init_splited_vars ()
@@ -229,34 +232,39 @@ def transpile(self,
229
232
random .seed (self .origin_program .random_seed )
230
233
random .shuffle (grad_var_mapping_items )
231
234
232
- for orig_varname , splited_vars in grad_var_mapping_items :
235
+ grad_name_to_send_dummy_out = dict ()
236
+ for grad_varname , splited_vars in grad_var_mapping_items :
233
237
eplist = ps_dispatcher .dispatch (splited_vars )
234
238
235
239
if not self .config .slice_var_up :
236
240
assert (len (splited_vars ) == 1 )
237
241
242
+ splited_grad_varname = grad_varname
238
243
if len (splited_vars ) == 1 :
239
- orig_varname = splited_vars [0 ].name
244
+ splited_grad_varname = splited_vars [0 ].name
240
245
index = find_op_by_output_arg (program .global_block (),
241
- orig_varname )
246
+ splited_grad_varname )
242
247
elif len (splited_vars ) > 1 :
243
- orig_var = program .global_block ().vars [orig_varname ]
248
+ orig_var = program .global_block ().vars [splited_grad_varname ]
244
249
index = find_op_by_output_arg (program .global_block (),
245
- orig_varname )
250
+ splited_grad_varname )
246
251
self ._insert_split_op (program , orig_var , index , splited_vars )
247
252
index += 1
248
253
else :
249
254
AssertionError ("Can not insert the send op by original "
250
- "variable name :" , orig_varname )
255
+ "variable name :" , splited_grad_varname )
251
256
257
+ dummy_output = program .global_block ().create_var ()
258
+ grad_name_to_send_dummy_out [grad_varname ] = dummy_output
252
259
program .global_block ()._insert_op (
253
260
index = index + 1 ,
254
261
type = "send" ,
255
262
inputs = {"X" : splited_vars },
256
- outputs = {},
263
+ outputs = {"Out" : dummy_output },
257
264
attrs = {
258
265
"epmap" : eplist ,
259
- RPC_OP_ROLE_ATTR_NAME : RPC_OP_ROLE_ATTR_VALUE
266
+ RPC_OP_ROLE_ATTR_NAME : RPC_OP_ROLE_ATTR_VALUE ,
267
+ "sync_mode" : not self .sync_mode ,
260
268
})
261
269
for _ , var in enumerate (splited_vars ):
262
270
send_vars .append (var )
@@ -268,7 +276,6 @@ def transpile(self,
268
276
outputs = {},
269
277
attrs = {
270
278
"endpoints" : pserver_endpoints ,
271
- "sync_mode" : self .sync_mode ,
272
279
RPC_OP_ROLE_ATTR_NAME : RPC_OP_ROLE_ATTR_VALUE
273
280
})
274
281
@@ -284,19 +291,21 @@ def transpile(self,
284
291
self .param_grad_ep_mapping [ep ]["grads" ].append (send_vars [i ])
285
292
286
293
# step4: Concat the parameters splits together after recv.
287
- for varname , splited_var in six .iteritems (self .param_var_mapping ):
294
+ for param_varname , splited_var in six .iteritems (self .param_var_mapping ):
288
295
eps = []
289
296
for var in splited_var :
290
297
index = [v .name for v in recv_vars ].index (var .name )
291
298
eps .append (eplist [index ])
292
-
299
+ grad_send_dummy_out = grad_name_to_send_dummy_out [
300
+ self .param_name_to_grad_name [param_varname ]]
293
301
program .global_block ().append_op (
294
302
type = "recv" ,
295
- inputs = {},
303
+ inputs = {"X" : [ grad_send_dummy_out ] },
296
304
outputs = {"Out" : splited_var },
297
305
attrs = {
298
306
"epmap" : eps ,
299
- RPC_OP_ROLE_ATTR_NAME : RPC_OP_ROLE_ATTR_VALUE
307
+ RPC_OP_ROLE_ATTR_NAME : RPC_OP_ROLE_ATTR_VALUE ,
308
+ "sync_mode" : not self .sync_mode
300
309
})
301
310
302
311
if self .sync_mode :
@@ -309,10 +318,10 @@ def transpile(self,
309
318
RPC_OP_ROLE_ATTR_NAME : RPC_OP_ROLE_ATTR_VALUE
310
319
})
311
320
312
- for varname , splited_var in six .iteritems (self .param_var_mapping ):
321
+ for param_varname , splited_var in six .iteritems (self .param_var_mapping ):
313
322
if len (splited_var ) <= 1 :
314
323
continue
315
- orig_param = program .global_block ().vars [varname ]
324
+ orig_param = program .global_block ().vars [param_varname ]
316
325
program .global_block ().append_op (
317
326
type = "concat" ,
318
327
inputs = {"X" : splited_var },
@@ -380,7 +389,7 @@ def _get_trainer_startup_program(self,
380
389
381
390
op = startup_program .global_block ().append_op (
382
391
type = "recv" ,
383
- inputs = {},
392
+ inputs = {"X" : [] },
384
393
outputs = {"Out" : splited_var },
385
394
attrs = {
386
395
"epmap" : eps ,
@@ -786,19 +795,21 @@ def _init_splited_vars(self):
786
795
self .config .min_block_size )
787
796
assert (len (grad_blocks ) == len (param_blocks ))
788
797
789
- # origin_varname -> [splited_var ]
798
+ # origin_param_name -> [splited_param_vars ]
790
799
self .param_var_mapping = self ._create_vars_from_blocklist (
791
800
self .origin_program , param_blocks )
801
+ # origin_grad_name -> [splited_grad_vars]
792
802
self .grad_var_mapping = self ._create_vars_from_blocklist (
793
803
self .origin_program ,
794
804
grad_blocks ,
795
805
add_trainer_suffix = self .trainer_num > 1 )
806
+ # dict(grad_splited_var -> param_splited_var)
796
807
self .grad_param_mapping = collections .OrderedDict ()
797
808
for g , p in zip (grad_blocks , param_blocks ):
798
809
g_name , g_bid , _ = g .split (":" )
799
810
p_name , p_bid , _ = p .split (":" )
800
811
self .grad_param_mapping [self .grad_var_mapping [g_name ][int (g_bid )]] = \
801
- self .param_var_mapping [p_name ][int (p_bid )]
812
+ self .param_var_mapping [p_name ][int (p_bid )]
802
813
803
814
# create mapping of endpoint -> split var to create pserver side program
804
815
self .param_grad_ep_mapping = collections .OrderedDict ()
@@ -919,7 +930,7 @@ def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints):
919
930
index = op_index + 2 ,
920
931
type = "send" ,
921
932
inputs = {'X' : self .trainer_side_table_grad_list },
922
- outputs = {},
933
+ outputs = {'Out' : [] },
923
934
attrs = {
924
935
"sync_mode" : True ,
925
936
"epmap" : pserver_endpoints ,
0 commit comments