Skip to content

Commit 23a21c8

Browse files
author
chengduo
authored
Merge pull request #9922 from chengduoZH/feature/refine_gather_reduce
Refine gather and broadcast
2 parents d114d2b + 88f8183 commit 23a21c8

File tree

9 files changed

+262
-108
lines changed

9 files changed

+262
-108
lines changed

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@ cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framewor
2121
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
2222
simple_threadpool device_context)
2323

24-
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory)
25-
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory)
24+
cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)
25+
26+
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base variable_visitor scope ddim memory)
27+
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope variable_visitor ddim memory)
2628

2729
cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
2830
device_context broadcast_op_handle)

paddle/fluid/framework/details/broadcast_op_handle.cc

Lines changed: 42 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -13,95 +13,72 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
16+
#include "paddle/fluid/framework/details/container_cast.h"
17+
#include "paddle/fluid/framework/details/variable_visitor.h"
1618

1719
namespace paddle {
1820
namespace framework {
1921
namespace details {
20-
21-
Tensor *GetTensorFromVar(Variable *in_var) {
22-
if (in_var->IsType<LoDTensor>()) {
23-
return in_var->GetMutable<LoDTensor>();
24-
} else if (in_var->IsType<SelectedRows>()) {
25-
return in_var->GetMutable<SelectedRows>()->mutable_value();
26-
} else {
27-
PADDLE_THROW("Var should be LoDTensor or SelectedRows");
28-
}
29-
return nullptr;
30-
}
31-
3222
BroadcastOpHandle::BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
3323
const std::vector<platform::Place> &places)
3424
: local_scopes_(local_scopes), places_(places) {}
3525

3626
void BroadcastOpHandle::RunImpl() {
37-
// the input may have dummy var.
38-
std::vector<VarHandle *> in_var_handle;
39-
for (auto *in : inputs_) {
40-
auto *out_handle = dynamic_cast<VarHandle *>(in);
41-
if (out_handle) {
42-
in_var_handle.push_back(out_handle);
43-
}
44-
}
45-
PADDLE_ENFORCE_EQ(in_var_handle.size(), 1,
46-
"The number of input should be one.");
27+
// the input and output may have dummy var.
28+
VarHandle *in_var_handle;
4729

48-
// the output may have dummy var.
49-
std::vector<VarHandle *> out_var_handles;
50-
for (auto *out : outputs_) {
51-
auto *out_handle = dynamic_cast<VarHandle *>(out);
52-
if (out_handle) {
53-
out_var_handles.push_back(out_handle);
54-
}
30+
{
31+
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
32+
PADDLE_ENFORCE_EQ(in_var_handles.size(), 1,
33+
"The number of input should be one.");
34+
in_var_handle = in_var_handles[0];
5535
}
5636

37+
auto out_var_handles = DynamicCast<VarHandle>(outputs_);
38+
5739
PADDLE_ENFORCE_EQ(
5840
out_var_handles.size(), places_.size(),
5941
"The number of output should equal to the number of places.");
6042

61-
// Wait input done, this Wait is asynchronous operation
62-
auto &in_place = in_var_handle[0]->place_;
63-
if (in_var_handle[0]->generated_op_) {
64-
for (auto *out : out_var_handles) {
65-
auto &out_p = out->place_;
66-
in_var_handle[0]->generated_op_->Wait(dev_ctxes_[out_p]);
67-
}
68-
}
43+
// Wait input done, this Wait is asynchronous operation platform::Place
44+
// &in_place;
45+
WaitInputVarGenerated(*in_var_handle);
6946

70-
//
71-
auto in_scope_idx = in_var_handle[0]->scope_idx_;
72-
auto in_var =
73-
local_scopes_.at(in_scope_idx)->FindVar(in_var_handle[0]->name_);
74-
Tensor *in_tensor = GetTensorFromVar(in_var);
47+
auto *in_var = local_scopes_.at(in_var_handle->scope_idx_)
48+
->FindVar(in_var_handle->name_);
49+
PADDLE_ENFORCE_NOT_NULL(in_var);
50+
Tensor &in_tensor = VariableVisitor::GetMutableTensor(in_var);
7551

7652
for (auto *out : out_var_handles) {
53+
if (*out == *in_var_handle) {
54+
continue;
55+
}
56+
7757
auto &out_p = out->place_;
78-
auto out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_);
58+
auto *out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_);
7959

80-
PADDLE_ENFORCE_EQ(out_p.which(), in_place.which(),
60+
PADDLE_ENFORCE_EQ(out_p.which(), in_var_handle->place_.which(),
8161
"Places must be all on CPU or all on CUDA.");
8262

83-
if (in_var->IsType<framework::SelectedRows>()) {
84-
auto &in_sr = in_var->Get<framework::SelectedRows>();
85-
auto out_sr = out_var->GetMutable<framework::SelectedRows>();
86-
if (&in_sr == out_sr) continue;
87-
out_sr->set_height(in_sr.height());
88-
out_sr->set_rows(in_sr.rows());
89-
out_sr->mutable_value()->Resize(in_sr.value().dims());
90-
out_sr->mutable_value()->mutable_data(out_p, in_sr.value().type());
91-
} else if (in_var->IsType<framework::LoDTensor>()) {
92-
auto in_lod = in_var->Get<framework::LoDTensor>();
93-
auto out_lod = out_var->GetMutable<framework::LoDTensor>();
94-
if (&in_lod == out_lod) continue;
95-
out_lod->set_lod(in_lod.lod());
96-
out_lod->Resize(in_lod.dims());
97-
out_lod->mutable_data(out_p, in_lod.type());
98-
} else {
99-
PADDLE_THROW("Var should be LoDTensor or SelectedRows.");
100-
}
63+
VariableVisitor::ShareDimsAndLoD(*in_var, out_var);
64+
VariableVisitor::GetMutableTensor(out_var)
65+
.Resize(in_tensor.dims())
66+
.mutable_data(out_p, in_tensor.type());
10167

102-
Tensor *out_tensor = GetTensorFromVar(out_var);
103-
paddle::framework::TensorCopy(*in_tensor, out_p, *(dev_ctxes_[in_place]),
104-
out_tensor);
68+
auto dev_ctx = dev_ctxes_[out_p];
69+
RunAndRecordEvent(out_p, [in_tensor, out_var, dev_ctx, out_p] {
70+
paddle::framework::TensorCopy(
71+
in_tensor, out_p, *(dev_ctx),
72+
&VariableVisitor::GetMutableTensor(out_var));
73+
});
74+
}
75+
}
76+
77+
void BroadcastOpHandle::WaitInputVarGenerated(const VarHandle &in_var) {
78+
if (in_var.generated_op_) {
79+
for (auto &pair : dev_ctxes_) {
80+
in_var.generated_op_->Wait(pair.second);
81+
}
10582
}
10683
}
10784

paddle/fluid/framework/details/broadcast_op_handle.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,12 @@ struct BroadcastOpHandle : public OpHandleBase {
3939

4040
protected:
4141
void RunImpl() override;
42+
void WaitInputVarGenerated(const VarHandle &in_var);
4243

4344
private:
4445
const std::vector<Scope *> &local_scopes_;
4546
const std::vector<platform::Place> &places_;
4647
};
47-
4848
} // namespace details
4949
} // namespace framework
5050
} // namespace paddle
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <type_traits>
18+
#include <vector>
19+
20+
namespace paddle {
21+
namespace framework {
22+
namespace details {
23+
24+
template <typename ResultType, typename ElemType>
25+
std::vector<ResultType*> DynamicCast(const std::vector<ElemType*>& container) {
26+
static_assert(std::is_base_of<ElemType, ResultType>::value,
27+
"ElementType must be a base class of ResultType");
28+
std::vector<ResultType*> res;
29+
for (auto* ptr : container) {
30+
auto* derived = dynamic_cast<ResultType*>(ptr);
31+
if (derived) {
32+
res.emplace_back(derived);
33+
}
34+
}
35+
return res;
36+
}
37+
38+
} // namespace details
39+
} // namespace framework
40+
} // namespace paddle

