@@ -247,7 +247,7 @@ def transpile(self,
247
247
np .random .seed (self .origin_program .random_seed )
248
248
np .random .shuffle (grad_var_mapping_items )
249
249
250
- grad_name_to_send_dummy_out = dict ()
250
+ self . grad_name_to_send_dummy_out = dict ()
251
251
for grad_varname , splited_vars in grad_var_mapping_items :
252
252
eplist = ps_dispatcher .dispatch (splited_vars )
253
253
@@ -271,7 +271,7 @@ def transpile(self,
271
271
272
272
dummy_output = program .global_block ().create_var (
273
273
name = framework .generate_control_dev_var_name ())
274
- grad_name_to_send_dummy_out [grad_varname ] = dummy_output
274
+ self . grad_name_to_send_dummy_out [grad_varname ] = dummy_output
275
275
276
276
# get send op_role_var, if not splited, the grad should have .trainer suffix
277
277
# if splited, grad should be the original grad var name (split_by_ref and send
@@ -297,7 +297,12 @@ def transpile(self,
297
297
if self .sync_mode :
298
298
send_barrier_out = program .global_block ().create_var (
299
299
name = framework .generate_control_dev_var_name ())
300
- input_deps = grad_name_to_send_dummy_out .values ()
300
+ if self .has_distributed_lookup_table :
301
+ self .grad_name_to_send_dummy_out [
302
+ self .table_name ] = program .global_block ().create_var (
303
+ name = framework .generate_control_dev_var_name ())
304
+ input_deps = self .grad_name_to_send_dummy_out .values ()
305
+
301
306
program .global_block ().append_op (
302
307
type = "send_barrier" ,
303
308
inputs = {"X" : list (input_deps )},
@@ -329,7 +334,7 @@ def transpile(self,
329
334
recv_dep_in = send_barrier_out
330
335
else :
331
336
# connect deps to send op in async mode
332
- recv_dep_in = grad_name_to_send_dummy_out [
337
+ recv_dep_in = self . grad_name_to_send_dummy_out [
333
338
self .param_name_to_grad_name [param_varname ]]
334
339
all_recv_outputs .extend (splited_var )
335
340
# get recv op_role_var, if not splited, the grad should have .trainer suffix
@@ -1046,9 +1051,13 @@ def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints):
1046
1051
index = op_index + 2 ,
1047
1052
type = "send" ,
1048
1053
inputs = {'X' : self .trainer_side_table_grad_list },
1049
- outputs = {'Out' : []},
1054
+ outputs = {
1055
+ 'Out' :
1056
+ [self .grad_name_to_send_dummy_out [self .table_name ]]
1057
+ if self .sync_mode else []
1058
+ },
1050
1059
attrs = {
1051
- "sync_mode" : True ,
1060
+ "sync_mode" : False ,
1052
1061
"epmap" : pserver_endpoints ,
1053
1062
RPC_OP_ROLE_ATTR_NAME : RPC_OP_ROLE_ATTR_VALUE ,
1054
1063
OP_ROLE_VAR_ATTR_NAME : [
0 commit comments