Skip to content

Commit 417b576

Browse files
authored
API(dynamic_lstm, dynamic_lstmp) error message enhancement (#24450)
* update err msg for dynamic_lstm and dynamic_lstmp, test=develop
1 parent 53bdee6 commit 417b576

File tree

5 files changed

+269
-123
lines changed

5 files changed

+269
-123
lines changed

paddle/fluid/operators/lstm_op.cc

Lines changed: 67 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -24,64 +24,80 @@ class LSTMOp : public framework::OperatorWithKernel {
2424
using framework::OperatorWithKernel::OperatorWithKernel;
2525

2626
void InferShape(framework::InferShapeContext* ctx) const override {
27-
PADDLE_ENFORCE(ctx->HasInput("Input"),
28-
"Input(Input) of LSTM should not be null.");
29-
PADDLE_ENFORCE(ctx->HasInput("Weight"),
30-
"Input(Weight) of LSTM should not be null.");
31-
PADDLE_ENFORCE(ctx->HasInput("Bias"),
32-
"Input(Bias) of LSTM should not be null.");
33-
34-
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
35-
"Output(Hidden) of LSTM should not be null.");
36-
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
37-
"Output(Cell) of LSTM should not be null.");
38-
PADDLE_ENFORCE(ctx->HasOutput("BatchGate"),
39-
"Output(BatchGate) of LSTM should not be null.");
40-
PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"),
41-
"Output(BatchGate) of LSTM should not be null.");
27+
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "LSTM");
28+
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "LSTM");
29+
OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "LSTM");
30+
31+
OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "LSTM");
32+
OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "LSTM");
33+
OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "LSTM");
34+
OP_INOUT_CHECK(ctx->HasOutput("BatchCellPreAct"), "Output",
35+
"BatchCellPreAct", "LSTM");
4236

4337
auto in_dims = ctx->GetInputDim("Input");
44-
PADDLE_ENFORCE_EQ(in_dims.size(), 2, "Input(X)'s rank must be 2.");
38+
PADDLE_ENFORCE_EQ(
39+
in_dims.size(), 2,
40+
platform::errors::InvalidArgument(
41+
"Input(X)'s rank must be 2, but received %d.", in_dims.size()));
4542

4643
if (ctx->HasInput("H0")) {
47-
PADDLE_ENFORCE(ctx->HasInput("C0"),
48-
"Input(Cell) and Input(Hidden) of LSTM should not "
49-
"be null at the same time.");
44+
PADDLE_ENFORCE_EQ(
45+
ctx->HasInput("C0"), true,
46+
platform::errors::NotFound("Input(Cell) and Input(Hidden) of LSTM "
47+
"should not be null at the same time."));
5048
auto h_dims = ctx->GetInputDim("H0");
5149
auto c_dims = ctx->GetInputDim("C0");
52-
PADDLE_ENFORCE(h_dims == c_dims,
53-
"The dimension of Input(H0) and Input(C0) "
54-
"should be the same.");
50+
PADDLE_ENFORCE_EQ(h_dims, c_dims,
51+
platform::errors::InvalidArgument(
52+
"The dimension of Input(H0) and Input(C0) should "
53+
"be the same, but received [%s] (H0) vs [%s] (C0).",
54+
h_dims, c_dims));
5555
}
5656

5757
int frame_size = in_dims[1] / 4;
5858
auto w_dims = ctx->GetInputDim("Weight");
59-
PADDLE_ENFORCE_EQ(w_dims.size(), 2,
60-
"The rank of Input(Weight) should be 2.");
59+
PADDLE_ENFORCE_EQ(
60+
w_dims.size(), 2,
61+
platform::errors::InvalidArgument(
62+
"The rank of Input(Weight) should be 2, but received %d.",
63+
w_dims.size()));
6164
PADDLE_ENFORCE_EQ(w_dims[0], frame_size,
62-
"The first dimension of Input(Weight) "
63-
"should be %d.",
64-
frame_size);
65+
platform::errors::InvalidArgument(
66+
"The first dimension of Input(Weight) should be %d, "
67+
"but received %d.",
68+
frame_size, w_dims[0]));
6569
PADDLE_ENFORCE_EQ(w_dims[1], 4 * frame_size,
66-
"The second dimension of Input(Weight) "
67-
"should be 4 * %d.",
68-
frame_size);
70+
platform::errors::InvalidArgument(
71+
"The second dimension of Input(Weight) should be 4 * "
72+
"%d, but received %d.",
73+
frame_size, w_dims[1]));
6974