paddle/fluid/framework/details/gather_op_handle.cc

Lines changed: 43 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/framework/details/gather_op_handle.h"
16+
#include "paddle/fluid/framework/details/container_cast.h"
17+
#include "paddle/fluid/framework/details/variable_visitor.h"
1618

1719
namespace paddle {
1820
namespace framework {
@@ -23,81 +25,68 @@ GatherOpHandle::GatherOpHandle(const std::vector<Scope *> &local_scopes,
2325
: local_scopes_(local_scopes), places_(places) {}
2426

2527
void GatherOpHandle::RunImpl() {
26-
// the input may have dummy var.
27-
std::vector<VarHandle *> in_var_handles;
28-
for (auto *in : inputs_) {
29-
auto *in_handle = dynamic_cast<VarHandle *>(in);
30-
if (in_handle) {
31-
in_var_handles.push_back(in_handle);
32-
}
33-
}
28+
// the input and output may have dummy var.
29+
auto in_var_handles = DynamicCast<VarHandle>(inputs_);
30+
3431
PADDLE_ENFORCE_EQ(
3532
in_var_handles.size(), places_.size(),
3633
"The number of output should equal to the number of places.");
3734

38-
// the output may have dummy var.
39-
std::vector<VarHandle *> out_var_handles;
40-
for (auto *out : outputs_) {
41-
auto *out_handle = dynamic_cast<VarHandle *>(out);
42-
if (out_handle) {
43-
out_var_handles.push_back(out_handle);
44-
}
35+
VarHandle *out_var_handle;
36+
{
37+
auto out_var_handles = DynamicCast<VarHandle>(outputs_);
38+
39+
PADDLE_ENFORCE_EQ(out_var_handles.size(), 1,
40+
"The number of output should be one.");
41+
out_var_handle = out_var_handles.front();
4542
}
46-
PADDLE_ENFORCE_EQ(out_var_handles.size(), 1,
47-
"The number of output should be one.");
4843

49-
auto in_0_handle = static_cast<VarHandle *>(in_var_handles[0]);
44+
auto in_0_handle = in_var_handles[0];
5045
auto pre_in_var =
5146
local_scopes_[in_0_handle->scope_idx_]->FindVar(in_0_handle->name_);
5247
auto pre_place = in_0_handle->place_;
5348

5449
PADDLE_ENFORCE(pre_in_var->IsType<framework::SelectedRows>(),
5550
"Currently, gather_op only can gather SelectedRows.");
5651

57-
PADDLE_ENFORCE_EQ(out_var_handles[0]->place_.which(), pre_place.which(),
52+
PADDLE_ENFORCE_EQ(out_var_handle->place_.which(), pre_place.which(),
5853
"The place of input and output should be the same.");
5954

6055
// Wait input done, this Wait is asynchronous operation
61-
for (auto *in : in_var_handles) {
62-
if (in->generated_op_) {
63-
in->generated_op_->Wait(dev_ctxes_[in->place_]);
64-
}
65-
}
56+
WaitInputVarGenerated(in_var_handles);
6657

6758
std::vector<int64_t> out_rows;
6859
std::vector<Tensor> in_tensors;
6960
std::vector<platform::Place> in_places;
7061

7162
auto &pre_in = pre_in_var->Get<framework::SelectedRows>();
7263
// gather the inputs
73-
for (auto *in : in_var_handles) {
74-
auto in_handle = static_cast<VarHandle *>(in);
64+
for (auto *in_handle : in_var_handles) {
7565
auto in_p = in_handle->place_;
7666
in_places.push_back(in_p);
7767
PADDLE_ENFORCE_EQ(in_p.which(), pre_place.which(),
7868
"Places must be all on CPU or all on CUDA.");
79-
auto in_var =
69+
auto *in_var =
8070
local_scopes_.at(in_handle->scope_idx_)->FindVar(in_handle->name_);
8171
auto &in_sr = in_var->Get<framework::SelectedRows>();
8272

8373
PADDLE_ENFORCE_EQ(in_sr.value().type(), pre_in.value().type(),
8474
"The type of input is not consistent.");
8575
PADDLE_ENFORCE_EQ(pre_in.height(), in_sr.height(),
8676
"The height of inputs is not consistent.");
87-
PADDLE_ENFORCE_EQ(pre_in.GetCompleteDims(), in_sr.GetCompleteDims(), ,
77+
PADDLE_ENFORCE_EQ(pre_in.GetCompleteDims(), in_sr.GetCompleteDims(),
8878
"The dims of inputs is not consistent.");
8979

90-
auto in_sr_rows = in_sr.rows();
80+
auto &in_sr_rows = in_sr.rows();
9181
out_rows.insert(out_rows.end(), in_sr_rows.begin(), in_sr_rows.end());
9282

9383
in_tensors.emplace_back(in_sr.value());
9484
}
9585

9686
// write the output
97-
auto &out_place = out_var_handles[0]->place_;
98-
auto out_scope_idx = out_var_handles[0]->scope_idx_;
99-
auto out_var =
100-
local_scopes_[out_scope_idx]->FindVar(out_var_handles[0]->name_);
87+
auto &out_place = out_var_handle->place_;
88+
auto out_scope_idx = out_var_handle->scope_idx_;
89+
auto out_var = local_scopes_[out_scope_idx]->FindVar(out_var_handle->name_);
10190

10291
auto out = out_var->GetMutable<framework::SelectedRows>();
10392
out->set_height(pre_in.height());
@@ -110,13 +99,27 @@ void GatherOpHandle::RunImpl() {
11099
Tensor *out_tensor = out->mutable_value();
111100

112101
// copy
113-
int s = 0, e = 0;
114-
for (size_t j = 0; j < in_tensors.size(); ++j) {
115-
e += in_tensors[j].dims()[0];
116-
auto sub_out = out_tensor->Slice(s, e);
117-
paddle::framework::TensorCopy(in_tensors[j], out_place,
118-
*(dev_ctxes_[in_places[j]]), &sub_out);
119-
s = e;
102+
auto dev_ctx = dev_ctxes_[out_place];
103+
RunAndRecordEvent(out_place, [in_tensors, out_tensor, dev_ctx, out_place] {
104+
int s = 0, e = 0;
105+
for (size_t j = 0; j < in_tensors.size(); ++j) {
106+
e += in_tensors[j].dims()[0];
107+
auto sub_out = out_tensor->Slice(s, e);
108+
paddle::framework::TensorCopy(in_tensors[j], out_place, *(dev_ctx),
109+
&sub_out);
110+
s = e;
111+
}
112+
});
113+
}
114+
115+
void GatherOpHandle::WaitInputVarGenerated(
116+
const std::vector<VarHandle *> &in_var_handles) {
117+
for (auto *in : in_var_handles) {
118+
if (in->generated_op_) {
119+
for (auto pair : dev_ctxes_) {
120+
in->generated_op_->Wait(pair.second);
121+
}
122+
}
120123
}
121124
}
122125

paddle/fluid/framework/details/gather_op_handle.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ struct GatherOpHandle : public OpHandleBase {
3939

4040
protected:
4141
void RunImpl() override;
42+
void WaitInputVarGenerated(const std::vector<VarHandle *> &in_var_handles);
4243

4344
private:
4445
const std::vector<Scope *> &local_scopes_;

paddle/fluid/framework/details/var_handle.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,11 @@ struct VarHandle : public VarHandleBase {
6161
size_t scope_idx_;
6262
std::string name_;
6363
platform::Place place_;
64+
65+
bool operator==(const VarHandle& o) const {
66+
return o.generated_op_ == generated_op_ && o.name_ == name_ &&
67+
o.scope_idx_ == scope_idx_;
68+
}
6469
};
6570

6671
// Dummy Variable. It is used to represent dependencies between operators

0 commit comments

Comments
 (0)