Skip to content

Commit 14fe40a

Browse files
authored
Refine/nccl (#9009)
* "Refine nccl op" * "refine code " * "refine nccl code"
1 parent 788c600 commit 14fe40a

File tree

2 files changed

+89
-142
lines changed

2 files changed

+89
-142
lines changed

paddle/fluid/operators/nccl_op.cc

Lines changed: 45 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -104,19 +104,38 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel {
104104
" Input(Communicator) of AllReduce op input should not be NULL");
105105
PADDLE_ENFORCE(ctx->HasOutput("Out"),
106106
" Output(Out) of AllReduce op output should not be NULL");
107-
108-
auto x_dims = ctx->GetInputsDim("X");
109-
110107
std::string reduction = ctx->Attrs().Get<std::string>("reduction");
111108
PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" ||
112109
reduction == "ncclMin" || reduction == "ncclMax"),
113110
"invalid reduction.");
114111

112+
auto x_dims = ctx->GetInputsDim("X");
115113
ctx->SetOutputsDim("Out", x_dims);
116114
ctx->ShareLoD("X", /*->*/ "Out");
117115
}
118116
};
119117

118+
// AllReduceOp
119+
class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
120+
public:
121+
NCCLAllReduceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
122+
: OpProtoAndCheckerMaker(proto, op_checker) {
123+
AddInput("X", "The input of AllReduce op");
124+
AddInput("Communicator", "Communicator for communicating between gpus");
125+
AddOutput("Out", "The output of AllReduce op");
126+
AddAttr<std::string>("reduction",
127+
"(string, default 'ncclSum') "
128+
"{'ncclMin', 'ncclMax', 'ncclProd', 'ncclSum'}.")
129+
.SetDefault("ncclSum");
130+
AddComment(R"DOC(
131+
NCCLAllReduce Operator.
132+
133+
AllReduce the input tensors.
134+
135+
)DOC");
136+
}
137+
};
138+
120139
// ReduceOp
121140
class NCCLReduceOp : public framework::OperatorWithKernel {
122141
public:
@@ -143,50 +162,6 @@ class NCCLReduceOp : public framework::OperatorWithKernel {
143162
}
144163
};
145164

146-
// BcastOp
147-
class NCCLBcastOp : public framework::OperatorWithKernel {
148-
public:
149-
using framework::OperatorWithKernel::OperatorWithKernel;
150-
151-
protected:
152-
void InferShape(framework::InferShapeContext *ctx) const override {
153-
PADDLE_ENFORCE(ctx->HasInput("X"),
154-
" Input(X) of Bcast op input should not be NULL");
155-
PADDLE_ENFORCE(ctx->HasInput("Communicator"),
156-
" Input(Communicator) of Bcast op input should not be NULL");
157-
PADDLE_ENFORCE(ctx->HasOutput("Out"),
158-
" Output(Out) of Bcast op output should not be NULL");
159-
160-
int root = ctx->Attrs().Get<int>("root");
161-
PADDLE_ENFORCE(root != platform::kInvalidGPUId, "Bcast root must be set.");
162-
163-
auto x_dims = ctx->GetInputsDim("X");
164-
ctx->SetOutputsDim("Out", x_dims);
165-
ctx->ShareLoD("X", /*->*/ "Out");
166-
}
167-
};
168-
169-
// AllreduceOp
170-
class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
171-
public:
172-
NCCLAllReduceOpMaker(OpProto *proto, OpAttrChecker *op_checker)
173-
: OpProtoAndCheckerMaker(proto, op_checker) {
174-
AddInput("X", "The input of AllReduce op");
175-
AddInput("Communicator", "Communicator for communicating between gpus");
176-
AddOutput("Out", "The output of AllReduce op");
177-
AddAttr<std::string>("reduction",
178-
"(string, default 'ncclSum') "
179-
"{'ncclMin', 'ncclMax', 'ncclProd', 'ncclSum'}.")
180-
.SetDefault("ncclSum");
181-
AddComment(R"DOC(
182-
NCCLAllReduce Operator.
183-
184-
AllReduce the input tensors.
185-
186-
)DOC");
187-
}
188-
};
189-
190165
// ReduceOp
191166
class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker {
192167
public:
@@ -213,6 +188,29 @@ Reduce the tensors.
213188
}
214189
};
215190

