Skip to content

Commit b43d87c

Browse files
author
chengduo
authored
Merge pull request #9825 from chengduoZH/feature/add_gather_and_BCast_op_handle
feature/Add Broadcast and Gather op handle
2 parents e4cfe47 + 384d6ee commit b43d87c

File tree

9 files changed

+769
-4
lines changed

9 files changed

+769
-4
lines changed

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
cc_library(var_handle SRCS var_handle.cc DEPS place)
2-
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context)
2+
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context lod_tensor)
33
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
44
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
55
nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
@@ -20,3 +20,11 @@ cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS
2020
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
2121
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
2222
simple_threadpool device_context)
23+
24+
cc_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory)
25+
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory)
26+
27+
cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
28+
device_context broadcast_op_handle)
29+
cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
30+
device_context gather_op_handle)
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
16+
17+
namespace paddle {
18+
namespace framework {
19+
namespace details {
20+
21+
Tensor *GetTensorFromVar(Variable *in_var) {
22+
if (in_var->IsType<LoDTensor>()) {
23+
return in_var->GetMutable<LoDTensor>();
24+
} else if (in_var->IsType<SelectedRows>()) {
25+
return in_var->GetMutable<SelectedRows>()->mutable_value();
26+
} else {
27+
PADDLE_THROW("Var should be LoDTensor or SelectedRows");
28+
}
29+
return nullptr;
30+
}
31+
32+
BroadcastOpHandle::BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
33+
const std::vector<platform::Place> &places)
34+
: local_scopes_(local_scopes), places_(places) {}
35+
36+
void BroadcastOpHandle::RunImpl() {
37+
// the input may have dummy var.
38+
std::vector<VarHandle *> in_var_handle;
39+
for (auto *in : inputs_) {
40+
auto *out_handle = dynamic_cast<VarHandle *>(in);
41+
if (out_handle) {
42+
in_var_handle.push_back(out_handle);
43+
}
44+
}
45+
PADDLE_ENFORCE_EQ(in_var_handle.size(), 1,
46+
"The number of input should be one.");
47+
48+
// the output may have dummy var.
49+
std::vector<VarHandle *> out_var_handles;
50+
for (auto *out : outputs_) {
51+
auto *out_handle = dynamic_cast<VarHandle *>(out);
52+
if (out_handle) {
53+
out_var_handles.push_back(out_handle);
54+
}
55+
}
56+
57+
PADDLE_ENFORCE_EQ(
58+
out_var_handles.size(), places_.size(),
59+
"The number of output should equal to the number of places.");
60+
61+
// Wait input done, this Wait is asynchronous operation
62+
auto &in_place = in_var_handle[0]->place_;
63+
if (in_var_handle[0]->generated_op_) {
64+
for (auto *out : out_var_handles) {
65+
auto &out_p = out->place_;
66+
in_var_handle[0]->generated_op_->Wait(dev_ctxes_[out_p]);
67+
}
68+
}
69+
70+
//
71+
auto in_scope_idx = in_var_handle[0]->scope_idx_;
72+
auto in_var =
73+
local_scopes_.at(in_scope_idx)->FindVar(in_var_handle[0]->name_);
74+
Tensor *in_tensor = GetTensorFromVar(in_var);
75+
76+
for (auto *out : out_var_handles) {
77+
auto &out_p = out->place_;
78+
auto out_var = local_scopes_.at(out->scope_idx_)->FindVar(out->name_);
79+
80+
PADDLE_ENFORCE_EQ(out_p.which(), in_place.which(),
81+
"Places must be all on CPU or all on CUDA.");
82+
83+
if (in_var->IsType<framework::SelectedRows>()) {
84+
auto &in_sr = in_var->Get<framework::SelectedRows>();
85+
auto out_sr = out_var->GetMutable<framework::SelectedRows>();
86+
if (&in_sr == out_sr) continue;
87+
out_sr->set_height(in_sr.height());
88+
out_sr->set_rows(in_sr.rows());
89+
out_sr->mutable_value()->Resize(in_sr.value().dims());
90+
out_sr->mutable_value()->mutable_data(out_p, in_sr.value().type());
91+
} else if (in_var->IsType<framework::LoDTensor>()) {
92+
auto in_lod = in_var->Get<framework::LoDTensor>();
93+
auto out_lod = out_var->GetMutable<framework::LoDTensor>();
94+
if (&in_lod == out_lod) continue;
95+
out_lod->set_lod(in_lod.lod());
96+
out_lod->Resize(in_lod.dims());
97+
out_lod->mutable_data(out_p, in_lod.type());
98+
} else {
99+
PADDLE_THROW("Var should be LoDTensor or SelectedRows.");
100+
}
101+
102+
Tensor *out_tensor = GetTensorFromVar(out_var);
103+
paddle::framework::TensorCopy(*in_tensor, out_p, *(dev_ctxes_[in_place]),
104+
out_tensor);
105+
}
106+
}
107+
108+
std::string BroadcastOpHandle::Name() const { return "broadcast"; }
109+
} // namespace details
110+
} // namespace framework
111+
} // namespace paddle
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <map>
18+
#include <string>
19+
#include <vector>
20+
21+
#include "paddle/fluid/framework/details/op_handle_base.h"
22+
#include "paddle/fluid/framework/lod_tensor.h"
23+
#include "paddle/fluid/framework/scope.h"
24+
#include "paddle/fluid/framework/selected_rows.h"
25+
#include "paddle/fluid/platform/device_context.h"
26+
27+
namespace paddle {
28+
namespace framework {
29+
namespace details {
30+
31+
struct BroadcastOpHandle : public OpHandleBase {
32+
const std::vector<Scope *> &local_scopes_;
33+
const std::vector<platform::Place> &places_;
34+
35+
BroadcastOpHandle(const std::vector<Scope *> &local_scopes,
36+
const std::vector<platform::Place> &places);
37+
38+
std::string Name() const override;
39+
40+
bool IsMultiDeviceTransfer() override { return false; };
41+
42+
protected:
43+
void RunImpl() override;
44+
};
45+
46+
} // namespace details
47+
} // namespace framework
48+
} // namespace paddle

0 commit comments

Comments
 (0)