Skip to content

Commit 833d0ad

Browse files
authored
Merge pull request #4838 from dzhwinter/feature/multigpu
Feature/multigpu
2 parents b84e822 + 71305e5 commit 833d0ad

File tree

12 files changed

+898
-7
lines changed

12 files changed

+898
-7
lines changed

paddle/framework/op_registry.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,10 @@ class OpKernelRegistrar : public Registrar {
228228
USE_OP_ITSELF(op_type); \
229229
USE_OP_DEVICE_KERNEL(op_type, CPU);
230230

231+
#define USE_GPU_ONLY_OP(op_type) \
232+
USE_OP_ITSELF(op_type); \
233+
USE_OP_DEVICE_KERNEL(op_type, GPU)
234+
231235
#define USE_OP(op_type) \
232236
USE_OP_ITSELF(op_type); \
233237
USE_OP_KERNEL(op_type)

paddle/framework/operator.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ class OperatorBase {
122122
protected:
123123
std::string type_;
124124
// NOTE: in case of OpGrad, inputs_ contains:
125-
// I (Inputs)opear
125+
// I (Inputs)
126126
// O (Outputs)
127127
// OG (Output Gradients)
128128
VariableNameMap inputs_;
@@ -287,6 +287,16 @@ class ExecutionContext {
287287
return device_context_;
288288
}
289289

290+
//! Get actual name vector for this input.
291+
const std::vector<std::string>& Inputs(const std::string& name) const {
292+
return op_.Inputs(name);
293+
}
294+
295+
//! Get actual name vector for this output.
296+
const std::vector<std::string>& Outputs(const std::string& name) const {
297+
return op_.Outputs(name);
298+
}
299+
290300
#ifdef PADDLE_WITH_CUDA
291301
const platform::CUDADeviceContext& cuda_device_context() const {
292302
PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace()));

paddle/operators/CMakeLists.txt

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,13 @@ function(op_library TARGET)
9090
file(APPEND ${pybind_file} "USE_OP(sigmoid);\n")
9191
endif()
9292

93+
# nccl_op contains several operators
94+
if ("${TARGET}" STREQUAL "nccl_op")
95+
set(pybind_flag 1)
96+
# It's enough to just adding one operator to pybind
97+
file(APPEND ${pybind_file} "USE_GPU_ONLY_OP(ncclAllReduce);\n")
98+
endif()
99+
93100
# reduce_op contains several operators
94101
if ("${TARGET}" STREQUAL "reduce_op")
95102
set(pybind_flag 1)
@@ -121,6 +128,7 @@ function(op_library TARGET)
121128
endfunction()
122129

123130
add_subdirectory(math)
131+
add_subdirectory(nccl)
124132

125133
set(DEPS_OPS
126134
recurrent_op
@@ -130,6 +138,7 @@ set(DEPS_OPS
130138
sum_op
131139
pool_op
132140
pool_with_index_op
141+
nccl_op
133142
sequence_conv_op
134143
lstm_op)
135144

@@ -142,6 +151,9 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
142151
op_library(sum_op DEPS net_op selected_rows_functor)
143152
op_library(pool_op DEPS pooling)
144153
op_library(pool_with_index_op DEPS pooling)
154+
if(WITH_GPU)
155+
op_library(nccl_op DEPS nccl_common)
156+
endif()
145157
op_library(sequence_conv_op DEPS context_project)
146158
op_library(lstm_op DEPS sequence2batch lstm_compute)
147159

@@ -157,4 +169,8 @@ cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)
157169
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
158170
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memory)
159171
cc_test(dynamic_recurrent_op_test SRCS dynamic_recurrent_op_test.cc DEPS dynamic_recurrent_op recurrent_op tensor_array)
172+
173+
if(WITH_GPU)
174+
nv_test(nccl_op_test SRCS nccl_op_test.cu DEPS nccl_op gpu_info device_context)
175+
endif()
160176
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)

