We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent b756063 commit df7a266Copy full SHA for df7a266
paddle/fluid/operators/adam_op.cc
@@ -56,9 +56,12 @@ class AdamOp : public framework::OperatorWithKernel {
56
"Beta2 power accumulator should have 1 dimension");
57
58
auto param_dims = ctx->GetInputDim("Param");
59
- PADDLE_ENFORCE_EQ(
60
- param_dims, ctx->GetInputDim("Grad"),
61
- "Param and Grad input of AdamOp should have same dimension");
+ if (ctx->GetInputsVarType("Grad")[0] ==
+ framework::proto::VarType::LOD_TENSOR) {
+ PADDLE_ENFORCE_EQ(
62
+ param_dims, ctx->GetInputDim("Grad"),
63
+ "Param and Grad input of AdamOp should have same dimension");
64
+ }
65
PADDLE_ENFORCE_EQ(
66
param_dims, ctx->GetInputDim("Moment1"),
67
"Param and Moment1 input of AdamOp should have same dimension");
0 commit comments