@@ -242,60 +242,14 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
242
242
" Input(LogLikelihood@GRAD) shoudl be not null." );
243
243
244
244
auto transition_exps_dims = ctx->GetInputDim (" TransitionExps" );
245
- PADDLE_ENFORCE_EQ (transition_exps_dims.size (), 2 ,
246
- " The Input(TransitionExps) should be a 2-D tensor." );
247
- bool check = true ;
248
- if ((!ctx->IsRuntime ()) &&
249
- (transition_exps_dims[0 ] <= 0 || transition_exps_dims[1 ] <= 0 )) {
250
- check = false ;
251
- }
252
- if (check) {
253
- PADDLE_ENFORCE_EQ (
254
- transition_exps_dims[0 ] - 2 , transition_exps_dims[1 ],
255
- " An invalid dimension for the Input(TransitionExps), which should "
256
- " be a 2-D tensor with shape [(D + 2) x D]." );
257
- }
258
-
259
245
auto emission_exps_dims = ctx->GetInputDim (" EmissionExps" );
260
- auto label_dims = ctx->GetInputDim (" Label" );
261
- if (ctx->HasInput (" Length" )) {
262
- PADDLE_ENFORCE_EQ (emission_exps_dims.size (), 3 ,
263
- " The Input(EmissionExps) should be a 3-D tensor." );
264
- PADDLE_INFERSHAPE_ENFORCE_EQ (
265
- ctx, emission_exps_dims[2 ], transition_exps_dims[1 ],
266
- " The 3nd dimension of the Input(EmissionExps) and the "
267
- " Input(TransitionExps) should be equal to the tag number." );
268
- PADDLE_ENFORCE_EQ (label_dims.size (), 3 ,
269
- " The Input(Label) should be a 3-D tensor with the 3nd "
270
- " dimensions fixed to 1." );
271
- } else {
272
- PADDLE_ENFORCE_EQ (emission_exps_dims.size (), 2 ,
273
- " The Input(EmissionExps) should be a 2-D tensor." );
274
- PADDLE_INFERSHAPE_ENFORCE_EQ (
275
- ctx, emission_exps_dims[1 ], transition_exps_dims[1 ],
276
- " The 2nd dimension of the Input(EmissionExps) and the "
277
- " Input(TransitionExps) should be equal to the tag number." );
278
- PADDLE_ENFORCE_EQ (label_dims.size (), 2 ,
279
- " The Input(Label) should be a 2-D tensor" );
280
- PADDLE_ENFORCE_EQ (label_dims[1 ], 1 ,
281
- " The Input(Label) 2nd dimensions fixed to 1." );
282
- }
283
- PADDLE_ENFORCE_NE (emission_exps_dims[0 ], 0 ,
284
- " An empty mini-batch is not allowed." );
285
-
286
- PADDLE_INFERSHAPE_ENFORCE_EQ (
287
- ctx, emission_exps_dims[0 ], label_dims[0 ],
288
- " The height of Input(EmissionExps) and the height of Input(Label) "
289
- " should be the same." );
290
-
291
246
if (ctx->HasOutput (framework::GradVarName (" Emission" ))) {
292
247
ctx->SetOutputDim (framework::GradVarName (" Emission" ), emission_exps_dims);
293
248
if (ctx->HasInput (" Length" ) == false ) {
294
249
ctx->ShareLoD (" Emission" , framework::GradVarName (" Emission" ));
295
250
}
296
251
}
297
- // ctx->SetOutputDim(framework::GradVarName("Emission"),
298
- // emission_exps_dims);
252
+
299
253
if (ctx->HasOutput (framework::GradVarName (" Transition" ))) {
300
254
ctx->SetOutputDim (framework::GradVarName (" Transition" ),
301
255
transition_exps_dims);
0 commit comments