Skip to content

Commit 652cf43

Browse files
authored
Merge pull request #9746 from typhoonzero/multigpumultinode
[Feature] Enable multi gpu distributed training of fluid
2 parents 2a3d490 + dfc6025 commit 652cf43

File tree

11 files changed

+153
-23
lines changed

11 files changed

+153
-23
lines changed

paddle/fluid/framework/details/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod
55
nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
66
dynload_cuda)
77
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
8+
cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry)
89

910
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
1011
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
@@ -15,7 +16,7 @@ else()
1516
set(multi_devices_graph_builder_deps)
1617
endif()
1718
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
18-
scale_loss_grad_op_handle ${multi_devices_graph_builder_deps})
19+
scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps})
1920
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
2021
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
2122
simple_threadpool device_context)

paddle/fluid/framework/details/multi_devices_graph_builder.cc

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
1616
#include "paddle/fluid/framework/details/computation_op_handle.h"
1717
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
18+
#include "paddle/fluid/framework/details/send_op_handle.h"
1819
#include "paddle/fluid/framework/scope.h"
1920

2021
#ifdef PADDLE_WITH_CUDA
@@ -54,6 +55,27 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
5455
}
5556
}
5657

58+
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, OpDesc *op,
59+
const platform::Place &p,
60+
const size_t &i) const {
61+
auto *op_handle = result->ops_.back().get();
62+
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
63+
platform::DeviceContextPool::Instance().Get(p));
64+
65+
auto var_names = op->InputArgumentNames();
66+
67+
for (auto &each_var_name : var_names) {
68+
VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i);
69+
op_handle->AddInput(var);
70+
}
71+
72+
var_names = op->OutputArgumentNames();
73+
74+
for (auto &each_var_name : var_names) {
75+
CreateOpOutput(result, op_handle, each_var_name, p, i);
76+
}
77+
}
78+
5779
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
5880
const ProgramDesc &program) const {
5981
auto graph = new SSAGraph();
@@ -76,27 +98,28 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
7698
}
7799
}
78100

101+
// append send op if program is distributed trainer main program.
102+
// always use the first device
103+
if (!is_forwarding && op->Type() == "send") {
104+
auto &p = places_[0];
105+
auto *s = local_scopes_[0];
106+
// FIXME(wuyi): send op always copy from GPU 0
107+
result.ops_.emplace_back(new SendOpHandle(*op, s, p));
108+
// Create inputs for output on original place and no ssa output
109+
// is created for send op.
110+
CreateOpHandleIOs(&result, op, p, 0);
111+
continue;
112+
}
113+
79114
for (size_t i = 0; i < places_.size(); ++i) {
80115
auto &p = places_[i];
81116
auto *s = local_scopes_[i];
82117

83118
result.ops_.emplace_back(new ComputationOpHandle(*op, s, p));
84119
auto *op_handle = result.ops_.back().get();
85-
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
86-
platform::DeviceContextPool::Instance().Get(p));
120+
CreateOpHandleIOs(&result, op, p, i);
87121

88-
auto var_names = op->InputArgumentNames();
89-
90-
for (auto &each_var_name : var_names) {
91-
VarHandle *var =
92-
CreateOrGetLatestVarHandle(&result, each_var_name, p, i);
93-
op_handle->AddInput(var);
94-
}
95-
var_names = op->OutputArgumentNames();
96-
97-
for (auto &each_var_name : var_names) {
98-
CreateOpOutput(&result, op_handle, each_var_name, p, i);
99-
}
122+
auto var_names = op->OutputArgumentNames();
100123

