Skip to content

Commit 593bbfe

Browse files
authored
Merge pull request #11765 from jacquesqiao/fix-adam-op-for-selectedrows
fix adam op for selected rows
2 parents 958823f + 20fae68 commit 593bbfe

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

paddle/fluid/operators/adam_op.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,12 @@ class AdamOp : public framework::OperatorWithKernel {
5656
"Beta2 power accumulator should have 1 dimension");
5757

5858
auto param_dims = ctx->GetInputDim("Param");
59-
PADDLE_ENFORCE_EQ(
60-
param_dims, ctx->GetInputDim("Grad"),
61-
"Param and Grad input of AdamOp should have same dimension");
59+
if (ctx->GetInputsVarType("Grad")[0] ==
60+
framework::proto::VarType::LOD_TENSOR) {
61+
PADDLE_ENFORCE_EQ(
62+
param_dims, ctx->GetInputDim("Grad"),
63+
"Param and Grad input of AdamOp should have same dimension");
64+
}
6265
PADDLE_ENFORCE_EQ(
6366
param_dims, ctx->GetInputDim("Moment1"),
6467
"Param and Moment1 input of AdamOp should have same dimension");

paddle/fluid/operators/adam_op.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,10 @@ class AdamOpKernel : public framework::OpKernel<T> {
282282
} else if (grad_var->IsType<framework::SelectedRows>()) {
283283
auto& grad =
284284
Ref(ctx.Input<framework::SelectedRows>("Grad"), "Must set Grad");
285+
if (grad.rows().size() == 0) {
286+
VLOG(3) << "grad row size is 0!!";
287+
return;
288+
}
285289
// merge duplicated rows if any.
286290
scatter::MergeAdd<DeviceContext, T> merge_func;
287291
auto grad_merge =

0 commit comments

Comments
 (0)