File tree Expand file tree Collapse file tree 2 files changed +10
-3
lines changed Expand file tree Collapse file tree 2 files changed +10
-3
lines changed Original file line number Diff line number Diff line change @@ -56,9 +56,12 @@ class AdamOp : public framework::OperatorWithKernel {
56
56
" Beta2 power accumulator should have 1 dimension" );
57
57
58
58
auto param_dims = ctx->GetInputDim (" Param" );
59
- PADDLE_ENFORCE_EQ (
60
- param_dims, ctx->GetInputDim (" Grad" ),
61
- " Param and Grad input of AdamOp should have same dimension" );
59
+ if (ctx->GetInputsVarType (" Grad" )[0 ] ==
60
+ framework::proto::VarType::LOD_TENSOR) {
61
+ PADDLE_ENFORCE_EQ (
62
+ param_dims, ctx->GetInputDim (" Grad" ),
63
+ " Param and Grad input of AdamOp should have same dimension" );
64
+ }
62
65
PADDLE_ENFORCE_EQ (
63
66
param_dims, ctx->GetInputDim (" Moment1" ),
64
67
" Param and Moment1 input of AdamOp should have same dimension" );
Original file line number Diff line number Diff line change @@ -282,6 +282,10 @@ class AdamOpKernel : public framework::OpKernel<T> {
282
282
} else if (grad_var->IsType <framework::SelectedRows>()) {
283
283
auto & grad =
284
284
Ref (ctx.Input <framework::SelectedRows>(" Grad" ), " Must set Grad" );
285
+ if (grad.rows ().size () == 0 ) {
286
+ VLOG (3 ) << " grad row size is 0!!" ;
287
+ return ;
288
+ }
285
289
// merge duplicated rows if any.
286
290
scatter::MergeAdd<DeviceContext, T> merge_func;
287
291
auto grad_merge =
You can’t perform that action at this time.
0 commit comments