paddle/operators/nccl/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
if(WITH_GPU)
2+
nv_library(nccl_common SRCS nccl_gpu_common.cc DEPS device_context operator )
3+
endif()
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/operators/nccl/nccl_gpu_common.h"
13+
#include "paddle/platform/gpu_info.h"
14+
15+
namespace paddle {
16+
namespace platform {} // namespace platform
17+
} // namespace paddle
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <algorithm>
18+
#include <condition_variable>
19+
#include <memory>
20+
#include <mutex>
21+
#include <string>
22+
#include <unordered_map>
23+
#include <vector>
24+
25+
#include "paddle/platform/device_context.h"
26+
#include "paddle/platform/dynload/nccl.h"
27+
#include "paddle/platform/enforce.h"
28+
#include "paddle/platform/macros.h"
29+
30+
namespace paddle {
31+
namespace platform {
32+
33+
constexpr int kInvalidGPUId = -1;
34+
35+
struct Communicator {
36+
std::vector<ncclComm_t> comms_;
37+
std::unordered_map<int, int> comm_id_map_;
38+
39+
Communicator() {}
40+
41+
int GetCommId(int device_id) const { return comm_id_map_.at(device_id); }
42+
43+
void InitAll(const std::vector<int>& gpus) {
44+
comms_.resize(gpus.size());
45+
for (size_t i = 0; i < gpus.size(); ++i) {
46+
comm_id_map_[gpus[i]] = i;
47+
}
48+
PADDLE_ENFORCE(
49+
dynload::ncclCommInitAll(comms_.data(), gpus.size(), gpus.data()));
50+
}
51+
52+
~Communicator() {
53+
for (size_t i = 0; i < comms_.size(); ++i) {
54+
// FIXME(dzh) : PADDLE_ENFORCE return void
55+
dynload::ncclCommDestroy(comms_[i]);
56+
}
57+
}
58+
59+
DISABLE_COPY_AND_ASSIGN(Communicator);
60+
};
61+
62+
} // namespace platform
63+
} // namespace paddle