191+
// BcastOp
192+
class NCCLBcastOp : public framework::OperatorWithKernel {
193+
public:
194+
using framework::OperatorWithKernel::OperatorWithKernel;
195+
196+
protected:
197+
void InferShape(framework::InferShapeContext *ctx) const override {
198+
PADDLE_ENFORCE(ctx->HasInput("X"),
199+
" Input(X) of Bcast op input should not be NULL");
200+
PADDLE_ENFORCE(ctx->HasInput("Communicator"),
201+
" Input(Communicator) of Bcast op input should not be NULL");
202+
PADDLE_ENFORCE(ctx->HasOutput("Out"),
203+
" Output(Out) of Bcast op output should not be NULL");
204+
205+
int root = ctx->Attrs().Get<int>("root");
206+
PADDLE_ENFORCE(root != platform::kInvalidGPUId, "Bcast root must be set.");
207+
208+
auto x_dims = ctx->GetInputsDim("X");
209+
ctx->SetOutputsDim("Out", x_dims);
210+
ctx->ShareLoD("X", /*->*/ "Out");
211+
}
212+
};
213+
216214
// BcastOp
217215
class NCCLBcastOpMaker : public framework::OpProtoAndCheckerMaker {
218216
public:

paddle/fluid/operators/nccl_op.cu.cc

Lines changed: 44 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,12 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
4343
void Compute(const framework::ExecutionContext& ctx) const override {
4444
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
4545
"This kernel only runs on GPU device.");
46-
47-
auto ins = ctx.MultiInput<LoDTensor>("X");
48-
auto outs = ctx.MultiOutput<LoDTensor>("Out");
49-
46+
auto* x = ctx.Input<LoDTensor>("X");
47+
auto* out = ctx.Output<LoDTensor>("Out");
48+
auto* comm = ctx.Input<Communicator>("Communicator");
5049
std::string reduction = ctx.Attr<std::string>("reduction");
51-
ncclRedOp_t reduction_op_ = ncclSum;
5250

51+
ncclRedOp_t reduction_op_ = ncclSum;
5352
if (reduction == "ncclMin") {
5453
reduction_op_ = ncclMin;
5554
} else if (reduction == "ncclMax") {
@@ -61,30 +60,19 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
6160
} else {
6261
PADDLE_THROW("Invalid reduction. default ncclSum.");
6362
}
64-
65-
auto* comm = ctx.Input<Communicator>("Communicator");
66-
67-
auto stream = ctx.cuda_device_context().stream();
68-
6963
// device id
7064
int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).GetDeviceId();
7165
int idx = comm->GetCommId(gpu_id);
72-
73-
for (size_t i = 0; i < ins.size(); ++i) {
74-
VLOG(1) << "gpu : "
75-
<< " invoke allreduce. send " << ins[i]->numel() << " recv "
76-
<< outs[i]->numel();
77-
78-
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
79-
ins[i]->data<T>(), outs[i]->mutable_data<T>(ctx.GetPlace()),
80-
outs[i]->numel(), NCCLTypeWrapper<T>::type, reduction_op_,
81-
comm->comms().at(idx), stream));
82-
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
83-
84-
VLOG(1) << "gpu : "
85-
<< " finished allreduce. send " << ins[i]->numel() << " recv "
86-
<< outs[i]->numel();
87-
}
66+
VLOG(3) << "gpu : "
67+
<< " invoke allreduce. send " << x->numel() << " recv "
68+
<< out->numel();
69+
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
70+
x->data<T>(), out->mutable_data<T>(ctx.GetPlace()), out->numel(),
71+
NCCLTypeWrapper<T>::type, reduction_op_, comm->comms().at(idx),
72+
ctx.cuda_device_context().stream()));
73+
VLOG(3) << "gpu : "
74+
<< " finished allreduce. send " << x->numel() << " recv "
75+
<< out->numel();
8876
}
8977
};
9078

@@ -94,13 +82,13 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
9482
void Compute(const framework::ExecutionContext& ctx) const override {
9583
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
9684
"This kernel only runs on GPU device.");
97-
98-
auto ins = ctx.MultiInput<LoDTensor>("X"); // x0, x1, x2
99-
auto outs = ctx.MultiOutput<LoDTensor>("Out");
100-
85+
auto x = ctx.Input<LoDTensor>("X"); // x0, x1, x2
86+
auto out = ctx.Output<LoDTensor>("Out");
87+
auto* comm = ctx.Input<Communicator>("Communicator");
88+
int root = ctx.Attr<int>("root");
10189
std::string reduction = ctx.Attr<std::string>("reduction");
102-
ncclRedOp_t reduction_op_ = ncclSum;
10390

