Skip to content

Commit f0c6101

Browse files
authored
fix op error, test=develop (#24451) (#24539)
1 parent 27dee22 commit f0c6101

File tree

5 files changed

+30
-12
lines changed

5 files changed

+30
-12
lines changed

paddle/fluid/operators/distributed_ops/recv_op.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@ class RecvOp : public framework::OperatorBase {
8484
}
8585
for (size_t i = 0; i < rets.size(); i++) {
8686
VLOG(7) << "before sync_recv " << outs[i] << "from " << epmap[i];
87-
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
87+
PADDLE_ENFORCE_NE(
88+
rets[i]->Wait(), 0U,
89+
platform::errors::ExecutionTimeout("internal error in RPCClient"));
8890
VLOG(7) << "after sync_recv " << outs[i] << "from " << epmap[i];
8991
}
9092
}

paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,23 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel {
2727
: OperatorWithKernel(type, inputs, outputs, attrs) {}
2828

2929
void InferShape(framework::InferShapeContext *ctx) const override {
30-
PADDLE_ENFORCE(ctx->HasInputs("X"),
31-
"Input(X) of RefByTrainerIdOp should not be null.");
32-
PADDLE_ENFORCE(ctx->HasInput("TrainerId"),
33-
"Input(TrainerId) of RefByTrainerIdOp should not be null.");
34-
PADDLE_ENFORCE(ctx->HasOutput("Out"),
35-
"Output(Out) of RefByTrainerIdOp should not be null.");
36-
PADDLE_ENFORCE_EQ(ctx->GetInputDim("TrainerId").size(), 1,
37-
"TrainerId should be a scalar.");
30+
PADDLE_ENFORCE_EQ(ctx->HasInputs("X"), true,
31+
platform::errors::InvalidArgument(
32+
"Input(X) of RefByTrainerIdOp should not be null."));
33+
34+
PADDLE_ENFORCE_EQ(
35+
ctx->HasInput("TrainerId"), true,
36+
platform::errors::InvalidArgument(
37+
"Input(TrainerId) of RefByTrainerIdOp should not be null."));
38+
39+
PADDLE_ENFORCE_EQ(
40+
ctx->HasOutput("Out"), true,
41+
platform::errors::InvalidArgument(
42+
"Output(Out) of RefByTrainerIdOp should not be null."));
43+
44+
PADDLE_ENFORCE_EQ(
45+
ctx->GetInputDim("TrainerId").size(), 1,
46+
platform::errors::InvalidArgument("TrainerId should be a scalar."));
3847
// Out's shape is determined at runtime.
3948
}
4049

paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,10 @@ class RefByTrainerIdKernel : public framework::OpKernel<T> {
3838
} else {
3939
trainer_id = *trainer_id_data;
4040
}
41-
PADDLE_ENFORCE_LT((size_t)trainer_id, in_list.size());
41+
PADDLE_ENFORCE_LT((size_t)trainer_id, in_list.size(),
42+
platform::errors::InvalidArgument(
43+
"X' size must >= TrainerId: [%s], but received [%s]",
44+
trainer_id, in_list.size()));
4245
out->mutable_data<T>(context.GetPlace());
4346
framework::TensorCopy(*(in_list[trainer_id]), in_list[trainer_id]->place(),
4447
out);

paddle/fluid/operators/distributed_ops/send_barrier_op.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ class SendBarrierOp : public framework::OperatorBase {
5959
}
6060

6161
for (size_t i = 0; i < rets.size(); i++) {
62-
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
62+
PADDLE_ENFORCE_NE(
63+
rets[i]->Wait(), 0U,
64+
platform::errors::ExecutionTimeout("internal error in RPCClient"));
6365
}
6466
}
6567
};

paddle/fluid/operators/distributed_ops/send_op.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ class SendOp : public framework::OperatorBase {
8383
}
8484
for (size_t i = 0; i < rets.size(); i++) {
8585
VLOG(7) << "before sync_send " << ins[i] << "from " << epmap[i];
86-
PADDLE_ENFORCE_NE(rets[i]->Wait(), 0U, "internal error in RPCClient");
86+
PADDLE_ENFORCE_NE(
87+
rets[i]->Wait(), 0U,
88+
platform::errors::ExecutionTimeout("internal error in RPCClient"));
8789
VLOG(7) << "after sync_send " << ins[i] << "from " << epmap[i];
8890
}
8991
}

0 commit comments

Comments
 (0)