7075
auto b_dims = ctx->GetInputDim("Bias");
71-
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
72-
PADDLE_ENFORCE_EQ(b_dims[0], 1,
73-
"The first dimension of Input(Bias) should be 1.");
76+
PADDLE_ENFORCE_EQ(
77+
b_dims.size(), 2,
78+
platform::errors::InvalidArgument(
79+
"The rank of Input(Bias) should be 2, but received %d.",
80+
b_dims.size()));
81+
PADDLE_ENFORCE_EQ(
82+
b_dims[0], 1,
83+
platform::errors::InvalidArgument(
84+
"The first dimension of Input(Bias) should be 1, but received %d.",
85+
b_dims[0]));
7486

7587
if (ctx->Attrs().Get<bool>("use_peepholes")) {
76-
PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size,
77-
"The second dimension of Input(Bias) should be "
78-
"7 * %d if enable peepholes connection",
79-
frame_size);
88+
PADDLE_ENFORCE_EQ(
89+
b_dims[1], 7 * frame_size,
90+
platform::errors::InvalidArgument(
91+
"The second dimension of Input(Bias) should be 7 * %d if enable "
92+
"peepholes connection, but received %d.",
93+
frame_size, b_dims[1]));
8094
} else {
81-
PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size,
82-
"The second dimension of Input(Bias) should be "
83-
"4 * %d if disable peepholes connection",
84-
frame_size);
95+
PADDLE_ENFORCE_EQ(
96+
b_dims[1], 4 * frame_size,
97+
platform::errors::InvalidArgument(
98+
"The second dimension of Input(Bias) should be 4 * %d if disable "
99+
"peepholes connection, but received %d.",
100+
frame_size, b_dims[1]));
85101
}
86102

87103
framework::DDim out_dims({in_dims[0], frame_size});
@@ -229,21 +245,16 @@ class LSTMGradOp : public framework::OperatorWithKernel {
229245
using framework::OperatorWithKernel::OperatorWithKernel;
230246

231247
void InferShape(framework::InferShapeContext* ctx) const override {
232-
PADDLE_ENFORCE(ctx->HasInput("Input"),
233-
"Input(Input) of LSTM should not be null.");
234-
PADDLE_ENFORCE(ctx->HasInput("Hidden"),
235-
"Input(Hidden) of LSTM should not be null.");
236-
PADDLE_ENFORCE(ctx->HasInput("Cell"),
237-
"Input(Cell) of LSTM should not be null.");
238-
PADDLE_ENFORCE(ctx->HasInput("Weight"),
239-
"Input(Weight) of LSTM should not be null.");
240-
PADDLE_ENFORCE(ctx->HasInput("Bias"),
241-
"Input(Bias) of LSTM should not be null.");
242-
243-
PADDLE_ENFORCE(ctx->HasInput("BatchGate"),
244-
"Input(BatchGate) of LSTM should not be null.");
245-
PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"),
246-
"Input(BatchGate) of LSTM should not be null.");
248+
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "LSTM@Grad");
249+
OP_INOUT_CHECK(ctx->HasInput("Hidden"), "Input", "Hidden", "LSTM@Grad");
250+
OP_INOUT_CHECK(ctx->HasInput("Cell"), "Input", "Cell", "LSTM@Grad");
251+
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "LSTM@Grad");
252+
OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "LSTM@Grad");
253+
254+
OP_INOUT_CHECK(ctx->HasInput("BatchGate"), "Input", "BatchGate",
255+
"LSTM@Grad");
256+
OP_INOUT_CHECK(ctx->HasInput("BatchCellPreAct"), "Input", "BatchCellPreAct",
257+
"LSTM@Grad");
247258

