Skip to content

Commit ce72c3f

Browse files
author
chengduo
authored
Merge pull request #10476 from chengduoZH/refine_parallel_exe
Clean Parallel exe
2 parents 61343fb + a89cd46 commit ce72c3f

17 files changed

+191
-142
lines changed

paddle/fluid/framework/details/broadcast_op_handle.cc

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@ void BroadcastOpHandle::RunImpl() {
3838
out_var_handles.size(), places_.size(),
3939
"The number of output should equal to the number of places.");
4040

41-
// Wait input done, this Wait is asynchronous operation platform::Place
42-
// &in_place;
43-
WaitInputVarGenerated(*in_var_handle);
41+
WaitInputVarGenerated();
4442

4543
std::vector<const Scope *> var_scopes;
4644
for (auto *s : local_scopes_) {
@@ -50,29 +48,9 @@ void BroadcastOpHandle::RunImpl() {
5048
auto *in_var =
5149
var_scopes.at(in_var_handle->scope_idx_)->FindVar(in_var_handle->name_);
5250
PADDLE_ENFORCE_NOT_NULL(in_var);
53-
5451
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
5552

56-
// NOTE: The tensors' Place of input and output must be all on GPU or all on
57-
// CPU.
58-
for (auto *out_var_handle : out_var_handles) {
59-
if (out_var_handle->IsTheSameVar(*in_var_handle)) {
60-
continue;
61-
}
62-
auto t_out_p = out_var_handle->place_;
63-
auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
64-
->FindVar(out_var_handle->name_);
65-
PADDLE_ENFORCE_NOT_NULL(out_var);
66-
if (platform::is_gpu_place(in_tensor.place())) {
67-
PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
68-
"Places of input and output must be all on GPU.");
69-
} else {
70-
t_out_p = platform::CPUPlace();
71-
}
72-
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
73-
VariableVisitor::GetMutableTensor(out_var).mutable_data(t_out_p,
74-
in_tensor.type());
75-
}
53+
InitOutputValue(*in_var_handle, out_var_handles);
7654

7755
if (platform::is_cpu_place(in_tensor.place())) {
7856
for (auto *out_var_handle : out_var_handles) {
@@ -147,11 +125,37 @@ void BroadcastOpHandle::RunImpl() {
147125
}
148126
}
149127

150-
void BroadcastOpHandle::WaitInputVarGenerated(const VarHandle &in_var) {
151-
if (in_var.generated_op_) {
152-
for (auto &pair : dev_ctxes_) {
153-
in_var.generated_op_->Wait(pair.second);
128+
void BroadcastOpHandle::InitOutputValue(
129+
const VarHandle &in_var_handle,
130+
const std::vector<VarHandle *> &out_var_handles) const {
131+
std::vector<const Scope *> var_scopes;
132+
for (auto *s : local_scopes_) {
133+
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
134+
}
135+
auto *in_var =
136+
var_scopes.at(in_var_handle.scope_idx_)->FindVar(in_var_handle.name_);
137+
138+
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
139+
140+
// NOTE: The tensors' Place of input and output must be all on GPU or all on
141+
// CPU.
142+
for (auto *out_var_handle : out_var_handles) {
143+
if (out_var_handle->IsTheSameVar(in_var_handle)) {
144+
continue;
154145
}
146+
auto t_out_p = out_var_handle->place_;
147+
auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
148+
->FindVar(out_var_handle->name_);
149+
PADDLE_ENFORCE_NOT_NULL(out_var);
150+
if (is_gpu_place(in_tensor.place())) {
151+
PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
152+
"Places of input and output must be all on GPU.");
153+
} else {
154+
t_out_p = platform::CPUPlace();
155+
}
156+
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
157+
VariableVisitor::GetMutableTensor(out_var).mutable_data(t_out_p,
158+
in_tensor.type());
155159
}
156160
}
157161

paddle/fluid/framework/details/broadcast_op_handle.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,16 @@ struct BroadcastOpHandle : public OpHandleBase {
5757

5858
protected:
5959
void RunImpl() override;
60-
void WaitInputVarGenerated(const VarHandle &in_var);
6160

6261
private:
6362
const std::vector<Scope *> &local_scopes_;
6463
const std::vector<platform::Place> &places_;
6564
#ifdef PADDLE_WITH_CUDA
6665
const platform::NCCLContextMap *nccl_ctxs_;
6766
#endif
67+
68+
void InitOutputValue(const VarHandle &in_var_handle,
69+
const std::vector<VarHandle *> &out_var_handles) const;
6870
};
6971
} // namespace details
7072
} // namespace framework

paddle/fluid/framework/details/computation_op_handle.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,20 @@ ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
2626
place_(place) {}
2727

2828
void ComputationOpHandle::RunImpl() {
29-
auto *cur_ctx = dev_ctxes_[place_];
30-
for (auto *in : inputs_) {
31-
bool need_wait = in->generated_op_ &&
32-
in->generated_op_->DeviceContext(place_) != cur_ctx;
33-
if (need_wait) {
34-
in->generated_op_->Wait(cur_ctx);
35-
}
36-
}
29+
WaitInputVarGenerated(place_);
3730

3831
this->RunAndRecordEvent([this] {
3932
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
4033
});
4134
}
4235

36+
bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) {
37+
bool need_wait =
38+
in_var && in_var->generated_op_ &&
39+
in_var->generated_op_->DeviceContext(place_) != dev_ctxes_[place_];
40+
return need_wait;
41+
}
42+
4343
std::string ComputationOpHandle::Name() const { return op_->Type(); }
4444
} // namespace details
4545
} // namespace framework

