Skip to content

Commit 4abef50

Browse files
committed
code refine
1 parent 2aaa75e commit 4abef50

File tree

4 files changed

+69
-50
lines changed

4 files changed

+69
-50
lines changed

paddle/fluid/framework/details/broadcast_op_handle.cc

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -34,40 +34,21 @@ BroadcastOpHandle::BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
3434
: local_scopes_(local_scopes), places_(places) {}
3535

3636
void BroadcastOpHandle::RunImpl() {
37-
// the input may have dummy var.
38-
std::vector<VarHandle *> in_var_handle;
39-
for (auto *in : inputs_) {
40-
auto *out_handle = dynamic_cast<VarHandle *>(in);
41-
if (out_handle) {
42-
in_var_handle.push_back(out_handle);
43-
}
44-
}
37+
// the input and output may have dummy var.
38+
std::vector<VarHandle *> in_var_handle = GetValidVarHandles(inputs_);
39+
std::vector<VarHandle *> out_var_handles = GetValidVarHandles(outputs_);
40+
4541
PADDLE_ENFORCE_EQ(in_var_handle.size(), 1,
4642
"The number of input should be one.");
47-
48-
// the output may have dummy var.
49-
std::vector<VarHandle *> out_var_handles;
50-
for (auto *out : outputs_) {
51-
auto *out_handle = dynamic_cast<VarHandle *>(out);
52-
if (out_handle) {
53-
out_var_handles.push_back(out_handle);
54-
}
55-
}
56-
5743
PADDLE_ENFORCE_EQ(
5844
out_var_handles.size(), places_.size(),
5945
"The number of output should equal to the number of places.");
6046

61-
// Wait input done, this Wait is asynchronous operation
62-
auto &in_place = in_var_handle[0]->place_;
63-
if (in_var_handle[0]->generated_op_) {
64-
for (auto *out : out_var_handles) {
65-
auto &out_p = out->place_;
66-
in_var_handle[0]->generated_op_->Wait(dev_ctxes_[out_p]);
67-
}
68-
}
47+
// Wait input done, this Wait is asynchronous operationplatform::Place
48+
// &in_place;
49+
WaitEvents(out_var_handles, in_var_handle);
6950

70-
//
51+
auto in_place = in_var_handle[0]->place_;
7152
auto in_scope_idx = in_var_handle[0]->scope_idx_;
7253
auto in_var =
7354
local_scopes_.at(in_scope_idx)->FindVar(in_var_handle[0]->name_);
@@ -107,6 +88,29 @@ void BroadcastOpHandle::RunImpl() {
10788
}
10889
}
10990

91+
void BroadcastOpHandle::WaitEvents(
92+
const std::vector<VarHandle *> &out_var_handles,
93+
const std::vector<VarHandle *> &in_var_handle) {
94+
if (in_var_handle[0]->generated_op_) {
95+
for (auto *out : out_var_handles) {
96+
auto &out_p = out->place_;
97+
in_var_handle[0]->generated_op_->Wait(dev_ctxes_[out_p]);
98+
}
99+
}
100+
}
101+
102+
std::vector<VarHandle *> BroadcastOpHandle::GetValidVarHandles(
103+
const std::vector<VarHandleBase *> &inputs) {
104+
std::vector<VarHandle *> in_var_handle;
105+
for (auto *in : inputs) {
106+
auto *out_handle = dynamic_cast<VarHandle *>(in);
107+
if (out_handle) {
108+
in_var_handle.push_back(out_handle);
109+
}
110+
}
111+
return in_var_handle;
112+
}
113+
110114
std::string BroadcastOpHandle::Name() const { return "broadcast"; }
111115
} // namespace details
112116
} // namespace framework

paddle/fluid/framework/details/broadcast_op_handle.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ struct BroadcastOpHandle : public OpHandleBase {
4141

4242
protected:
4343
void RunImpl() override;
44+
45+
std::vector<VarHandle *> GetValidVarHandles(
46+
const std::vector<VarHandleBase *> &inputs);
47+
48+
void WaitEvents(const std::vector<VarHandle *> &out_var_handles,
49+
const std::vector<VarHandle *> &in_var_handle);
4450
};
4551

