Skip to content

Commit d971d5b

Browse files
authored
Merge pull request #14431 from velconia/fix_expand_op_dim_in_compile_time
Fix expand op incorrect infer shape
2 parents b32c13d + 560b29c commit d971d5b

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

paddle/fluid/operators/expand_op.cc

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ class ExpandOp : public framework::OperatorWithKernel {
4747
out_shape[i] = x_dims[i] * expand_times[i];
4848
}
4949

50+
// set the first dim to -1 in compile time
51+
if (!ctx->IsRuntime()) {
52+
out_shape[0] = x_dims[0];
53+
}
54+
5055
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
5156
if (out_shape[0] == x_dims[0]) {
5257
ctx->ShareLoD("X", "Out");
@@ -109,7 +114,16 @@ class ExpandGradOp : public framework::OperatorWithKernel {
109114
ctx->Attrs().Get<std::vector<int>>("expand_times");
110115
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
111116

112-
for (size_t i = 0; i < expand_times.size(); ++i) {
117+
size_t start_pos = 0u;
118+
if (!ctx->IsRuntime()) {
119+
PADDLE_ENFORCE_EQ(
120+
x_dims[0], out_dims[0],
121+
"The first dimension size of Input(Out@GRAD) should be "
122+
"equal to the crroresponding dimension size of Input(X)");
123+
start_pos = 1u;
124+
}
125+
126+
for (size_t i = start_pos; i < expand_times.size(); ++i) {
113127
PADDLE_ENFORCE_EQ(x_dims[i] * expand_times[i], out_dims[i],
114128
"Each dimension size of Input(Out@GRAD) should be "
115129
"equal to multiplication of crroresponding dimension "

python/paddle/fluid/tests/unittests/test_infer_shape.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,34 @@ def test_mul_op(self):
8383
mul_op_desc.infer_shape(block)
8484
self.assertEqual(out.shape(), [x_shape[0], y_shape[1]])
8585

86+
def test_expand_op(self):
87+
prog = core.ProgramDesc()
88+
self.assertIsNotNone(prog)
89+
block = prog.block(0)
90+
self.assertIsNotNone(block)
91+
92+
shape = [-1, 20]
93+
expand_times = [3, 1]
94+
95+
# prepare input/output
96+
x1 = block.var(six.b("x"))
97+
x1.set_type(core.VarDesc.VarType.LOD_TENSOR)
98+
x1.set_shape(shape)
99+
100+
out = block.var(six.b("out"))
101+
out.set_type(core.VarDesc.VarType.LOD_TENSOR)
102+
103+
# prepare the operator
104+
sum_op_desc = block.append_op()
105+
sum_op_desc.set_type("expand")
106+
sum_op_desc.set_input("X", ["x"])
107+
sum_op_desc.set_output("Out", ["out"])
108+
sum_op_desc._set_attr('expand_times', expand_times)
109+
110+
sum_op_desc.check_attrs()
111+
sum_op_desc.infer_shape(block)
112+
self.assertEqual(out.shape(), shape)
113+
86114

87115
if __name__ == '__main__':
88116
unittest.main()

0 commit comments

Comments
 (0)