@@ -48,48 +48,33 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel {
48
48
using framework::OperatorWithKernel::OperatorWithKernel;
49
49
50
50
void InferShape (framework::InferShapeContext* ctx) const override {
51
- PADDLE_ENFORCE (
52
- ctx->HasInput (" param" ),
53
- " Input (param) of average_accumulates op should not be null." );
54
- PADDLE_ENFORCE (
55
- ctx->HasInput (" in_sum_1" ),
56
- " Input (sum_1) of average_accumulates op should not be null." );
57
- PADDLE_ENFORCE (
58
- ctx->HasInput (" in_sum_2" ),
59
- " Input (sum_2) of average_accumulates op should not be null." );
60
- PADDLE_ENFORCE (
61
- ctx->HasInput (" in_sum_3" ),
62
- " Input (sum_3) of average_accumulates op should not be null." );
63
- PADDLE_ENFORCE (
64
- ctx->HasInput (" in_num_accumulates" ),
65
- " Input (in_num_accumulates) of average_accumulates op should "
66
- " not be null." );
67
- PADDLE_ENFORCE (ctx->HasInput (" in_old_num_accumulates" ),
68
- " Input (old_num_accumulates) of average_accumulates op "
69
- " should not be null." );
70
- PADDLE_ENFORCE (
71
- ctx->HasInput (" in_num_updates" ),
72
- " Input (num_updates) of average_accumulates op should not be null." );
73
-
74
- PADDLE_ENFORCE (
75
- ctx->HasOutput (" out_sum_1" ),
76
- " Output (sum_1) of average_accumulates op should not be null." );
77
- PADDLE_ENFORCE (
78
- ctx->HasOutput (" out_sum_2" ),
79
- " Output (sum_2) of average_accumulates op should not be null." );
80
- PADDLE_ENFORCE (
81
- ctx->HasOutput (" out_sum_3" ),
82
- " Output (sum_3) of average_accumulates op should not be null." );
83
- PADDLE_ENFORCE (ctx->HasOutput (" out_num_accumulates" ),
84
- " Output (num_accumulates) of average_accumulates op should "
85
- " not be null." );
86
- PADDLE_ENFORCE (ctx->HasOutput (" out_old_num_accumulates" ),
87
- " Output (old_num_accumulates) of average_accumulates op "
88
- " should not be null." );
89
- PADDLE_ENFORCE (
90
- ctx->HasOutput (" out_num_updates" ),
91
- " Output (num_updates) of average_accumulates op should not be null." );
92
-
51
+ OP_INOUT_CHECK (ctx->HasInput (" param" ), " Input" , " param" ,
52
+ " AverageAccumulates" );
53
+ OP_INOUT_CHECK (ctx->HasInput (" in_sum_1" ), " Input" , " in_sum_1" ,
54
+ " AverageAccumulates" );
55
+ OP_INOUT_CHECK (ctx->HasInput (" in_sum_2" ), " Input" , " in_sum_2" ,
56
+ " AverageAccumulates" );
57
+ OP_INOUT_CHECK (ctx->HasInput (" in_sum_3" ), " Input" , " in_sum_3" ,
58
+ " AverageAccumulates" );
59
+ OP_INOUT_CHECK (ctx->HasInput (" in_num_accumulates" ), " Input" ,
60
+ " in_num_accumulates" , " AverageAccumulates" );
61
+ OP_INOUT_CHECK (ctx->HasInput (" in_old_num_accumulates" ), " Input" ,
62
+ " in_old_num_accumulates" , " AverageAccumulates" );
63
+ OP_INOUT_CHECK (ctx->HasInput (" in_num_updates" ), " Input" , " in_num_updates" ,
64
+ " AverageAccumulates" );
65
+
66
+ OP_INOUT_CHECK (ctx->HasOutput (" out_sum_1" ), " Output" , " out_sum_1" ,
67
+ " AverageAccumulates" );
68
+ OP_INOUT_CHECK (ctx->HasOutput (" out_sum_2" ), " Output" , " out_sum_2" ,
69
+ " AverageAccumulates" );
70
+ OP_INOUT_CHECK (ctx->HasOutput (" out_sum_3" ), " Output" , " out_sum_3" ,
71
+ " AverageAccumulates" );
72
+ OP_INOUT_CHECK (ctx->HasOutput (" out_num_accumulates" ), " Output" ,
73
+ " out_num_accumulates" , " AverageAccumulates" );
74
+ OP_INOUT_CHECK (ctx->HasOutput (" out_old_num_accumulates" ), " Output" ,
75
+ " out_old_num_accumulates" , " AverageAccumulates" );
76
+ OP_INOUT_CHECK (ctx->HasOutput (" out_num_updates" ), " Output" ,
77
+ " out_num_updates" , " AverageAccumulates" );
93
78
auto in_dim = ctx->GetInputDim (" param" );
94
79
95
80
ctx->SetOutputDim (" out_sum_1" , in_dim);
0 commit comments