248259
auto SetOutGradDim = [&ctx](const std::string& name) {
249260
auto g_name = framework::GradVarName(name);

paddle/fluid/operators/lstmp_op.cc

Lines changed: 81 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -24,74 +24,92 @@ class LSTMPOp : public framework::OperatorWithKernel {
2424
using framework::OperatorWithKernel::OperatorWithKernel;
2525

2626
void InferShape(framework::InferShapeContext* ctx) const override {
27-
PADDLE_ENFORCE(ctx->HasInput("Input"),
28-
"Input(Input) of LSTMP operator should not be null.");
29-
PADDLE_ENFORCE(ctx->HasInput("Weight"),
30-
"Input(Weight) of LSTMP operator should not be null.");
31-
PADDLE_ENFORCE(ctx->HasInput("ProjWeight"),
32-
"Input(ProjWeight) of LSTMP operator should not be null.");
33-
PADDLE_ENFORCE(ctx->HasInput("Bias"),
34-
"Input(Bias) of LSTMP operator should not be null.");
35-
36-
PADDLE_ENFORCE(ctx->HasOutput("Projection"),
37-
"Output(Projection) of LSTMP operator should not be null.");
38-
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
39-
"Output(Cell) of LSTMP operator should not be null.");
40-
PADDLE_ENFORCE(ctx->HasOutput("BatchGate"),
41-
"Output(BatchGate) of LSTMP operator should not be null.");
42-
PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"),
43-
"Output(BatchCellPreAct) of LSTMP operator should not be "
44-
"null.");
45-
PADDLE_ENFORCE(ctx->HasOutput("BatchHidden"),
46-
"Output(BatchHidden) of LSTMP operator should not be null.");
27+
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "LSTMP");
28+
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "LSTMP");
29+
OP_INOUT_CHECK(ctx->HasInput("ProjWeight"), "Input", "ProjWeight", "LSTMP");
30+
OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "LSTMP");
31+
32+
OP_INOUT_CHECK(ctx->HasOutput("Projection"), "Output", "Projection",
33+
"LSTMP");
34+
OP_INOUT_CHECK(ctx->HasOutput("Cell"), "Output", "Cell", "LSTMP");
35+
OP_INOUT_CHECK(ctx->HasOutput("BatchGate"), "Output", "BatchGate", "LSTMP");
36+
OP_INOUT_CHECK(ctx->HasOutput("BatchCellPreAct"), "Output",
37+
"BatchCellPreAct", "LSTMP");
38+
OP_INOUT_CHECK(ctx->HasOutput("BatchHidden"), "Output", "BatchHidden",
39+
"LSTMP");
4740

4841
auto in_dims = ctx->GetInputDim("Input");
4942

50-
PADDLE_ENFORCE_EQ(in_dims.size(), 2,
51-
"Input(X)'s rank of LSTMP operator must be 2.");
43+
PADDLE_ENFORCE_EQ(
44+
in_dims.size(), 2,
45+
platform::errors::InvalidArgument(
46+
"Input(X)'s rank of LSTMP operator must be 2, but received %d.",
47+
in_dims.size()));
5248

