Skip to content

Commit 71305e5

Browse files
committed
"polish code based on comment"
1 parent 6f009cf commit 71305e5

File tree

4 files changed

+13
-11
lines changed

4 files changed

+13
-11
lines changed

paddle/framework/operator.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,12 +290,12 @@ class ExecutionContext {
290290
return device_context_;
291291
}
292292

293-
//! Get variables vector with same input name.
293+
//! Get actual name vector for this input.
294294
const std::vector<std::string>& Inputs(const std::string& name) const {
295295
return op_.Inputs(name);
296296
}
297297

298-
//! Get variables vector with same output name.
298+
//! Get actual name vector for this output.
299299
const std::vector<std::string>& Outputs(const std::string& name) const {
300300
return op_.Outputs(name);
301301
}

paddle/operators/nccl_op.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ class NCCLInitOp : public framework::OperatorBase {
3030
"Can not find variable '%s' in the scope.", name);
3131
std::vector<int> gpus = Attr<std::vector<int>>("gpus");
3232
PADDLE_ENFORCE(!gpus.empty(), "Attr(gpus) should not be empty.");
33+
34+
if (scope.FindVar(name) == nullptr) {
35+
PADDLE_THROW("Output(Communicator) is needed for ncclInit operator.");
36+
}
37+
3338
platform::Communicator *comm =
3439
scope.FindVar(name)->GetMutable<platform::Communicator>();
3540
comm->InitAll(gpus);

paddle/operators/nccl_op.cu

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
99
See the License for the specific language governing permissions and
1010
limitations under the License. */
1111

12-
#define EIGEN_USE_GPU
1312
#include <functional>
1413

1514
#include "paddle/framework/lod_tensor.h"
@@ -60,7 +59,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
6059
} else if (reduction == "ncclProd") {
6160
reduction_op_ = ncclProd;
6261
} else {
63-
PADDLE_ENFORCE(false, "Invalid reduction. default ncclSum.");
62+
PADDLE_THROW("Invalid reduction. default ncclSum.");
6463
}
6564

6665
auto* comm = ctx.Input<Communicator>("Communicator");
@@ -113,7 +112,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
113112
} else if (reduction == "ncclProd") {
114113
reduction_op_ = ncclProd;
115114
} else {
116-
PADDLE_ENFORCE(false, "Invalid reduction. default ncclSum.");
115+
PADDLE_THROW("Invalid reduction. default ncclSum.");
117116
}
118117

119118
int root = ctx.Attr<int>("root");

paddle/operators/nccl_op_test.cu

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#define EIGEN_USE_GPU
16-
1715
#include <glog/logging.h>
1816
#include <gtest/gtest.h>
1917
#include <algorithm>
@@ -193,15 +191,15 @@ TEST_F(NCCLTester, ncclAllReduceOp) {
193191
}
194192
}
195193

196-
// ncclAReduceOp with desc
194+
// ncclReduceOp with desc
197195
TEST_F(NCCLTester, ncclReduceOp) {
198196
std::unique_ptr<f::OpDescBind> op2(new f::OpDescBind);
199197
const int kRoot = 0;
200198
op2->SetType("ncclReduce");
201199
op2->SetInput("X", {"st"});
202200
op2->SetInput("Communicator", {"comm"});
203201
op2->SetOutput("Out", {"rt"});
204-
op2->SetAttr("root", {kRoot});
202+
op2->SetAttr("root", kRoot);
205203

206204
std::vector<f::Scope *> dev_scopes;
207205

@@ -241,15 +239,15 @@ TEST_F(NCCLTester, ncclReduceOp) {
241239
}
242240
}
243241

244-
// // ncclBcastOp with desc
242+
// ncclBcastOp with desc
245243
TEST_F(NCCLTester, ncclBcastOp) {
246244
std::unique_ptr<f::OpDescBind> op2(new f::OpDescBind);
247245
const int kRoot = 5;
248246
op2->SetType("ncclBcast");
249247
op2->SetInput("X", {"st"});
250248
op2->SetInput("Communicator", {"comm"});
251249
op2->SetOutput("Out", {"rt"});
252-
op2->SetAttr("root", {kRoot});
250+
op2->SetAttr("root", kRoot);
253251

254252
std::vector<f::Scope *> dev_scopes;
255253

0 commit comments

Comments
 (0)