4652
} // namespace details

paddle/fluid/framework/details/gather_op_handle.cc

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,26 +23,13 @@ GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes,
2323
: local_scopes_(local_scopes), places_(places) {}
2424

2525
void GatherOpHandle::RunImpl() {
26-
// the input may have dummy var.
27-
std::vector<VarHandle *> in_var_handles;
28-
for (auto *in : inputs_) {
29-
auto *in_handle = dynamic_cast<VarHandle *>(in);
30-
if (in_handle) {
31-
in_var_handles.push_back(in_handle);
32-
}
33-
}
26+
// the input and output may have dummy var.
27+
std::vector<VarHandle *> in_var_handles = GetValidVarHandles(inputs_);
28+
std::vector<VarHandle *> out_var_handles = GetValidVarHandles(outputs_);
29+
3430
PADDLE_ENFORCE_EQ(
3531
in_var_handles.size(), places_.size(),
3632
"The number of output should equal to the number of places.");
37-
38-
// the output may have dummy var.
39-
std::vector<VarHandle *> out_var_handles;
40-
for (auto *out : outputs_) {
41-
auto *out_handle = dynamic_cast<VarHandle *>(out);
42-
if (out_handle) {
43-
out_var_handles.push_back(out_handle);
44-
}
45-
}
4633
PADDLE_ENFORCE_EQ(out_var_handles.size(), 1,
4734
"The number of output should be one.");
4835

@@ -58,11 +45,7 @@ void GatherOpHandle::RunImpl() {
5845
"The place of input and output should be the same.");
5946

6047
// Wait input done, this Wait is asynchronous operation
61-
for (auto *in : in_var_handles) {
62-
if (in->generated_op_) {
63-
in->generated_op_->Wait(dev_ctxes_[in->place_]);
64-
}
65-
}
48+
WaitEvents(in_var_handles);
6649

6750
std::vector<int64_t> out_rows;
6851
std::vector<Tensor> in_tensors;
@@ -111,7 +94,7 @@ void GatherOpHandle::RunImpl() {
11194

11295
// copy
11396
auto dev_ctx = dev_ctxes_[out_place];
114-
RunAndRecordEvent(out_place, [in_tensors, out_var, dev_ctx, out_place] {
97+
RunAndRecordEvent(out_place, [in_tensors, out_tensor, dev_ctx, out_place] {
11598
int s = 0, e = 0;
11699
for (size_t j = 0; j < in_tensors.size(); ++j) {
117100
e += in_tensors[j].dims()[0];
@@ -123,6 +106,27 @@ void GatherOpHandle::RunImpl() {
123106
});
124107
}
125108

109+
void GatherOpHandle::WaitEvents(
110+
const std::vector<VarHandle *> &in_var_handles) {
111+
for (auto *in : in_var_handles) {
112+
if (in->generated_op_) {
113+
in->generated_op_->Wait(dev_ctxes_[in->place_]);
114+
}
115+
}
116+
}
117+
118+
std::vector<VarHandle *> GatherOpHandle::GetValidVarHandles(
119+
const std::vector<VarHandleBase *> &inputs) {
120+
std::vector<VarHandle *> in_var_handles;
121+
for (auto *in : inputs) {
122+
auto *in_handle = dynamic_cast<VarHandle *>(in);
123+
if (in_handle) {
124+
in_var_handles.push_back(in_handle);
125+
}
126+
}
127+
return in_var_handles;
128+
}
129+
126130
std::string GatherOpHandle::Name() const { return "gather"; }
127131
} // namespace details
128132
} // namespace framework

paddle/fluid/framework/details/gather_op_handle.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ struct GatherOpHandle : public OpHandleBase {
4141

4242
protected:
4343
void RunImpl() override;
44+
45+
std::vector<VarHandle *> GetValidVarHandles(
46+
const std::vector<VarHandleBase *> &);
47+
48+
void WaitEvents(const std::vector<VarHandle *> &in_var_handles);
4449
};
4550

4651
} // namespace details

0 commit comments

Comments
 (0)