101124
if (is_forwarding) {
102125
if (var_names.size() == 1 && var_names[0] == loss_var_name_) {

paddle/fluid/framework/details/multi_devices_graph_builder.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
#pragma once
1616

17+
#include <string>
18+
#include <vector>
19+
1720
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
1821

1922
namespace paddle {
@@ -41,6 +44,10 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
4144

4245
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
4346

47+
private:
48+
void CreateOpHandleIOs(SSAGraph *result, OpDesc *op, const platform::Place &p,
49+
const size_t &i) const;
50+
4451
private:
4552
std::string loss_var_name_;
4653
const std::vector<platform::Place> &places_;
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/details/send_op_handle.h"
16+
17+
namespace paddle {
18+
namespace framework {
19+
namespace details {
20+
21+
SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc,
22+
const Scope *local_scope,
23+
const platform::Place &place)
24+
: op_(framework::OpRegistry::CreateOp(op_desc)),
25+
local_scope_(local_scope),
26+
place_(place) {}
27+
28+
void SendOpHandle::RunImpl() {
29+
// Wait input done
30+
for (auto *in : inputs_) {
31+
auto &p = static_cast<VarHandle *>(in)->place_;
32+
if (in->DebugString() == "dummy") { // HACK
33+
continue;
34+
}
35+
in->generated_op_->Wait(dev_ctxes_[p]);
36+
}
37+
op_->Run(*local_scope_, place_);
38+
}
39+
40+
std::string SendOpHandle::Name() const { return "send"; }
41+
} // namespace details
42+
} // namespace framework
43+
} // namespace paddle
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <string>
18+
#include <vector>
19+
20+
#include "paddle/fluid/framework/details/op_handle_base.h"
21+
#include "paddle/fluid/framework/lod_tensor.h"
22+
#include "paddle/fluid/framework/op_registry.h"
23+
#include "paddle/fluid/framework/operator.h"
24+
#include "paddle/fluid/framework/scope.h"
25+
26+
namespace paddle {
27+
namespace framework {
28+
namespace details {
29+
30+
struct SendOpHandle : public OpHandleBase {
31+
std::unique_ptr<OperatorBase> op_;
32+
const Scope* local_scope_;
33+
const platform::Place& place_;
34+
35+
SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
36+
const platform::Place& place);
37+
38+
std::string Name() const override;
39+
40+
// Delay and buffer nccl_all_reduce together can significantly increase
41+
// performance. Disable this feature by returning false.
42+
bool IsMultiDeviceTransfer() override { return false; };
43+
44+
protected:
45+
void RunImpl() override;
46+
};
47+
48+
} // namespace details
49+
} // namespace framework
50+
} // namespace paddle

paddle/fluid/framework/parallel_executor.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,13 +48,13 @@ class ParallelExecutor {
4848
const std::string& fetched_var_name,
4949
const std::unordered_map<std::string, LoDTensor>& feed_tensors);
5050

51+
void BCastParamsToGPUs(const std::unordered_set<std::string>& vars) const;
52+
5153
private:
5254
void SplitTensorToPlaces(
5355
const std::unordered_map<std::string, LoDTensor>& feed_tensors);
5456

5557
ParallelExecutorPrivate* member_;
56-
57-
void BCastParamsToGPUs(const std::unordered_set<std::string>& vars) const;
5858
};
5959

6060
} // namespace framework

paddle/fluid/operators/detail/grpc_client.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,8 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
6565
}
6666

6767
void ProcGetResponse(const VarHandle& var_h,
68-
// const sendrecv::VariableMessage& ret_msg) {
6968
const ::grpc::ByteBuffer& ret_msg) {
70-
framework::Variable* outvar = NULL;
69+
framework::Variable* outvar = nullptr;
7170
DeserializeFromByteBuffer(ret_msg, *var_h.ctx, var_h.scope, &outvar);
7271
}
7372

paddle/fluid/operators/detail/serde_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
107107
for (int i = 0; i < tensor_numel; ++i) {
108108
EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
109109
}
110-
for (int64_t i = 0; i < rows2->size(); ++i) {
110+
for (size_t i = 0; i < rows2->size(); ++i) {
111111
EXPECT_EQ(rows_data2[i], i);
112112
}
113113
EXPECT_EQ(slr2->height(), 1000);

paddle/fluid/pybind/pybind.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,7 @@ All parameter, weight, gradient are variables in Paddle.
553553
bcast_vars, main_program, loss_var_name,
554554
scope, local_scopes, allow_op_delay);
555555
})
556+
.def("bcast_params", &ParallelExecutor::BCastParamsToGPUs)
556557
.def("local_scopes",
557558
[](ParallelExecutor &self) -> std::vector<Scope *> * {
558559
return &self.GetLocalScopes();

python/paddle/fluid/distribute_transpiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ def transpile(self,
255255
def get_trainer_program(self):
256256
# remove optimize ops and add a send op to main_program
257257
self.program.global_block().delete_ops(self.optimize_ops)
258+
self.program.sync_with_cpp()
258259
# FIXME(typhoonzero): serialize once will fix error occurs when clone.
259260
self.program.__str__()
260261
return self.program

0 commit comments

Comments
 (0)