Skip to content

Commit 7d0e903

Browse files
authored
update error message for unstack op and lamb op; test=develop (#24487)
1 parent 6885d15 commit 7d0e903

File tree

2 files changed

+47
-27
lines changed

2 files changed

+47
-27
lines changed

paddle/fluid/operators/optimizers/lamb_op.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,12 @@ class LambOpKernel : public framework::OpKernel<T> {
177177
public:
178178
void Compute(const framework::ExecutionContext& ctx) const override {
179179
const auto* param_var = ctx.InputVar("Param");
180-
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
181-
"The Var(%s)'s type should be LoDTensor, "
182-
"but the received is %s",
183-
ctx.InputNames("Param").front(),
184-
framework::ToTypeName(param_var->Type()));
180+
PADDLE_ENFORCE_EQ(param_var->IsType<framework::LoDTensor>(), true,
181+
platform::errors::InvalidArgument(
182+
"The Var(%s)'s type should be LoDTensor, "
183+
"but the received is %s",
184+
ctx.InputNames("Param").front(),
185+
framework::ToTypeName(param_var->Type())));
185186

186187
using paddle::framework::LoDTensor;
187188

@@ -274,7 +275,10 @@ class LambOpKernel : public framework::OpKernel<T> {
274275
row_numel, grad_merge.rows().size());
275276
for_range(moment_update_functor);
276277
} else {
277-
PADDLE_THROW("Variable type not supported by lamb_op.");
278+
PADDLE_THROW(platform::errors::InvalidArgument(
279+
"Variable type not supported by lamb_op. Expect LoDTensor or "
280+
"SelectedRows, but got %s",
281+
framework::ToTypeName(param_var->Type())));
278282
}
279283

280284
// Update parameter

paddle/fluid/operators/unstack_op.cc

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,24 +27,35 @@ class UnStackOp : public framework::OperatorWithKernel {
2727
using framework::OperatorWithKernel::OperatorWithKernel;
2828

2929
void InferShape(framework::InferShapeContext *ctx) const override {
30-
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) must exist.");
31-
30+
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "UnStack");
3231
int axis = ctx->Attrs().Get<int>("axis");
3332
int num = ctx->Attrs().Get<int>("num");
3433
auto x_dim = ctx->GetInputDim("X");
3534
int rank = x_dim.size();
36-
PADDLE_ENFORCE_GE(
37-
axis, -rank, "Attr(axis) must be inside [-rank, rank), where rank = %d",
38-
rank);
39-
PADDLE_ENFORCE_LT(
40-
axis, rank, "Attr(axis) must be inside [-rank, rank), where rank = %d",
41-
rank);
35+
PADDLE_ENFORCE_GE(axis, -rank,
36+
platform::errors::InvalidArgument(
37+
"The attribute axis is out of range, it must be "
38+
"inside [-rank, rank), where rank = %d",
39+
rank));
40+
PADDLE_ENFORCE_LT(axis, rank,
41+
platform::errors::InvalidArgument(
42+
"The attribute axis is out of range, it must be "
43+
"inside [-rank, rank), where rank = %d",
44+
rank));
4245
if (axis < 0) axis += rank;
4346

4447
PADDLE_ENFORCE_EQ(ctx->Outputs("Y").size(), static_cast<size_t>(num),
45-
"Number of Outputs(Y) is wrong");
48+
platform::errors::InvalidArgument(
49+
"Number of Outputs(Y) is wrong. Got %d , but it must "
50+
"equal to attribute num which is %d.",
51+
ctx->Outputs("Y").size(), static_cast<size_t>(num)));
4652
if (x_dim[axis] > 0) {
47-
PADDLE_ENFORCE_EQ(num, x_dim[axis], "Number of Outputs(Y) is wrong");
53+
PADDLE_ENFORCE_EQ(
54+
num, x_dim[axis],
55+
platform::errors::InvalidArgument(
56+
"The number of attribute num is not equal to the length of the "
57+
"%d axis of Input(X). Expect %d but got %d.",
58+
axis, x_dim[axis], num));
4859
}
4960
auto vec = framework::vectorize<int>(x_dim);
5061
vec.erase(vec.begin() + axis);
@@ -89,24 +100,29 @@ class UnStackGradOp : public framework::OperatorWithKernel {
89100

90101
void InferShape(framework::InferShapeContext *ctx) const override {
91102
PADDLE_ENFORCE_GT(ctx->Inputs(framework::GradVarName("Y")).size(), 0,
92-
"Number of Inputs(Y@Grad) must be larger than 0");
93-
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
94-
"Output(X@Grad) must exist.");
95-
103+
platform::errors::InvalidArgument(
104+
"Number of Inputs(Y@Grad) must be larger than 0"));
105+
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output", "X",
106+
"UnStackGrad");
96107
auto input_dims = ctx->GetInputsDim(framework::GradVarName("Y"));
97108
for (size_t i = 1; i < input_dims.size(); ++i) {
98109
PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0],
99-
"Dims of all Inputs(Y@Grad) must be the same");
110+
platform::errors::InvalidArgument(
111+
"Dims of all Inputs(Y@Grad) must be the same"));
100112
}
101113

102114
int axis = ctx->Attrs().Get<int>("axis");
103115
int rank = input_dims[0].size();
104-
PADDLE_ENFORCE_GE(
105-
axis, -(rank + 1),
106-
"Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank);
107-
PADDLE_ENFORCE_LT(
108-
axis, rank + 1,
109-
"Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank);
116+
PADDLE_ENFORCE_GE(axis, -(rank + 1),
117+
platform::errors::InvalidArgument(
118+
"The attribute axis is out of range, it must be "
119+
"inside [-(rank+1), rank+1), where rank = %d",
120+
rank));
121+
PADDLE_ENFORCE_LT(axis, rank + 1,
122+
platform::errors::InvalidArgument(
123+
"The attribute axis is out of range, it must be "
124+
"inside [-(rank+1), rank+1), where rank = %d",
125+
rank));
110126
if (axis < 0) axis += (rank + 1);
111127

112128
auto vec = framework::vectorize<int>(input_dims[0]);

0 commit comments

Comments
 (0)