Skip to content

Commit 8305322

Browse files
committed
extract method from broadcast::RunImpl
1 parent 93368aa commit 8305322

File tree

5 files changed

+56
-33
lines changed

5 files changed

+56
-33
lines changed

paddle/fluid/framework/details/broadcast_op_handle.cc

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -48,29 +48,9 @@ void BroadcastOpHandle::RunImpl() {
4848
auto *in_var =
4949
var_scopes.at(in_var_handle->scope_idx_)->FindVar(in_var_handle->name_);
5050
PADDLE_ENFORCE_NOT_NULL(in_var);
51-
5251
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
5352

54-
// NOTE: The tensors' Place of input and output must be all on GPU or all on
55-
// CPU.
56-
for (auto *out_var_handle : out_var_handles) {
57-
if (out_var_handle->IsTheSameVar(*in_var_handle)) {
58-
continue;
59-
}
60-
auto t_out_p = out_var_handle->place_;
61-
auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
62-
->FindVar(out_var_handle->name_);
63-
PADDLE_ENFORCE_NOT_NULL(out_var);
64-
if (platform::is_gpu_place(in_tensor.place())) {
65-
PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
66-
"Places of input and output must be all on GPU.");
67-
} else {
68-
t_out_p = platform::CPUPlace();
69-
}
70-
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
71-
VariableVisitor::GetMutableTensor(out_var).mutable_data(t_out_p,
72-
in_tensor.type());
73-
}
53+
InitOutputValue(*in_var_handle, out_var_handles);
7454

7555
if (platform::is_cpu_place(in_tensor.place())) {
7656
for (auto *out_var_handle : out_var_handles) {
@@ -145,6 +125,40 @@ void BroadcastOpHandle::RunImpl() {
145125
}
146126
}
147127

128+
void BroadcastOpHandle::InitOutputValue(
129+
const VarHandle &in_var_handle,
130+
const std::vector<VarHandle *> &out_var_handles) const {
131+
std::vector<const Scope *> var_scopes;
132+
for (auto *s : local_scopes_) {
133+
var_scopes.emplace_back(s->FindVar(kLocalExecScopeName)->Get<Scope *>());
134+
}
135+
auto *in_var =
136+
var_scopes.at(in_var_handle.scope_idx_)->FindVar(in_var_handle.name_);
137+
138+
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
139+
140+
// NOTE: The tensors' Place of input and output must be all on GPU or all on
141+
// CPU.
142+
for (auto *out_var_handle : out_var_handles) {
143+
if (out_var_handle->IsTheSameVar(in_var_handle)) {
144+
continue;
145+
}
146+
auto t_out_p = out_var_handle->place_;
147+
auto *out_var = var_scopes.at(out_var_handle->scope_idx_)
148+
->FindVar(out_var_handle->name_);
149+
PADDLE_ENFORCE_NOT_NULL(out_var);
150+
if (is_gpu_place(in_tensor.place())) {
151+
PADDLE_ENFORCE(platform::is_gpu_place(t_out_p),
152+
"Places of input and output must be all on GPU.");
153+
} else {
154+
t_out_p = platform::CPUPlace();
155+
}
156+
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
157+
VariableVisitor::GetMutableTensor(out_var).mutable_data(t_out_p,
158+
in_tensor.type());
159+
}
160+
}
161+
148162
std::string BroadcastOpHandle::Name() const { return "broadcast"; }
149163
} // namespace details
150164
} // namespace framework

paddle/fluid/framework/details/broadcast_op_handle.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ struct BroadcastOpHandle : public OpHandleBase {
6464
#ifdef PADDLE_WITH_CUDA
6565
const platform::NCCLContextMap *nccl_ctxs_;
6666
#endif
67+
68+
void InitOutputValue(const VarHandle &in_var_handle,
69+
const std::vector<VarHandle *> &out_var_handles) const;
6770
};
6871
} // namespace details
6972
} // namespace framework

paddle/fluid/framework/details/op_handle_base.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,14 @@ class OpHandleBase {
4444

4545
void AddOutput(VarHandleBase *out);
4646

47-
// Wait inputs are generated, this Wait is asynchronous operation.
47+
// This method adds the wait events of all the input on all the device
48+
// context.
49+
// NODE: This Wait is asynchronous operation.
4850
virtual void WaitInputVarGenerated();
4951

50-
// Wait inputs are generated, this Wait is asynchronous operation.
52+
// This method adds the wait events of all the input on the specified device
53+
// context.
54+
// NODE: This Wait is asynchronous operation.
5155
virtual void WaitInputVarGenerated(const platform::Place &place);
5256

5357
virtual bool NeedWait(VarHandleBase *in_var);

paddle/fluid/framework/details/reduce_op_handle.cc

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -80,19 +80,21 @@ void ReduceOpHandle::RunImpl() {
8080
}
8181

8282
if (pre_in_var->IsType<framework::SelectedRows>()) {
83-
std::vector<const SelectedRows *> in_selected_rows =
84-
GetInputValues<SelectedRows>(in_var_handles, var_scopes);
85-
86-
GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, t_out_p,
87-
out_var->GetMutable<framework::SelectedRows>());
83+
this->RunAndRecordEvent([&] {
84+
std::vector<const SelectedRows *> in_selected_rows =
85+
GetInputValues<SelectedRows>(in_var_handles, var_scopes);
86+
GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, t_out_p,
87+
out_var->GetMutable<framework::SelectedRows>());
88+
});
8889
} else {
8990
std::vector<const LoDTensor *> lod_tensors =
9091
GetInputValues<LoDTensor>(in_var_handles, var_scopes);
91-
9292
if (paddle::platform::is_cpu_place(lod_tensors[0]->place())) {
93-
ReduceLoDTensor func(lod_tensors,
94-
out_var->GetMutable<framework::LoDTensor>());
95-
VisitDataType(ToDataType(lod_tensors[0]->type()), func);
93+
this->RunAndRecordEvent([&] {
94+
ReduceLoDTensor func(lod_tensors,
95+
out_var->GetMutable<framework::LoDTensor>());
96+
VisitDataType(ToDataType(lod_tensors[0]->type()), func);
97+
});
9698
} else if (paddle::platform::is_gpu_place(lod_tensors[0]->place())) {
9799
#ifdef PADDLE_WITH_CUDA
98100
auto pre_in = pre_in_var->Get<framework::LoDTensor>();

paddle/fluid/framework/details/send_op_handle.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc,
2727

2828
void SendOpHandle::RunImpl() {
2929
// Wait input done
30-
WaitInputVarGenerated();
30+
WaitInputVarGenerated(place_);
3131
auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
3232
// FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead
3333
// lock.

0 commit comments

Comments
 (0)