Skip to content

Commit 9eec2c7

Browse files
committed
refine pe
1 parent f4851f1 commit 9eec2c7

15 files changed

+65
-70
lines changed

paddle/fluid/framework/details/broadcast_op_handle.cc

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@ void BroadcastOpHandle::RunImpl() {
3838
out_var_handles.size(), places_.size(),
3939
"The number of output should equal to the number of places.");
4040

41-
// Wait input done, this Wait is asynchronous operation platform::Place
42-
// &in_place;
43-
WaitInputVarGenerated(*in_var_handle);
41+
WaitInputVarGenerated();
4442

4543
std::vector<const Scope *> var_scopes;
4644
for (auto *s : local_scopes_) {
@@ -147,14 +145,6 @@ void BroadcastOpHandle::RunImpl() {
147145
}
148146
}
149147

150-
void BroadcastOpHandle::WaitInputVarGenerated(const VarHandle &in_var) {
151-
if (in_var.generated_op_) {
152-
for (auto &pair : dev_ctxes_) {
153-
in_var.generated_op_->Wait(pair.second);
154-
}
155-
}
156-
}
157-
158148
std::string BroadcastOpHandle::Name() const { return "broadcast"; }
159149
} // namespace details
160150
} // namespace framework

paddle/fluid/framework/details/broadcast_op_handle.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ struct BroadcastOpHandle : public OpHandleBase {
5757

5858
protected:
5959
void RunImpl() override;
60-
void WaitInputVarGenerated(const VarHandle &in_var);
6160

6261
private:
6362
const std::vector<Scope *> &local_scopes_;

paddle/fluid/framework/details/computation_op_handle.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,20 @@ ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
2626
place_(place) {}
2727

2828
void ComputationOpHandle::RunImpl() {
29-
auto *cur_ctx = dev_ctxes_[place_];
30-
for (auto *in : inputs_) {
31-
bool need_wait = in->generated_op_ &&
32-
in->generated_op_->DeviceContext(place_) != cur_ctx;
33-
if (need_wait) {
34-
in->generated_op_->Wait(cur_ctx);
35-
}
36-
}
29+
WaitInputVarGenerated(place_);
3730

3831
this->RunAndRecordEvent([this] {
3932
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
4033
});
4134
}
4235

36+
bool ComputationOpHandle::NeedWait(VarHandleBase *in_var) {
37+
bool need_wait =
38+
dynamic_cast<VarHandle *>(in_var) && in_var->generated_op_ &&
39+
in_var->generated_op_->DeviceContext(place_) != dev_ctxes_[place_];
40+
return need_wait;
41+
}
42+
4343
std::string ComputationOpHandle::Name() const { return op_->Type(); }
4444
} // namespace details
4545
} // namespace framework

paddle/fluid/framework/details/computation_op_handle.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ struct ComputationOpHandle : public OpHandleBase {
3636
protected:
3737
void RunImpl() override;
3838

39+
virtual bool NeedWait(VarHandleBase *in_var);
40+
3941
private:
4042
std::unique_ptr<OperatorBase> op_;
4143
Scope *scope_;

paddle/fluid/framework/details/fetch_op_handle.cc

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ FetchOpHandle::~FetchOpHandle() {
3131
}
3232
}
3333

34-
void FetchOpHandle::Wait(platform::DeviceContext *waited_dev) {
34+
void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
3535
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
3636
}
3737

@@ -45,12 +45,8 @@ void FetchOpHandle::WaitAndMergeCPUTensors() const {
4545
}
4646

4747
void FetchOpHandle::RunImpl() {
48-
auto cpu_ctx =
49-
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
50-
for (auto *input : inputs_) {
51-
auto *var = static_cast<VarHandle *>(input);
52-
var->generated_op_->Wait(cpu_ctx);
53-
}
48+
WaitInputVarGenerated(platform::CPUPlace());
49+
5450
tensors_.resize(inputs_.size());
5551
auto *var_handle = static_cast<VarHandle *>(inputs_[0]);
5652
auto &var_name = var_handle->name_;
@@ -77,6 +73,15 @@ void FetchOpHandle::RunImpl() {
7773
this->WaitAndMergeCPUTensors();
7874
}
7975

76+
void FetchOpHandle::WaitInputVarGenerated(const platform::Place &place) {
77+
auto cpu_ctx = platform::DeviceContextPool::Instance().Get(place);
78+
for (auto *input : inputs_) {
79+
if (input->generated_op_) {
80+
input->generated_op_->RecordWaitEventOnCtx(cpu_ctx);
81+
}
82+
}
83+
}
84+
8085
std::string FetchOpHandle::Name() const { return "Fetch"; }
8186

8287
} // namespace details

