Skip to content

Commit d24ef93

Browse files
committed
Clean Code
1 parent 4abef50 commit d24ef93

File tree

9 files changed

+234
-103
lines changed

9 files changed

+234
-103
lines changed

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS
2424
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory)
2525
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory)
2626

27+
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
28+
29+
2730
cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
2831
device_context broadcast_op_handle)
2932
cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory

paddle/fluid/framework/details/broadcast_op_handle.cc

Lines changed: 33 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -13,102 +13,70 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
16+
#include "paddle/fluid/framework/details/container_cast.h"
17+
#include "paddle/fluid/framework/details/variable_visitor.h"
1618

1719
namespace paddle {
1820
namespace framework {
1921
namespace details {
20-
21-
Tensor *GetTensorFromVar(Variable *in_var) {
22-
if (in_var->IsType<LoDTensor>()) {
23-
return in_var->GetMutable<LoDTensor>();
24-
} else if (in_var->IsType<SelectedRows>()) {
25-
return in_var->GetMutable<SelectedRows>()->mutable_value();
26-
} else {
27-
PADDLE_THROW("Var should be LoDTensor or SelectedRows");
28-
}
29-
return nullptr;
30-
}
31-
3222
BroadcastOpHandle::BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
3323
const std::vector<platform::Place> &places)
3424
: local_scopes_(local_scopes), places_(places) {}
3525

3626
void BroadcastOpHandle::RunImpl() {
3727
// the input and output may have dummy var.
38-
std::vector<VarHandle *> in_var_handle = GetValidVarHandles(inputs_);
39-
std::vector<VarHandle *> out_var_handles = GetValidVarHandles(outputs_);
28+
VarHandle *in_var_handle;
29+
30+
{
31+
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
32+
PADDLE_ENFORCE_EQ(in_var_handles.size(), 1,
33+
"The number of input should be one.");
34+
in_var_handle = in_var_handles[0];
35+
}
36+
37+
auto out_var_handles = DynamicCast<VarHandle>(outputs_);
4038

41-
PADDLE_ENFORCE_EQ(in_var_handle.size(), 1,
42-
"The number of input should be one.");
4339
PADDLE_ENFORCE_EQ(
4440
out_var_handles.size(), places_.size(),
4541
"The number of output should equal to the number of places.");
4642

47-
// Wait input done, this Wait is asynchronous operationplatform::Place
43+
// Wait input done, this Wait is asynchronous operation platform::Place
4844
// &in_place;
49-
WaitEvents(out_var_handles, in_var_handle);
45+
WaitInputVarGenerated(*in_var_handle);
5046

51-
auto in_place = in_var_handle[0]->place_;
52-
auto in_scope_idx = in_var_handle[0]->scope_idx_;
53-
auto in_var =
54-
local_scopes_.at(in_scope_idx)->FindVar(in_var_handle[0]->name_);
55-
Tensor *in_tensor = GetTensorFromVar(in_var);
47+
auto *in_var = local_scopes_.at(in_var_handle->scope_idx_)
48+
->FindVar(in_var_handle->name_);
49+
PADDLE_ENFORCE_NOT_NULL(in_var);
50+
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
5651

5752
for (auto *out : out_var_handles) {
53+
if (*out == *in_var_handle) {
54+
continue;
55+
}
56+
5857
auto &out_p = out->place_;
59-
auto out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_);
58+
auto *out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_);
6059

61-
PADDLE_ENFORCE_EQ(out_p.which(), in_place.which(),
60+
PADDLE_ENFORCE_EQ(out_p.which(), in_var_handle->place_.which(),
6261
"Places must be all on CPU or all on CUDA.");
6362

64-
if (in_var->IsType<framework::SelectedRows>()) {
65-
auto &in_sr = in_var->Get<framework::SelectedRows>();
66-
auto out_sr = out_var->GetMutable<framework::SelectedRows>();
67-
if (&in_sr == out_sr) continue;
68-
out_sr->set_height(in_sr.height());
69-
out_sr->set_rows(in_sr.rows());
70-
out_sr->mutable_value()->Resize(in_sr.value().dims());
71-
out_sr->mutable_value()->mutable_data(out_p, in_sr.value().type());
72-
} else if (in_var->IsType<framework::LoDTensor>()) {
73-
auto in_lod = in_var->Get<framework::LoDTensor>();
74-
auto out_lod = out_var->GetMutable<framework::LoDTensor>();
75-
if (&in_lod == out_lod) continue;
76-
out_lod->set_lod(in_lod.lod());
77-
out_lod->Resize(in_lod.dims());
78-
out_lod->mutable_data(out_p, in_lod.type());
79-
} else {
80-
PADDLE_THROW("Var should be LoDTensor or SelectedRows.");
81-
}
63+
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
64+
VariableVisitor::GetMutableTensor(out_var).mutable_data(out_p,
65+
in_tensor.type());
8266

8367
auto dev_ctx = dev_ctxes_[out_p];
8468
RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] {
85-
Tensor *out_tensor = GetTensorFromVar(out_var);
86-
paddle::framework::TensorCopy(*in_tensor, out_p, *(dev_ctx), out_tensor);
69+
paddle::framework::TensorCopy(
70+
in_tensor, out_p, *(dev_ctx),
71+
&VariableVisitor::GetMutableTensor(out_var));
8772
});
8873
}
8974
}
9075