91+
ncclRedOp_t reduction_op_ = ncclSum;
10492
if (reduction == "ncclMin") {
10593
reduction_op_ = ncclMin;
10694
} else if (reduction == "ncclMax") {
@@ -112,40 +100,21 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
112100
} else {
113101
PADDLE_THROW("Invalid reduction. default ncclSum.");
114102
}
115-
116-
int root = ctx.Attr<int>("root");
117-
auto* comm = ctx.Input<Communicator>("Communicator");
118-
119-
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
120-
ctx.device_context())
121-
.stream();
122103
// device id
123104
int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).GetDeviceId();
124105
int idx = comm->GetCommId(gpu_id);
125-
126-
auto ins_names = ctx.Inputs("X");
127-
std::hash<std::string> hasher;
128-
for (size_t i = 0; i < ins.size(); ++i) {
129-
if (root == platform::kInvalidGPUId) {
130-
root = hasher(ins_names[i]) % comm->comms().size();
131-
}
132-
T* recvbuffer = nullptr;
133-
if (root == gpu_id) {
134-
recvbuffer = outs[i]->mutable_data<T>(ctx.GetPlace());
135-
}
136-
137-
VLOG(1) << "gpu : " << gpu_id << " invoke reduce. send "
138-
<< ins[i]->numel() << " recv " << outs[i]->numel();
139-
140-
PADDLE_ENFORCE(platform::dynload::ncclReduce(
141-
ins[i]->data<T>(), recvbuffer, ins[i]->numel(),
142-
NCCLTypeWrapper<T>::type, reduction_op_, root, comm->comms().at(idx),
143-
stream));
144-
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
145-
146-
VLOG(1) << "gpu : " << gpu_id << " finished reduce. send "
147-
<< ins[i]->numel() << " recv " << outs[i]->numel();
106+
T* recvbuffer = nullptr;
107+
if (root == gpu_id) {
108+
recvbuffer = out->mutable_data<T>(ctx.GetPlace());
148109
}
110+
VLOG(3) << "gpu : " << gpu_id << " invoke reduce. send " << x->numel()
111+
<< " recv " << out->numel();
112+
PADDLE_ENFORCE(platform::dynload::ncclReduce(
113+
x->data<T>(), recvbuffer, x->numel(), NCCLTypeWrapper<T>::type,
114+
reduction_op_, root, comm->comms().at(idx),
115+
ctx.cuda_device_context().stream()));
116+
VLOG(3) << "gpu : " << gpu_id << " finished reduce. send " << x->numel()
117+
<< " recv " << out->numel();
149118
}
150119
};
151120

@@ -155,47 +124,27 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
155124
void Compute(const framework::ExecutionContext& ctx) const override {
156125
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
157126
"This kernel only runs on GPU device.");
158-
159127
int root = ctx.Attr<int>("root");
160-
161128
auto* comm = ctx.Input<Communicator>("Communicator");
162-
163-
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
164-
ctx.device_context())
165-
.stream();
166129
// device id
167130
int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).GetDeviceId();
168131
int idx = comm->GetCommId(gpu_id);
169-
170132
if (idx == root) {
171-
auto ins = ctx.MultiInput<LoDTensor>("X");
172-
for (size_t i = 0; i < ins.size(); ++i) {
173-
VLOG(1) << "gpu : " << gpu_id << " invoke Bcast. send "
174-
<< ins[i]->numel();
175-
176-
VLOG(1) << " before ncclBcast";
177-
PADDLE_ENFORCE(platform::dynload::ncclBcast(
178-
(void*)ins[i]->data<T>(), ins[i]->numel(), NCCLTypeWrapper<T>::type,
179-
root, comm->comms().at(idx), stream));
180-
VLOG(1) << " after ncclBcast";
181-
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
182-
183-
VLOG(1) << "gpu : " << gpu_id << " finished Bcast.";
184-
}
133+
auto* x = ctx.Input<LoDTensor>("X");
134+
VLOG(3) << "gpu : " << gpu_id << " invoke Bcast. send " << x->numel();
135+
PADDLE_ENFORCE(platform::dynload::ncclBcast(
136+
(void*)x->data<T>(), x->numel(), NCCLTypeWrapper<T>::type, root,
137+
comm->comms().at(idx), ctx.cuda_device_context().stream()));
138+
VLOG(3) << "gpu : " << gpu_id << " finished Bcast.";
185139
} else {
186-
auto outs = ctx.MultiOutput<LoDTensor>("Out");
187-
for (size_t i = 0; i < outs.size(); ++i) {
188-
VLOG(1) << "gpu : " << gpu_id << " invoke Bcast. recv buffer "
189-
<< framework::product(outs[i]->dims());
190-
191-
PADDLE_ENFORCE(platform::dynload::ncclBcast(
192-
outs[i]->mutable_data<T>(ctx.GetPlace()), outs[i]->numel(),
193-
NCCLTypeWrapper<T>::type, root, comm->comms().at(idx), stream));
194-
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
195-
196-
VLOG(1) << "gpu : " << gpu_id << " finished Bcast. recv "
197-
<< outs[i]->numel();
198-
}
140+
auto* out = ctx.Output<LoDTensor>("Out");
141+
VLOG(3) << "gpu : " << gpu_id << " invoke Bcast. recv buffer "
142+
<< framework::product(out->dims());
143+
PADDLE_ENFORCE(platform::dynload::ncclBcast(
144+
out->mutable_data<T>(ctx.GetPlace()), out->numel(),
145+
NCCLTypeWrapper<T>::type, root, comm->comms().at(idx),
146+
ctx.cuda_device_context().stream()));
147+
VLOG(3) << "gpu : " << gpu_id << " finished Bcast. recv " << out->numel();
199148
}
200149
}
201150
};

0 commit comments

Comments
 (0)