Skip to content

Commit 5f89ce7

Browse files
authored
Merge pull request #15536 from jacquesqiao/fix-prefetch-one-parameter
Fix prefetch one parameter
2 parents d303270 + 806658d commit 5f89ce7

File tree

3 files changed

+34
-16
lines changed

3 files changed

+34
-16
lines changed

paddle/fluid/operators/distributed/request_handler_impl.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,11 @@ bool RequestSendHandler::Handle(const std::string& varname,
5454
// Async
5555
if (!sync_mode_) {
5656
VLOG(3) << "async process var: " << varname;
57+
if (varname == BATCH_BARRIER_MESSAGE) {
58+
PADDLE_THROW(
59+
"async mode should not recv BATCH_BARRIER_MESSAGE or "
60+
"COMPLETE_MESSAGE");
61+
}
5762
try {
5863
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(),
5964
scope);

paddle/fluid/operators/distributed/rpc_server.cc

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,27 +39,33 @@ void RPCServer::SavePort() const {
3939
port_file.open(file_path);
4040
port_file << selected_port_;
4141
port_file.close();
42-
VLOG(4) << "selected port written to " << file_path;
42+
VLOG(3) << "selected port written to " << file_path;
4343
}
4444

4545
void RPCServer::WaitBarrier(const std::string& rpc_name) {
46+
VLOG(3) << "WaitBarrier in: " << rpc_name;
4647
std::unique_lock<std::mutex> lock(this->mutex_);
4748
barrier_cond_.wait(lock, [this, &rpc_name] {
4849
return ((barrier_counter_[rpc_name] == client_num_ && client_num_ != 0) ||
4950
exit_flag_.load());
5051
});
5152

52-
VLOG(3) << "batch_barrier_: " << rpc_name << " "
53-
<< barrier_counter_[rpc_name];
53+
VLOG(3) << "WaitBarrier out: " << rpc_name
54+
<< " counter: " << barrier_counter_[rpc_name];
5455
}
5556

5657
void RPCServer::IncreaseBatchBarrier(const std::string rpc_name) {
57-
VLOG(4) << "RPCServer begin IncreaseBatchBarrier " << rpc_name;
58+
VLOG(3) << "RPCServer begin IncreaseBatchBarrier " << rpc_name;
59+
// barrier msg should make sure that it's in the right cond(send|recv)
60+
WaitCond(rpc_name);
5861
int b = 0;
5962
std::unique_lock<std::mutex> lock(mutex_);
6063
b = ++barrier_counter_[rpc_name];
64+
VLOG(3) << rpc_name << " barrier_counter: " << b;
6165
if (b >= client_num_) {
6266
lock.unlock();
67+
VLOG(3) << "BatchBarrier counter reach " << client_num_ << " for "
68+
<< rpc_name;
6369
barrier_cond_.notify_all();
6470
lock.lock();
6571
}
@@ -71,7 +77,7 @@ void RPCServer::Complete() {
7177
client_num_--;
7278
need_reset_all_vars_ = true;
7379

74-
VLOG(4) << "decrease client_num to: " << client_num_;
80+
VLOG(3) << "decrease client_num to: " << client_num_;
7581
if (cur_cond_.load() == rpc_cond_map_[kRequestGet]) {
7682
barrier_counter_[kRequestGet]--;
7783
}
@@ -105,8 +111,8 @@ void RPCServer::RegisterRPC(const std::string& rpc_name,
105111

106112
static int cond = -1;
107113
rpc_cond_map_[rpc_name] = ++cond;
108-
VLOG(4) << "RegisterRPC rpc_name:" << rpc_name << ", handler:" << handler
109-
<< ", cond:" << rpc_cond_map_[rpc_name];
114+
VLOG(3) << "RegisterRPC rpc_name: " << rpc_name << ", handler: " << handler
115+
<< ", cond: " << rpc_cond_map_[rpc_name];
110116
}
111117

112118
void RPCServer::SetCond(const std::string& rpc_name) {
@@ -120,7 +126,7 @@ void RPCServer::SetCond(const std::string& rpc_name) {
120126
}
121127

122128
void RPCServer::WaitCond(const std::string& rpc_name) {
123-
VLOG(4) << "RPCServer WaitCond " << rpc_name;
129+
VLOG(3) << "RPCServer WaitCond in " << rpc_name;
124130
int cond = 0;
125131
{
126132
std::unique_lock<std::mutex> lock(mutex_);
@@ -130,6 +136,7 @@ void RPCServer::WaitCond(const std::string& rpc_name) {
130136
std::unique_lock<std::mutex> lock(mutex_);
131137
rpc_cond_.wait(
132138
lock, [=] { return (cur_cond_.load() == cond || exit_flag_.load()); });
139+
VLOG(3) << "RPCServer WaitCond out " << rpc_name;
133140
}
134141

135142
void RPCServer::RegisterVar(const std::string& var_name,
@@ -151,7 +158,7 @@ void RPCServer::RegisterVar(const std::string& var_name,
151158
}
152159

153160
rpc_cond_.notify_all();
154-
VLOG(4) << "RegisterVar context:" << h.String();
161+
VLOG(3) << "RegisterVar context:" << h.String();
155162
}
156163

157164
void RPCServer::IncreaseVarBarrier(const std::string& var_name) {
@@ -167,23 +174,23 @@ void RPCServer::IncreaseVarBarrier(const std::string& var_name) {
167174
barrier_cond_.notify_all();
168175
}
169176

170-
VLOG(4) << "IncreaseVarBarrier context:" << h.String();
177+
VLOG(3) << "IncreaseVarBarrier context:" << h.String();
171178
}
172179

173180
void RPCServer::WaitVarBarrier(const std::string& var_name) {
174-
VLOG(4) << "WaitBarrier var_name:" << var_name;
181+
VLOG(3) << "WaitVarBarrier var_name:" << var_name;
175182

176183
std::unique_lock<std::mutex> lock(mutex_);
177184
barrier_cond_.wait(lock, [&]() {
178185
return ((var_map_[var_name].barrier_ >= client_num_ && client_num_ != 0) ||
179186
exit_flag_.load());
180187
});
181188

182-
VLOG(4) << "WaitBarrier context: " << var_map_[var_name].String();
189+
VLOG(3) << "WaitVarBarrier context: " << var_map_[var_name].String();
183190
}
184191

185192
void RPCServer::SetVarCond(const std::string& var_name) {
186-
VLOG(4) << "SetVarCond var_name:" << var_name;
193+
VLOG(3) << "SetVarCond var_name:" << var_name;
187194
{
188195
std::unique_lock<std::mutex> lock(mutex_);
189196
if (var_map_.find(var_name) != var_map_.end()) {
@@ -193,14 +200,14 @@ void RPCServer::SetVarCond(const std::string& var_name) {
193200
}
194201

195202
void RPCServer::WaitVarCond(const std::string& var_name) {
196-
VLOG(4) << "WaitVarCond var_name:" << var_name;
203+
VLOG(3) << "WaitVarCond var_name:" << var_name;
197204

198205
std::unique_lock<std::mutex> lock(mutex_);
199206
rpc_cond_.wait(lock, [=] {
200207
return (var_map_.find(var_name) != var_map_.end() || exit_flag_.load());
201208
});
202209

203-
VLOG(4) << "WaitVarCond var_name:" << var_name << " end";
210+
VLOG(3) << "WaitVarCond var_name:" << var_name << " end";
204211
}
205212

206213
MonomerHandle RPCServer::GetMonomer(const std::string& var_name) {

paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,9 @@ void ListenAndServOp::RunSyncLoop(
137137
while (true) {
138138
// Get from multiple trainers, we don't care about the order in which
139139
// the gradients arrives, just add suffix 0~n and merge the gradient.
140+
VLOG(3) << "wait all clients to send gradient";
140141
rpc_service_->SetCond(distributed::kRequestSend);
142+
VLOG(3) << "wait all clients to send send_barrier";
141143
rpc_service_->WaitBarrier(distributed::kRequestSend);
142144

143145
if (rpc_service_->IsExit()) {
@@ -168,12 +170,16 @@ void ListenAndServOp::RunSyncLoop(
168170
}
169171
ParallelExecuteBlocks(parallel_blkids, executor, optimize_prepared, program,
170172
recv_scope);
171-
VLOG(2) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
173+
VLOG(3) << "run all blocks spent " << GetTimestamp() - ts << "(ms)";
172174

175+
VLOG(3) << "ResetReceivedVars";
173176
ResetReceivedVars(recv_scope, dev_ctx, rpc_service_->NeedResetAllVars());
174177

178+
VLOG(3) << "wait all clients to get parameters back";
175179
rpc_service_->SetCond(distributed::kRequestGet);
180+
VLOG(3) << "wait all clients to send fetch_barrier";
176181
rpc_service_->WaitBarrier(distributed::kRequestGet);
182+
VLOG(3) << "ResetBarrierCounter";
177183
rpc_service_->ResetBarrierCounter();
178184
} // while(true)
179185
}

0 commit comments

Comments
 (0)