91-
void BroadcastOpHandle::WaitEvents(
92-
const std::vector<VarHandle *> &out_var_handles,
93-
const std::vector<VarHandle *> &in_var_handle) {
94-
if (in_var_handle[0]->generated_op_) {
95-
for (auto *out : out_var_handles) {
96-
auto &out_p = out->place_;
97-
in_var_handle[0]->generated_op_->Wait(dev_ctxes_[out_p]);
98-
}
99-
}
100-
}
101-
102-
std::vector<VarHandle *> BroadcastOpHandle::GetValidVarHandles(
103-
const std::vector<VarHandleBase *> &inputs) {
104-
std::vector<VarHandle *> in_var_handle;
105-
for (auto *in : inputs) {
106-
auto *out_handle = dynamic_cast<VarHandle *>(in);
107-
if (out_handle) {
108-
in_var_handle.push_back(out_handle);
109-
}
76+
void BroadcastOpHandle::WaitInputVarGenerated(const VarHandle &in_var) {
77+
for (auto &pair : dev_ctxes_) {
78+
in_var.generated_op_->Wait(pair.second);
11079
}
111-
return in_var_handle;
11280
}
11381

11482
std::string BroadcastOpHandle::Name() const { return "broadcast"; }

paddle/fluid/framework/details/broadcast_op_handle.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,7 @@ struct BroadcastOpHandle : public OpHandleBase {
4242
protected:
4343
void RunImpl() override;
4444

45-
std::vector<VarHandle *> GetValidVarHandles(
46-
const std::vector<VarHandleBase *> &inputs);
47-
48-
void WaitEvents(const std::vector<VarHandle *> &out_var_handles,
49-
const std::vector<VarHandle *> &in_var_handle);
45+
void WaitInputVarGenerated(const VarHandle &in_var);
5046
};
5147

5248
} // namespace details
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <type_traits>
18+
#include <vector>
19+
20+
namespace paddle {
21+
namespace framework {
22+
namespace details {
23+
24+
template <typename ResultType, typename ElemType>
25+
std::vector<ResultType*> DynamicCast(const std::vector<ElemType*>& container) {
26+
static_assert(std::is_base_of<ElemType, ResultType>::value,
27+
"ElementType must be a base class of ResultType");
28+
std::vector<ResultType*> res;
29+
for (auto* ptr : container) {
30+
auto* derived = dynamic_cast<ResultType*>(ptr);
31+
if (derived) {
32+
res.emplace_back(derived);
33+
}
34+
}
35+
return res;
36+
}
37+
38+
} // namespace details
39+
} // namespace framework
40+
} // namespace paddle

paddle/fluid/framework/details/gather_op_handle.cc

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/framework/details/gather_op_handle.h"
16+
#include "paddle/fluid/framework/details/container_cast.h"
17+
#include "paddle/fluid/framework/details/variable_visitor.h"
1618

1719
namespace paddle {
1820
namespace framework {
@@ -24,42 +26,47 @@ GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes,
2426

2527
void GatherOpHandle::RunImpl() {
2628
// the input and output may have dummy var.
27-
std::vector<VarHandle *> in_var_handles = GetValidVarHandles(inputs_);
28-
std::vector<VarHandle *> out_var_handles = GetValidVarHandles(outputs_);
29+
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
2930

3031
PADDLE_ENFORCE_EQ(
3132
in_var_handles.size(), places_.size(),
3233
"The number of output should equal to the number of places.");
33-
PADDLE_ENFORCE_EQ(out_var_handles.size(), 1,
34-
"The number of output should be one.");
3534

36-
auto in_0_handle = static_cast<VarHandle *>(in_var_handles[0]);
35+
VarHandle *out_var_handle;
36+
{
37+
auto out_var_handles = DynamicCast<VarHandle>(outputs_);
38+
39+
PADDLE_ENFORCE_EQ(out_var_handles.size(), 1,
40+
"The number of output should be one.");
41+
out_var_handle = out_var_handles.front();
42+
}
43+
44+
auto in_0_handle = in_var_handles[0];
3745
auto pre_in_var =
3846
local_scopes_[in_0_handle->scope_idx_]->FindVar(in_0_handle->name_);
3947
auto pre_place = in_0_handle->place_;
4048

4149
PADDLE_ENFORCE(pre_in_var->IsType<framework::SelectedRows>(),
4250
"Currently, gather_op only can gather SelectedRows.");
4351

44-
PADDLE_ENFORCE_EQ(out_var_handles[0]->place_.which(), pre_place.which(),
52+
PADDLE_ENFORCE_EQ(out_var_handle->place_.which(), pre_place.which(),
4553
"The place of input and output should be the same.");
4654

4755
// Wait input done, this Wait is asynchronous operation
48-
WaitEvents(in_var_handles);
56+
WaitInputVarGenerated(in_var_handles);
4957

5058
std::vector<int64_t> out_rows;
5159
std::vector<Tensor> in_tensors;
5260
std::vector<platform::Place> in_places;
5361

5462
auto &pre_in = pre_in_var->Get<framework::SelectedRows>();
5563
// gather the inputs
56-
for (auto *in : in_var_handles) {
57-
auto in_handle = static_cast<VarHandle *>(in);
64+
for (auto *in_handle : in_var_handles) {
5865
auto in_p = in_handle->place_;
5966
in_places.push_back(in_p);
6067
PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(),
6168
"Places must be all on CPU or all on CUDA.");
62-
auto in_var =
69+
auto *in_var =
6370
local_scopes_.at(in_handle->scope_idx_)->FindVar(in_handle->name_);
6471
auto &in_sr = in_var->Get<framework::SelectedRows>();
6572

@@ -70,17 +77,16 @@ void GatherOpHandle::RunImpl() {
7077
PADDLE_ENFORCE_EQ(pre_in.GetCompleteDims(), in_sr.GetCompleteDims(),
7178
"The dims of inputs is not consistent.");
7279

73-
auto in_sr_rows = in_sr.rows();
80+
auto &in_sr_rows = in_sr.rows();
7481
out_rows.insert(out_rows.end(), in_sr_rows.begin(), in_sr_rows.end());
7582

7683
in_tensors.emplace_back(in_sr.value());
7784
}
7885

7986
// write the output
80-
auto &out_place = out_var_handles[0]->place_;
81-
auto out_scope_idx = out_var_handles[0]->scope_idx_;
82-
auto out_var =
83-
local_scopes_[out_scope_idx]->FindVar(out_var_handles[0]->name_);
87+
auto &out_place = out_var_handle->place_;
88+
auto out_scope_idx = out_var_handle->scope_idx_;
89+
auto out_var = local_scopes_[out_scope_idx]->FindVar(out_var_handle->name_);
8490

8591
auto out = out_var->GetMutable<framework::SelectedRows>();
8692
out->set_height(pre_in.height());
@@ -106,25 +112,15 @@ void GatherOpHandle::RunImpl() {
106112
});
107113
}
108114

109-
void GatherOpHandle::WaitEvents(
115+
void GatherOpHandle::WaitInputVarGenerated(
110116
const std::vector<VarHandle *> &in_var_handles) {
111117
for (auto *in : in_var_handles) {
112118
if (in->generated_op_) {
113-
in->generated_op_->Wait(dev_ctxes_[in->place_]);
114-
}
115-
}
116-
}
117-
118-
std::vector<VarHandle *> GatherOpHandle::GetValidVarHandles(
119-
const std::vector<VarHandleBase *> &inputs) {
120-
std::vector<VarHandle *> in_var_handles;
121-
for (auto *in : inputs) {
122-
auto *in_handle = dynamic_cast<VarHandle *>(in);
123-
if (in_handle) {
124-
in_var_handles.push_back(in_handle);
119+
for (auto pair : dev_ctxes_) {
120+
in->generated_op_->Wait(pair.second);
121+
}
125122
}
126123
}
127-
return in_var_handles;
128124
}
129125

130126
std::string GatherOpHandle::Name() const { return "gather"; }

paddle/fluid/framework/details/gather_op_handle.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,7 @@ struct GatherOpHandle : public OpHandleBase {
4242
protected:
4343
void RunImpl() override;
4444

45-
std::vector<VarHandle *> GetValidVarHandles(
46-
const std::vector<VarHandleBase *> &);
47-
48-
void WaitEvents(const std::vector<VarHandle *> &in_var_handles);
45+
void WaitInputVarGenerated(const std::vector<VarHandle *> &in_var_handles);
4946
};
5047

5148
} // namespace details

paddle/fluid/framework/details/var_handle.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ struct VarHandle : public VarHandleBase {
5353
size_t scope_idx_;
5454
std::string name_;
5555
platform::Place place_;
56+
57+
bool operator==(const VarHandle &o) const {
58+
return o.generated_op_ == generated_op_ && o.name_ == name_ &&
59+
o.scope_idx_ == scope_idx_;
60+
}
5661
};
5762

5863
// Dummy Variable. It is used to represent dependencies between operators

0 commit comments

Comments
 (0)