paddle/fluid/framework/details/fetch_op_handle.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ struct FetchOpHandle : public OpHandleBase {
3333

3434
~FetchOpHandle();
3535

36-
void Wait(platform::DeviceContext *waited_dev) override;
36+
void RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) override;
3737

3838
void WaitAndMergeCPUTensors() const;
3939

@@ -42,6 +42,8 @@ struct FetchOpHandle : public OpHandleBase {
4242
protected:
4343
void RunImpl() override;
4444

45+
virtual void WaitInputVarGenerated(const platform::Place &place);
46+
4547
private:
4648
FeedFetchList *data_;
4749
size_t offset_;

paddle/fluid/framework/details/gather_op_handle.cc

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ void GatherOpHandle::RunImpl() {
5555
"Currently, gather_op only can gather SelectedRows.");
5656

5757
// Wait input done, this Wait is asynchronous operation
58-
WaitInputVarGenerated(in_var_handles);
58+
WaitInputVarGenerated();
5959

6060
auto &pre_in_value = pre_in_var->Get<framework::SelectedRows>();
6161
std::vector<int64_t> out_rows;
@@ -111,17 +111,6 @@ void GatherOpHandle::RunImpl() {
111111
});
112112
}
113113

114-
void GatherOpHandle::WaitInputVarGenerated(
115-
const std::vector<VarHandle *> &in_var_handles) {
116-
for (auto *in : in_var_handles) {
117-
if (in->generated_op_) {
118-
for (auto pair : dev_ctxes_) {
119-
in->generated_op_->Wait(pair.second);
120-
}
121-
}
122-
}
123-
}
124-
125114
std::string GatherOpHandle::Name() const { return "gather"; }
126115
} // namespace details
127116
} // namespace framework

paddle/fluid/framework/details/gather_op_handle.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ struct GatherOpHandle : public OpHandleBase {
3939

4040
protected:
4141
void RunImpl() override;
42-
void WaitInputVarGenerated(const std::vector<VarHandle *> &in_var_handles);
4342

4443
private:
4544
const std::vector<Scope *> &local_scopes_;

paddle/fluid/framework/details/nccl_all_reduce_op_handle.cc

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,7 @@ void NCCLAllReduceOpHandle::RunImpl() {
3434
return; // No need to all reduce when GPU count = 1;
3535
} else {
3636
// Wait input done
37-
for (auto *in : inputs_) {
38-
auto &p = static_cast<VarHandle *>(in)->place_;
39-
in->generated_op_->Wait(dev_ctxes_[p]);
40-
}
37+
WaitInputVarGenerated();
4138

4239
auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_;
4340
int dtype = -1;

paddle/fluid/framework/details/op_handle_base.cc

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,15 @@ void OpHandleBase::Run(bool use_event) {
5656
RunImpl();
5757
}
5858

59-
void OpHandleBase::Wait(platform::DeviceContext *waited_dev) {
59+
void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
6060
#ifdef PADDLE_WITH_CUDA
61-
if (platform::is_cpu_place(waited_dev->GetPlace()) || events_.empty()) {
61+
if (platform::is_cpu_place(waited_ctx->GetPlace()) || events_.empty()) {
6262
for (auto &dev_ctx : dev_ctxes_) {
6363
dev_ctx.second->Wait();
6464
}
6565
} else {
6666
auto stream =
67-
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
67+
static_cast<platform::CUDADeviceContext *>(waited_ctx)->stream();
6868
for (auto &ev : events_) {
6969
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0));
7070
}
@@ -86,6 +86,28 @@ void OpHandleBase::AddOutput(VarHandleBase *out) {
8686
out->generated_op_ = this;
8787
}
8888

89+
void OpHandleBase::WaitInputVarGenerated() {
90+
for (auto in_var : inputs_) {
91+
if (NeedWait(in_var)) {
92+
for (auto &pair : dev_ctxes_) {
93+
in_var->generated_op_->RecordWaitEventOnCtx(pair.second);
94+
}
95+
}
96+
}
97+
}
98+
99+
void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) {
100+
for (auto *in : inputs_) {
101+
if (NeedWait(in)) {
102+
in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[place]);
103+
}
104+
}
105+
}
106+
107+
bool OpHandleBase::NeedWait(VarHandleBase *in_var) {
108+
return dynamic_cast<VarHandle *>(in_var) && in_var->generated_op_;
109+
}
110+
89111
void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
90112
#ifdef PADDLE_WITH_CUDA
91113
if (!events_.empty()) { // Use event

0 commit comments

Comments
 (0)