5349
int frame_size = in_dims[1] / 4;
5450
auto w_dims = ctx->GetInputDim("Weight");
5551
auto proj_dims = ctx->GetInputDim("ProjWeight");
56-
PADDLE_ENFORCE_EQ(w_dims.size(), 2,
57-
"The rank of Input(Weight) should be 2.");
58-
PADDLE_ENFORCE_EQ(w_dims[0], proj_dims[1],
59-
"The first dimension of Input(Weight) "
60-
"should be %d.",
61-
proj_dims[1]);
52+
PADDLE_ENFORCE_EQ(
53+
w_dims.size(), 2,
54+
platform::errors::InvalidArgument(
55+
"The rank of Input(Weight) should be 2, but received %d.",
56+
w_dims.size()));
57+
PADDLE_ENFORCE_EQ(
58+
w_dims[0], proj_dims[1],
59+
platform::errors::InvalidArgument(
60+
"The first dimension of Input(Weight) and the second dimension of "
61+
"Input(ProjWeight) should be the same, but received %d vs %d.",
62+
w_dims[0], proj_dims[1]));
6263
PADDLE_ENFORCE_EQ(w_dims[1], 4 * frame_size,
63-
"The second dimension of Input(Weight) "
64-
"should be 4 * %d.",
65-
frame_size);
66-
67-
PADDLE_ENFORCE_EQ(proj_dims.size(), 2,
68-
"The rank of Input(ProjWeight) should be 2.");
64+
platform::errors::InvalidArgument(
65+
"The second dimension of Input(Weight) should be 4 * "
66+
"%d, but received %d.",
67+
frame_size, w_dims[1]));
68+
69+
PADDLE_ENFORCE_EQ(
70+
proj_dims.size(), 2,
71+
platform::errors::InvalidArgument(
72+
"The rank of Input(ProjWeight) should be 2, but received %d.",
73+
proj_dims.size()));
6974
PADDLE_ENFORCE_EQ(proj_dims[0], frame_size,
70-
"The first dimension of Input(ProjWeight) "
71-
"should be %d.",
72-
frame_size);
75+
platform::errors::InvalidArgument(
76+
"The first dimension of Input(ProjWeight) should be "
77+
"%d, but received %d.",
78+
frame_size, proj_dims[0]));
7379

7480
if (ctx->HasInput("H0")) {
75-
PADDLE_ENFORCE(ctx->HasInput("C0"),
76-
"Input(C0) of LSTMP operator should not be null after "
77-
"Input(H0) provided.");
81+
PADDLE_ENFORCE_EQ(
82+
ctx->HasInput("C0"), true,
83+
platform::errors::NotFound("Input(C0) of LSTMP operator should not "
84+
"be null after Input(H0) provided."));
7885
}
7986

8087
auto b_dims = ctx->GetInputDim("Bias");
81-
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
82-
PADDLE_ENFORCE_EQ(b_dims[0], 1,
83-
"The first dimension of Input(Bias) should be 1.");
88+
PADDLE_ENFORCE_EQ(
89+
b_dims.size(), 2,
90+
platform::errors::InvalidArgument(
91+
"The rank of Input(Bias) should be 2, but received %d.",
92+
b_dims.size()));
93+
PADDLE_ENFORCE_EQ(
94+
b_dims[0], 1,
95+
platform::errors::InvalidArgument(
96+
"The first dimension of Input(Bias) should be 1, but received %d.",
97+
b_dims[0]));
8498

8599
if (ctx->Attrs().Get<bool>("use_peepholes")) {
86-
PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size,
87-
"The second dimension of Input(Bias) should be "
88-
"7 * %d if enable peepholes connection",
89-
frame_size);
100+
PADDLE_ENFORCE_EQ(
101+
b_dims[1], 7 * frame_size,
102+
platform::errors::InvalidArgument(
103+
"The second dimension of Input(Bias) should be 7 * %d if enable "
104+
"peepholes connection, but received %d.",
105+
frame_size, b_dims[1]));
90106
} else {
91-
PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size,
92-
"The second dimension of Input(Bias) should be "
93-
"4 * %d if disable peepholes connection",
94-
frame_size);
107+
PADDLE_ENFORCE_EQ(
108+
b_dims[1], 4 * frame_size,
109+
platform::errors::InvalidArgument(
110+
"The second dimension of Input(Bias) should be 4 * %d if disable "
111+
"peepholes connection, but received %d.",
112+
frame_size, b_dims[1]));
95113
}
96114