paddle/fluid/framework/details/computation_op_handle.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ struct ComputationOpHandle : public OpHandleBase {
3636
protected:
3737
void RunImpl() override;
3838

39+
virtual bool NeedWait(VarHandleBase *in_var);
40+
3941
private:
4042
std::unique_ptr<OperatorBase> op_;
4143
Scope *scope_;

paddle/fluid/framework/details/fetch_op_handle.cc

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ FetchOpHandle::~FetchOpHandle() {
3131
}
3232
}
3333

34-
void FetchOpHandle::Wait(platform::DeviceContext *waited_dev) {
34+
void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
3535
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
3636
}
3737

@@ -45,14 +45,8 @@ void FetchOpHandle::WaitAndMergeCPUTensors() const {
4545
}
4646

4747
void FetchOpHandle::RunImpl() {
48-
auto cpu_ctx =
49-
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
50-
for (auto *input : inputs_) {
51-
auto *var = static_cast<VarHandle *>(input);
52-
if (var->generated_op_) {
53-
var->generated_op_->Wait(cpu_ctx);
54-
}
55-
}
48+
WaitInputVarGenerated(platform::CPUPlace());
49+
5650
tensors_.resize(inputs_.size());
5751
auto *var_handle = static_cast<VarHandle *>(inputs_[0]);
5852
auto &var_name = var_handle->name_;
@@ -79,6 +73,15 @@ void FetchOpHandle::RunImpl() {
7973
this->WaitAndMergeCPUTensors();
8074
}
8175

76+
void FetchOpHandle::WaitInputVarGenerated(const platform::Place &place) {
77+
auto cpu_ctx = platform::DeviceContextPool::Instance().Get(place);
78+
for (auto *input : inputs_) {
79+
if (input->generated_op_) {
80+
input->generated_op_->RecordWaitEventOnCtx(cpu_ctx);
81+
}
82+
}
83+
}
84+
8285
std::string FetchOpHandle::Name() const { return "Fetch"; }
8386

8487
} // namespace details

paddle/fluid/framework/details/fetch_op_handle.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ struct FetchOpHandle : public OpHandleBase {
3333

3434
~FetchOpHandle();
3535

36-
void Wait(platform::DeviceContext *waited_dev) override;
36+
void RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) override;
3737

3838
void WaitAndMergeCPUTensors() const;
3939

