Skip to content

Commit eeaf04d

Browse files
authored
[cherry-pick]Fix communicator slow bug & fix communicator stop bug (#20366) (#20646)
* Fix communicator slow bug & fix communicator stop bug (#20366) * test=develop,Fix communicator slow bug * test=develop, delete if() in stop_worker() * test=develop * fix UT, test=develop * fix bug in fetch handler, test=develop * fix bug in fetch handler, test=develop * test=develop, fix fetch barrier bug * test=develop, bug fix * test=develop, bug fix * test=develop, fix bug * test=develop,test=release/1.6
1 parent 965b45e commit eeaf04d

File tree

11 files changed

+24
-11
lines changed

11 files changed

+24
-11
lines changed

paddle/fluid/framework/dist_multi_trainer.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ void DistMultiTrainer::Run() {
144144
}
145145
}
146146

147+
Scope *DistMultiTrainer::GetWorkerScope(int thread_id) {
148+
return workers_[thread_id]->GetThreadScope();
149+
}
150+
147151
void DistMultiTrainer::Finalize() {
148152
for (auto &th : threads_) {
149153
th.join();
@@ -199,5 +203,5 @@ void DistMultiTrainer::MergeToRootScope(LoDTensor *root_tensor,
199203
root_data[i] += data[i];
200204
}
201205
}
202-
} // end namespace framework
203-
} // end namespace paddle
206+
} // namespace framework
207+
} // namespace paddle

paddle/fluid/framework/trainer.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ class DistMultiTrainer : public MultiTrainer {
9393
void MergeToRootScope(LoDTensor* root_tensor, LoDTensor* thread_tensor);
9494
virtual void FinalizeDumpEnv();
9595
virtual void InitDumpEnv();
96+
virtual Scope* GetWorkerScope(int thread_id);
9697
virtual void DumpWork(int tid);
97-
virtual Scope* GetWorkerScope(int thread_id) { return root_scope_; }
9898

9999
protected:
100100
std::shared_ptr<paddle::framework::PullDenseWorker> pull_dense_worker_;

paddle/fluid/operators/distributed/communicator.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -923,6 +923,7 @@ void GeoSgdCommunicator::RpcSend(const std::string &origin_var_name,
923923
auto &cpu_ctx_send = *pool.Get(platform::CPUPlace());
924924
distributed::RPCClient *rpc_client =
925925
distributed::RPCClient::GetInstance<RPCCLIENT_T>(trainer_id);
926+
926927
rpc_client->AsyncSendVar(endpoint, cpu_ctx_send, *delta_scope_.get(),
927928
splited_var_name);
928929
}

paddle/fluid/operators/distributed/rpc_client.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
// default to 3min to avoid temprary network failures.
1919
DEFINE_int32(rpc_deadline, 180000, "deadline timeouts for rpc");
20-
DEFINE_int32(rpc_retry_times, 3, "retry times for rpc");
20+
DEFINE_int32(rpc_retry_times, 0, "retry times for rpc");
2121

2222
namespace paddle {
2323
namespace operators {

paddle/fluid/operators/distributed_ops/fetch_barrier_op.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ class FetchBarrierOp : public framework::OperatorBase {
5555
class FetchBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
5656
public:
5757
void Make() {
58+
AddInput("X", "(Any) Dummy inputs, used for control dependency")
59+
.AsDispensable()
60+
.AsDuplicable();
5861
AddOutput("Out", "(Any) Dummy outputs, used for control dependency")
5962
.AsDuplicable();
6063
AddComment(R"DOC(

python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,7 @@ def stop_worker(self):
129129
Returns:
130130
None
131131
"""
132-
if not self._transpile_config.sync_mode and self._communicator.is_running(
133-
):
132+
if not self._transpile_config.sync_mode:
134133
self._communicator.stop()
135134
self._executor.close()
136135
if isinstance(self._role_maker, MPISymetricRoleMaker):

python/paddle/fluid/tests/unittests/ctr_dataset_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def get_rand(low=0.0, high=1.0):
6767
return random.random()
6868

6969
def iter():
70-
if get_rand() < 0.1:
70+
if get_rand() < 0.05:
7171
fs = line.strip().split('\t')
7272
dnn_input = load_dnn_input_record(fs[0])
7373
lr_input = load_lr_input_record(fs[1])

python/paddle/fluid/tests/unittests/dist_fleet_ctr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def do_training(self, fleet):
139139
dataset.set_filelist(filelist)
140140
dataset.set_thread(thread_num)
141141

142-
for epoch_id in range(2):
142+
for epoch_id in range(1):
143143
pass_start = time.time()
144144
dataset.set_filelist(filelist)
145145
exe.train_from_dataset(
@@ -157,7 +157,7 @@ def handler(self, fetch_target_vars):
157157
print("{}: \n {}\n".format(self.fetch_target_names[0],
158158
fetch_target_vars[0]))
159159

160-
for epoch_id in range(2):
160+
for epoch_id in range(1):
161161
pass_start = time.time()
162162
dataset.set_filelist(filelist)
163163
exe.train_from_dataset(

python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def __func__(*args, **kwargs):
3030
return __func__
3131

3232

33-
@skip_ci
3433
class TestDistMnist2x2(TestFleetBase):
3534
def _setup_config(self):
3635
self._sync_mode = False

python/paddle/fluid/trainer_factory.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ def handler_decorator(self, fetch_scope, fetch_handler):
8484
for varname in fetch_target_names
8585
]
8686

87+
if None in fetch_vars:
88+
continue
89+
8790
fetch_tensors = [var.get_tensor() for var in fetch_vars]
8891

8992
if self.fetch_instance.return_np:

0 commit comments

Comments
 (0)