paddle/operators/nccl_op.cc

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/framework/op_registry.h"
13+
#include "paddle/operators/nccl/nccl_gpu_common.h"
14+
15+
namespace paddle {
16+
namespace operators {
17+
18+
// NCCLinitOp
19+
class NCCLInitOp : public framework::OperatorBase {
20+
public:
21+
NCCLInitOp(const std::string &type, const framework::VariableNameMap &inputs,
22+
const framework::VariableNameMap &outputs,
23+
const framework::AttributeMap &attrs)
24+
: OperatorBase(type, inputs, outputs, attrs) {}
25+
26+
void Run(const framework::Scope &scope,
27+
const platform::DeviceContext &dev_ctx) const override {
28+
const auto &name = Output("Communicator");
29+
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name),
30+
"Can not find variable '%s' in the scope.", name);
31+
std::vector<int> gpus = Attr<std::vector<int>>("gpus");
32+
PADDLE_ENFORCE(!gpus.empty(), "Attr(gpus) should not be empty.");
33+
34+
if (scope.FindVar(name) == nullptr) {
35+
PADDLE_THROW("Output(Communicator) is needed for ncclInit operator.");
36+
}
37+
38+
platform::Communicator *comm =
39+
scope.FindVar(name)->GetMutable<platform::Communicator>();
40+
comm->InitAll(gpus);
41+
}
42+
};
43+
44+
class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker {
45+
public:
46+
NCCLInitOpMaker(framework::OpProto *proto,
47+
framework::OpAttrChecker *op_checker)
48+
: OpProtoAndCheckerMaker(proto, op_checker) {
49+
AddOutput("Communicator",
50+
"Create Communicator for communicating between gpus");
51+
AddAttr<std::vector<int>>("gpus", "gpu id lists");
52+
AddAttr<int>("data_type", "output data type")
53+
.SetDefault(framework::DataType::FP32);
54+
AddComment(R"DOC(
55+
create communicator.
56+
)DOC");
57+
}
58+
};
59+
60+
// AllReduceOp
61+
class NCCLAllReduceOp : public framework::OperatorWithKernel {
62+
public:
63+
using framework::OperatorWithKernel::OperatorWithKernel;
64+
65+
protected:
66+
void InferShape(framework::InferShapeContext *ctx) const override {
67+
PADDLE_ENFORCE(ctx->HasInput("X"),
68+
" Input(X) of AllReduce op input should not be NULL");
69+
PADDLE_ENFORCE(
70+
ctx->HasInput("Communicator"),
71+
" Input(Communicator) of AllReduce op input should not be NULL");
72+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
73+
" Input(X) of AllReduce op input should not be NULL");
74+
75+
auto x_dims = ctx->GetInputsDim("X");
76+
77+
std::string reduction = ctx->Attrs().Get<std::string>("reduction");
78+
PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" ||
79+
reduction == "ncclMin" || reduction == "ncclMax"),
80+
"invalid reduction.");
81+
82+
ctx->SetOutputsDim("Out", x_dims);
83+
ctx->ShareLoD("X", /*->*/ "Out");
84+
}
85+
};
86+
87+
// ReduceOp
88+
class NCCLReduceOp : public framework::OperatorWithKernel {
89+
public:
90+
using framework::OperatorWithKernel::OperatorWithKernel;
91+
92+
protected:
93+
void InferShape(framework::InferShapeContext *ctx) const override {
94+
PADDLE_ENFORCE(ctx->HasInput("X"),
95+
" Input(X) of Reduce op input should not be NULL");
96+
PADDLE_ENFORCE(
97+
ctx->HasInput("Communicator"),
98+
" Input(Communicator) of Reduce op input should not be NULL");
99+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
100+
" Input(X) of Reduce op input should not be NULL");
101+
102+
std::string reduction = ctx->Attrs().Get<std::string>("reduction");
103+
PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" ||
104+
reduction == "ncclMin" || reduction == "ncclMax"),
105+
"invalid reduction.");
106+
107+
auto x_dims = ctx->GetInputsDim("X");
108+
ctx->SetOutputsDim("Out", x_dims);
109+
ctx->ShareLoD("X", /*->*/ "Out");
110+
}
111+
};
112+
113+
// BcastOp
114+
class NCCLBcastOp : public framework::OperatorWithKernel {
115+
public:
116+
using framework::OperatorWithKernel::OperatorWithKernel;
117+
118+
protected:
119+
void InferShape(framework::InferShapeContext *ctx) const override {
120+
PADDLE_ENFORCE(ctx->HasInput("X"),
121+
" Input(X) of Bcast op input should not be NULL");
122+
PADDLE_ENFORCE(ctx->HasInput("Communicator"),
123+
" Input(Communicator) of Bcast op input should not be NULL");
124+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
125+
" Output(Out) of Bcast op output should not be NULL");
126+
127+
int root = ctx->Attrs().Get<int>("root");
128+
PADDLE_ENFORCE(root != platform::kInvalidGPUId, "Bcast root must be set.");
129+
130+
auto x_dims = ctx->GetInputsDim("X");
131+
ctx->SetOutputsDim("Out", x_dims);
132+
ctx->ShareLoD("X", /*->*/ "Out");
133+
}
134+
};
135+
136+
// AllreduceOp
137+
class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
138+
public:
139+
NCCLAllReduceOpMaker(framework::OpProto *proto,
140+
framework::OpAttrChecker *op_checker)
141+
: OpProtoAndCheckerMaker(proto, op_checker) {
142+
AddInput("X", "The input of AllReduce op");
143+
AddInput("Communicator", "Communicator for communicating between gpus");
144+
AddOutput("Out", "The output of AllReduce op");
145+
AddAttr<std::string>("reduction",
146+
"{'ncclMin', 'ncclMax', 'ncclProd', 'ncclSum'}.")
147+
.SetDefault("ncclSum");
148+
AddComment(R"DOC(
149+
AllReduce the input tensors.
150+
)DOC");
151+
}
152+
};
153+
154+
// ReduceOp
155+
class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker {
156+
public:
157+
NCCLReduceOpMaker(framework::OpProto *proto,
158+
framework::OpAttrChecker *op_checker)
159+
: OpProtoAndCheckerMaker(proto, op_checker) {
160+
AddInput("X", "The input of Reduce op");
161+
AddInput("Communicator", "Communicator for communicating between gpus");
162+
AddOutput("Out", "The output of Reduce op");
163+
AddAttr<std::string>("reduction",
164+
"{'ncclMin', 'ncclMax', 'ncclProd', 'ncclSum'}.")
165+
.SetDefault("ncclSum");
166+
AddAttr<int>("root",
167+
"root gpu of the parameter. if not "
168+
"set(platform::kInvalidGPUId). hashed by name.")
169+
.SetDefault(platform::kInvalidGPUId);
170+
AddComment(R"DOC(
171+
Reduce the tensors)DOC");
172+
}
173+
};
174+
175+
// BcastOp
176+
class NCCLBcastOpMaker : public framework::OpProtoAndCheckerMaker {
177+
public:
178+
NCCLBcastOpMaker(framework::OpProto *proto,
179+
framework::OpAttrChecker *op_checker)
180+
: OpProtoAndCheckerMaker(proto, op_checker) {
181+
AddInput("X", "The input of BcastSend op");
182+
AddInput("Communicator", "Communicator for communicating between gpus");
183+
AddOutput("Out", "The output of Bcast");
184+
AddAttr<int>("root",
185+
"root gpu of the parameter. if not "
186+
"set(platform::kInvalidGPUId). hashed by name.")
187+
.SetDefault(platform::kInvalidGPUId);
188+
AddComment(R"DOC(
189+
Bcast the tensors.
190+
)DOC");
191+
}
192+
};
193+
194+
} // namespace operators
195+
} // namespace paddle
196+
197+
namespace ops = paddle::operators;
198+
REGISTER_OPERATOR(ncclInit, ops::NCCLInitOp,
199+
paddle::framework::EmptyGradOpMaker, ops::NCCLInitOpMaker);
200+
201+
REGISTER_OP_WITHOUT_GRADIENT(ncclAllReduce, ops::NCCLAllReduceOp,
202+
ops::NCCLAllReduceOpMaker);
203+
REGISTER_OP_WITHOUT_GRADIENT(ncclBcast, ops::NCCLBcastOp,
204+
ops::NCCLBcastOpMaker);
205+
REGISTER_OP_WITHOUT_GRADIENT(ncclReduce, ops::NCCLReduceOp,
206+
ops::NCCLReduceOpMaker);

0 commit comments

Comments
 (0)