Skip to content

Commit dcd8e30

Browse files
author
Yibing Liu
authored
Remove redundant infershape in linear chain crf grad, test=release/1.6 (#20629) (#20634)
1 parent 8493097 commit dcd8e30

File tree

1 file changed

+1
-47
lines changed

1 file changed

+1
-47
lines changed

paddle/fluid/operators/linear_chain_crf_op.cc

Lines changed: 1 addition & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -242,60 +242,14 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
242242
"Input(LogLikelihood@GRAD) shoudl be not null.");
243243

244244
auto transition_exps_dims = ctx->GetInputDim("TransitionExps");
245-
PADDLE_ENFORCE_EQ(transition_exps_dims.size(), 2,
246-
"The Input(TransitionExps) should be a 2-D tensor.");
247-
bool check = true;
248-
if ((!ctx->IsRuntime()) &&
249-
(transition_exps_dims[0] <= 0 || transition_exps_dims[1] <= 0)) {
250-
check = false;
251-
}
252-
if (check) {
253-
PADDLE_ENFORCE_EQ(
254-
transition_exps_dims[0] - 2, transition_exps_dims[1],
255-
"An invalid dimension for the Input(TransitionExps), which should "
256-
"be a 2-D tensor with shape [(D + 2) x D].");
257-
}
258-
259245
auto emission_exps_dims = ctx->GetInputDim("EmissionExps");
260-
auto label_dims = ctx->GetInputDim("Label");
261-
if (ctx->HasInput("Length")) {
262-
PADDLE_ENFORCE_EQ(emission_exps_dims.size(), 3,
263-
"The Input(EmissionExps) should be a 3-D tensor.");
264-
PADDLE_INFERSHAPE_ENFORCE_EQ(
265-
ctx, emission_exps_dims[2], transition_exps_dims[1],
266-
"The 3nd dimension of the Input(EmissionExps) and the "
267-
"Input(TransitionExps) should be equal to the tag number.");
268-
PADDLE_ENFORCE_EQ(label_dims.size(), 3,
269-
"The Input(Label) should be a 3-D tensor with the 3nd "
270-
"dimensions fixed to 1.");
271-
} else {
272-
PADDLE_ENFORCE_EQ(emission_exps_dims.size(), 2,
273-
"The Input(EmissionExps) should be a 2-D tensor.");
274-
PADDLE_INFERSHAPE_ENFORCE_EQ(
275-
ctx, emission_exps_dims[1], transition_exps_dims[1],
276-
"The 2nd dimension of the Input(EmissionExps) and the "
277-
"Input(TransitionExps) should be equal to the tag number.");
278-
PADDLE_ENFORCE_EQ(label_dims.size(), 2,
279-
"The Input(Label) should be a 2-D tensor");
280-
PADDLE_ENFORCE_EQ(label_dims[1], 1,
281-
"The Input(Label) 2nd dimensions fixed to 1.");
282-
}
283-
PADDLE_ENFORCE_NE(emission_exps_dims[0], 0,
284-
"An empty mini-batch is not allowed.");
285-
286-
PADDLE_INFERSHAPE_ENFORCE_EQ(
287-
ctx, emission_exps_dims[0], label_dims[0],
288-
"The height of Input(EmissionExps) and the height of Input(Label) "
289-
"should be the same.");
290-
291246
if (ctx->HasOutput(framework::GradVarName("Emission"))) {
292247
ctx->SetOutputDim(framework::GradVarName("Emission"), emission_exps_dims);
293248
if (ctx->HasInput("Length") == false) {
294249
ctx->ShareLoD("Emission", framework::GradVarName("Emission"));
295250
}
296251
}
297-
// ctx->SetOutputDim(framework::GradVarName("Emission"),
298-
// emission_exps_dims);
252+
299253
if (ctx->HasOutput(framework::GradVarName("Transition"))) {
300254
ctx->SetOutputDim(framework::GradVarName("Transition"),
301255
transition_exps_dims);

0 commit comments

Comments
 (0)