@@ -42,6 +42,8 @@ struct FetchOpHandle : public OpHandleBase {
4242
protected:
4343
void RunImpl() override;
4444

45+
virtual void WaitInputVarGenerated(const platform::Place &place);
46+
4547
private:
4648
FeedFetchList *data_;
4749
size_t offset_;

paddle/fluid/framework/details/gather_op_handle.cc

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ void GatherOpHandle::RunImpl() {
5555
"Currently, gather_op only can gather SelectedRows.");
5656

5757
// Wait input done, this Wait is asynchronous operation
58-
WaitInputVarGenerated(in_var_handles);
58+
WaitInputVarGenerated();
5959

6060
auto &pre_in_value = pre_in_var->Get<framework::SelectedRows>();
6161
std::vector<int64_t> out_rows;
@@ -111,17 +111,6 @@ void GatherOpHandle::RunImpl() {
111111
});
112112
}
113113

114-
void GatherOpHandle::WaitInputVarGenerated(
115-
const std::vector<VarHandle *> &in_var_handles) {
116-
for (auto *in : in_var_handles) {
117-
if (in->generated_op_) {
118-
for (auto pair : dev_ctxes_) {
119-
in->generated_op_->Wait(pair.second);
120-
}
121-
}
122-
}
123-
}
124-
125114
std::string GatherOpHandle::Name() const { return "gather"; }
126115
} // namespace details
127116
} // namespace framework

paddle/fluid/framework/details/gather_op_handle.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ struct GatherOpHandle : public OpHandleBase {
3939

4040
protected:
4141
void RunImpl() override;
42-
void WaitInputVarGenerated(const std::vector<VarHandle *> &in_var_handles);
4342

4443
private:
4544
const std::vector<Scope *> &local_scopes_;

paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
3434
return; // No need to all reduce when GPU count = 1;
3535
} else {
3636
// Wait input done
37-
for (auto *in : inputs_) {
38-
auto &p = static_cast<VarHandle *>(in)->place_;
39-
if (in->generated_op_) {
40-
in->generated_op_->Wait(dev_ctxes_[p]);
41-
}
42-
}
37+
WaitInputVarGenerated();
4338

4439
auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_;
4540
int dtype = -1;

paddle/fluid/framework/details/op_handle_base.cc

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,15 @@ void OpHandleBase::Run(bool use_event) {
5656
RunImpl();
5757
}
5858

59-
void OpHandleBase::Wait(platform::DeviceContext *waited_dev) {
59+
void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
6060
#ifdef PADDLE_WITH_CUDA
61-
if (platform::is_cpu_place(waited_dev->GetPlace()) || events_.empty()) {
61+
if (platform::is_cpu_place(waited_ctx->GetPlace()) || events_.empty()) {
6262
for (auto &dev_ctx : dev_ctxes_) {
6363
dev_ctx.second->Wait();
6464
}
6565
} else {
6666
auto stream =
67-
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
67+
static_cast<platform::CUDADeviceContext *>(waited_ctx)->stream();
6868
for (auto &ev : events_) {
6969
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0));
7070
}
@@ -86,6 +86,28 @@ void OpHandleBase::AddOutput(VarHandleBase *out) {
8686
out->generated_op_ = this;
8787
}
8888

89+
void OpHandleBase::WaitInputVarGenerated() {
90+
for (auto in_var : inputs_) {
91+
if (NeedWait(in_var)) {
92+
for (auto &pair : dev_ctxes_) {
93+
in_var->generated_op_->RecordWaitEventOnCtx(pair.second);
94+
}
95+
}
96+
}
97+
}
98+
99+
void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) {
100+
for (auto *in : inputs_) {
101+
if (NeedWait(in)) {
102+
in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[place]);
103+
}
104+
}
105+
}
106+
107+
bool OpHandleBase::NeedWait(VarHandleBase *in_var) {
108+
return in_var && in_var->generated_op_;
109+
}
110+
89111
void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
90112
#ifdef PADDLE_WITH_CUDA
91113
if (!events_.empty()) { // Use event

0 commit comments

Comments
 (0)