Skip to content

Commit 020d13c

Browse files
authored
fix dist table send hang problem (#13259)
* fix dist table send hang problem * revert sync_mode config * fix async send table
1 parent 2c31ea9 commit 020d13c

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def transpile(self,
247247
np.random.seed(self.origin_program.random_seed)
248248
np.random.shuffle(grad_var_mapping_items)
249249

250-
grad_name_to_send_dummy_out = dict()
250+
self.grad_name_to_send_dummy_out = dict()
251251
for grad_varname, splited_vars in grad_var_mapping_items:
252252
eplist = ps_dispatcher.dispatch(splited_vars)
253253

@@ -271,7 +271,7 @@ def transpile(self,
271271

272272
dummy_output = program.global_block().create_var(
273273
name=framework.generate_control_dev_var_name())
274-
grad_name_to_send_dummy_out[grad_varname] = dummy_output
274+
self.grad_name_to_send_dummy_out[grad_varname] = dummy_output
275275

276276
# get send op_role_var, if not splited, the grad should have .trainer suffix
277277
# if splited, grad should be the original grad var name (split_by_ref and send
@@ -297,7 +297,12 @@ def transpile(self,
297297
if self.sync_mode:
298298
send_barrier_out = program.global_block().create_var(
299299
name=framework.generate_control_dev_var_name())
300-
input_deps = grad_name_to_send_dummy_out.values()
300+
if self.has_distributed_lookup_table:
301+
self.grad_name_to_send_dummy_out[
302+
self.table_name] = program.global_block().create_var(
303+
name=framework.generate_control_dev_var_name())
304+
input_deps = self.grad_name_to_send_dummy_out.values()
305+
301306
program.global_block().append_op(
302307
type="send_barrier",
303308
inputs={"X": list(input_deps)},
@@ -329,7 +334,7 @@ def transpile(self,
329334
recv_dep_in = send_barrier_out
330335
else:
331336
# connect deps to send op in async mode
332-
recv_dep_in = grad_name_to_send_dummy_out[
337+
recv_dep_in = self.grad_name_to_send_dummy_out[
333338
self.param_name_to_grad_name[param_varname]]
334339
all_recv_outputs.extend(splited_var)
335340
# get recv op_role_var, if not splited, the grad should have .trainer suffix
@@ -1046,9 +1051,13 @@ def _split_table_grad_and_add_send_vars(self, program, pserver_endpoints):
10461051
index=op_index + 2,
10471052
type="send",
10481053
inputs={'X': self.trainer_side_table_grad_list},
1049-
outputs={'Out': []},
1054+
outputs={
1055+
'Out':
1056+
[self.grad_name_to_send_dummy_out[self.table_name]]
1057+
if self.sync_mode else []
1058+
},
10501059
attrs={
1051-
"sync_mode": True,
1060+
"sync_mode": False,
10521061
"epmap": pserver_endpoints,
10531062
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE,
10541063
OP_ROLE_VAR_ATTR_NAME: [

0 commit comments

Comments
 (0)