Skip to content

Commit 50b6e4c

Browse files
committed
Fix expand grad op infer shape
test=develop
1 parent 30147d7 commit 50b6e4c

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

paddle/fluid/operators/expand_op.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,12 @@ class ExpandGradOp : public framework::OperatorWithKernel {
114114
ctx->Attrs().Get<std::vector<int>>("expand_times");
115115
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
116116

117-
for (size_t i = 0; i < expand_times.size(); ++i) {
117+
size_t start_pos = 0u;
118+
if (!ctx->IsRuntime()) {
119+
start_pos = 1u;
120+
}
121+
122+
for (size_t i = start_pos; i < expand_times.size(); ++i) {
118123
PADDLE_ENFORCE_EQ(x_dims[i] * expand_times[i], out_dims[i],
119124
"Each dimension size of Input(Out@GRAD) should be "
120125
"equal to multiplication of crroresponding dimension "

0 commit comments

Comments
 (0)