97115
framework::DDim out_dims({in_dims[0], frame_size});
@@ -314,21 +332,18 @@ class LSTMPGradOp : public framework::OperatorWithKernel {
314332
using framework::OperatorWithKernel::OperatorWithKernel;
315333

316334
void InferShape(framework::InferShapeContext* ctx) const override {
317-
PADDLE_ENFORCE(ctx->HasInput("Projection"),
318-
"Input(Projection) of LSTMP operator should not be null.");
319-
PADDLE_ENFORCE(ctx->HasInput("Cell"),
320-
"Input(Cell) of LSTMP operator should not be null.");
321-
PADDLE_ENFORCE(ctx->HasInput("Weight"),
322-
"Input(Weight) of LSTMP operator should not be null.");
323-
PADDLE_ENFORCE(ctx->HasInput("ProjWeight"),
324-
"Input(ProjWeight) of LSTMP operator should not be null.");
325-
PADDLE_ENFORCE(ctx->HasInput("Bias"),
326-
"Input(Bias) of LSTMP operator should not be null.");
327-
328-
PADDLE_ENFORCE(ctx->HasInput("BatchGate"),
329-
"Input(BatchGate) of LSTMP operator should not be null.");
330-
PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"),
331-
"Input(BatchGate) of LSTMP operator should not be null.");
335+
OP_INOUT_CHECK(ctx->HasInput("Projection"), "Input", "Projection",
336+
"LSTMP@Grad");
337+
OP_INOUT_CHECK(ctx->HasInput("Cell"), "Input", "Cell", "LSTMP@Grad");
338+
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "LSTMP@Grad");
339+
OP_INOUT_CHECK(ctx->HasInput("ProjWeight"), "Input", "ProjWeight",
340+
"LSTMP@Grad");
341+
OP_INOUT_CHECK(ctx->HasInput("Bias"), "Input", "Bias", "LSTMP@Grad");
342+
343+
OP_INOUT_CHECK(ctx->HasInput("BatchGate"), "Input", "BatchGate",
344+
"LSTMP@Grad");
345+
OP_INOUT_CHECK(ctx->HasInput("BatchCellPreAct"), "Input", "BatchCellPreAct",
346+
"LSTMP@Grad");
332347

333348
auto SetOutGradDim = [&ctx](const std::string& name) {
334349
auto g_name = framework::GradVarName(name);

python/paddle/fluid/layers/rnn.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2073,7 +2073,21 @@ def dynamic_lstm(input,
20732073
"""
20742074
assert in_dygraph_mode(
20752075
) is not True, "please use lstm instead of dynamic_lstm in dygraph mode!"
2076-
assert bias_attr is not False, "bias_attr should not be False in dynamic_lstmp."
2076+
assert bias_attr is not False, "bias_attr should not be False in dynamic_lstm."
2077+
2078+
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
2079+
'dynamic_lstm')
2080+
2081+
check_type(h_0, 'h_0', (Variable, type(None)), 'dynamic_lstm')
2082+
if isinstance(h_0, Variable):
2083+
check_variable_and_dtype(h_0, 'h_0', ['float32', 'float64'],
2084+
'dynamic_lstm')
2085+
2086+
check_type(c_0, 'c_0', (Variable, type(None)), 'dynamic_lstm')
2087+
if isinstance(c_0, Variable):
2088+
check_variable_and_dtype(c_0, 'c_0', ['float32', 'float64'],
2089+
'dynamic_lstm')
2090+
20772091
helper = LayerHelper('lstm', **locals())
20782092
size = size // 4
20792093
weight = helper.create_parameter(
@@ -2439,6 +2453,20 @@ def dynamic_lstmp(input,
24392453
) is not True, "please use lstm instead of dynamic_lstmp in dygraph mode!"
24402454

24412455
assert bias_attr is not False, "bias_attr should not be False in dynamic_lstmp."
2456+
2457+
check_variable_and_dtype(input, 'input', ['float32', 'float64'],
2458+
'dynamic_lstmp')
2459+
2460+
check_type(h_0, 'h_0', (Variable, type(None)), 'dynamic_lstmp')
2461+
if isinstance(h_0, Variable):
2462+
check_variable_and_dtype(h_0, 'h_0', ['float32', 'float64'],
2463+
'dynamic_lstmp')
2464+
2465+
check_type(c_0, 'c_0', (Variable, type(None)), 'dynamic_lstmp')
2466+
if isinstance(c_0, Variable):
2467+
check_variable_and_dtype(c_0, 'c_0', ['float32', 'float64'],
2468+
'dynamic_lstmp')
2469+
24422470
helper = LayerHelper('lstmp', **locals())
24432471
size = size // 4
24442472
weight = helper.create_parameter(

0 commit comments

Comments
 (0)