Skip to content

Commit 8c40ebd

Browse files
authored
Enhance error message of checkpoint_notify_op, fake_init_op gen_nccl_id_op and listen_and_serv_op (#24554) (#24844)
test=develop
1 parent 343687c commit 8c40ebd

File tree

6 files changed

+89
-24
lines changed

6 files changed

+89
-24
lines changed

paddle/fluid/operators/distributed_ops/checkpoint_notify_op.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,10 @@ class CheckpointNotifyOp : public framework::OperatorBase {
4949
VLOG(3) << "checkpoint notify sending lookup table: " << lookup_table_name
5050
<< " and dir:" << dir << " to " << epmap[i];
5151
}
52-
PADDLE_ENFORCE(rpc_client->Wait(), "internal error in RPCClient");
52+
PADDLE_ENFORCE_EQ(
53+
rpc_client->Wait(), true,
54+
platform::errors::Fatal("Fail to notify checkpoint."
55+
" Internal error occurs in RPCClient."));
5356
}
5457
};
5558

paddle/fluid/operators/distributed_ops/fake_init_op.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ namespace operators {
1919
class FakeInitInferShape : public framework::InferShapeBase {
2020
public:
2121
void operator()(framework::InferShapeContext *ctx) const override {
22-
PADDLE_ENFORCE(ctx->HasOutput("Out"),
23-
"Output(Out) of FakeInitOp should not be null.");
22+
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "FakeInit");
2423
auto &shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
2524
ctx->SetOutputDim("Out", framework::make_ddim(shape));
2625
}

paddle/fluid/operators/distributed_ops/gen_nccl_id_op.cc

Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,15 @@ class GenNCCLIdOp : public framework::OperatorBase {
4444

4545
std::vector<std::string> trainers =
4646
Attr<std::vector<std::string>>("trainers");
47-
PADDLE_ENFORCE(
48-
trainer_id >= 0 && trainer_id < static_cast<int>(trainers.size()),
49-
"trainer_id:%d must be in trainers.size range", trainer_id);
47+
PADDLE_ENFORCE_GE(trainer_id, 0, platform::errors::InvalidArgument(
48+
"trainer_id %d is less than 0. Its "
49+
"valid range is [0, trainer_size)"));
50+
PADDLE_ENFORCE_LT(
51+
trainer_id, static_cast<int>(trainers.size()),
52+
platform::errors::OutOfRange("trainer_id %d is out of range. Its valid "
53+
"range is [0, trainer_size)",
54+
trainer_id));
55+
5056
std::string endpoint = trainers[trainer_id];
5157

5258
framework::Scope& local_scope = scope.NewScope();
@@ -58,12 +64,20 @@ class GenNCCLIdOp : public framework::OperatorBase {
5864
int inter_trainer_id = -1;
5965
int exter_trainer_id = -1;
6066
if (use_hierarchical_allreduce) {
61-
PADDLE_ENFORCE(trainers.size() > 1, "trainers.size():%llu < 1",
62-
trainers.size());
63-
PADDLE_ENFORCE(inter_nranks > 1, "inter_nranks:%d < 1", inter_nranks);
64-
PADDLE_ENFORCE((trainers.size() % inter_nranks == 0),
65-
"trainers.size():%llu mod inter_nranks:%d != 0",
66-
trainers.size(), inter_nranks);
67+
PADDLE_ENFORCE_GT(
68+
trainers.size(), 1,
69+
platform::errors::PreconditionNotMet(
70+
"The number of collective trainers %llu <= 1", trainers.size()));
71+
PADDLE_ENFORCE_GT(
72+
inter_nranks, 1,
73+
platform::errors::PreconditionNotMet(
74+
"inter_nranks %d <= 1 while in hierarchical allreduce mode",
75+
inter_nranks));
76+
PADDLE_ENFORCE_EQ(
77+
trainers.size() % inter_nranks, 0,
78+
platform::errors::PreconditionNotMet(
79+
"The number of trainers %llu mod inter_nranks %d is not equal 0",
80+
trainers.size(), inter_nranks));
6781

6882
inter_trainer_id = trainer_id % inter_nranks;
6983

@@ -106,10 +120,16 @@ class GenNCCLIdOp : public framework::OperatorBase {
106120
return;
107121
}
108122

109-
PADDLE_ENFORCE(trainers.size() % inter_nranks == 0,
110-
"enpoints.size:%llu mod inter_nranks:%d should ==0",
111-
trainers.size(), inter_nranks);
112-
PADDLE_ENFORCE(inter_nranks > 1, "inter_nranks:%d must > 1", inter_nranks);
123+
PADDLE_ENFORCE_EQ(
124+
trainers.size() % inter_nranks, 0,
125+
platform::errors::PreconditionNotMet(
126+
"The number of trainers %llu mod inter_nranks %d is not equal 0",
127+
trainers.size(), inter_nranks));
128+
PADDLE_ENFORCE_GT(
129+
inter_nranks, 1,
130+
platform::errors::PreconditionNotMet(
131+
"inter_nranks %d <= 1 while in hierarchical allreduce mode",
132+
inter_nranks));
113133

114134
// hierarchical inter ncclid
115135
if (inter_trainer_id == 0) {
@@ -156,10 +176,11 @@ class GenNCCLIdOp : public framework::OperatorBase {
156176
const std::string& nccl_id_name,
157177
const std::vector<std::string>& endpoint_list) const {
158178
auto var = scope->FindVar(nccl_id_name);
159-
PADDLE_ENFORCE_NOT_NULL(var, "can't find nccl_id_var_name:%s",
160-
nccl_id_name);
179+
PADDLE_ENFORCE_NOT_NULL(
180+
var, platform::errors::NotFound("Variable with name %s is not found",
181+
nccl_id_name.c_str()));
161182
auto id = var->GetMutable<ncclUniqueId>();
162-
PADDLE_ENFORCE(platform::dynload::ncclGetUniqueId(id));
183+
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclGetUniqueId(id));
163184

164185
distributed::RPCClient* client =
165186
distributed::RPCClient::GetInstance<RPCCLIENT_T>(0);

paddle/fluid/operators/distributed_ops/listen_and_serv_op.cc

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -315,8 +315,9 @@ void ListenAndServOp::CacheVarsType(const std::vector<std::string> &varnames,
315315
const framework::Scope &scope) const {
316316
for (const auto &varname : varnames) {
317317
auto var = scope.FindVar(varname);
318-
PADDLE_ENFORCE(var != nullptr,
319-
"Received var should be initialized in the received scope.");
318+
PADDLE_ENFORCE_NOT_NULL(
319+
var, platform::errors::PreconditionNotMet(
320+
"Received var is not initialized in the received scope."));
320321
if (var->IsType<framework::SelectedRows>()) {
321322
sparse_vars_.push_back(varname);
322323
} else if (var->IsType<framework::LoDTensor>() ||
@@ -344,7 +345,9 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
344345
auto pserver_id = Attr<int>("pserver_id");
345346
auto inputs = Inputs("X");
346347

347-
PADDLE_ENFORCE(!rpc_service_);
348+
PADDLE_ENFORCE_EQ(rpc_service_, nullptr,
349+
platform::errors::PreconditionNotMet(
350+
"RPC service has been created unexpectedly."));
348351
std::string endpoint = Attr<std::string>("endpoint");
349352
int checkpoint_block_id = Attr<int>(kCheckpointBlockId);
350353
int lr_decay_block_id = Attr<int>(kLRDecayBlockId);
@@ -390,8 +393,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
390393

391394
auto optimize_blocks =
392395
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
393-
PADDLE_ENFORCE(optimize_blocks.size() >= 1,
394-
"optimize blocks should be 1 at least on the pserver side.");
396+
PADDLE_ENFORCE_GE(optimize_blocks.size(), 1,
397+
platform::errors::PreconditionNotMet(
398+
"optimize blocks is less than 1. Optimize blocks "
399+
"should be 1 at least on the pserver side."));
395400
auto *program = optimize_blocks[0]->Program();
396401
framework::Executor executor(dev_place);
397402

python/paddle/fluid/tests/unittests/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ if(WIN32)
4949
LIST(REMOVE_ITEM TEST_OPS test_trainer_desc)
5050
LIST(REMOVE_ITEM TEST_OPS test_multiprocess_reader_exception)
5151
LIST(REMOVE_ITEM TEST_OPS test_avoid_twice_initialization)
52+
LIST(REMOVE_ITEM TEST_OPS test_checkpoint_notify_op)
5253
endif()
5354

5455
if (NOT ${WITH_GPU})
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import print_function
16+
17+
import unittest
18+
import paddle.fluid as fluid
19+
20+
21+
class TestCheckpointNotifyOp(unittest.TestCase):
22+
def test_checkpoint_notify_op(self):
23+
program = fluid.Program()
24+
attrs = {}
25+
attrs['epmap'] = []
26+
attrs['dir'] = ''
27+
attrs['lookup_table'] = ''
28+
program.current_block().append_op(
29+
type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs)
30+
31+
exe = fluid.Executor(fluid.CPUPlace())
32+
exe.run(program)
33+
34+
35+
if __name__ == '__main__':
36+
unittest.main()

0 commit comments

Comments
 (0)