Skip to content

Commit 30147d7

Browse files
committed
Fix expand op incorrect infer shape
test=develop
1 parent 9be99b1 commit 30147d7

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

paddle/fluid/operators/expand_op.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ class ExpandOp : public framework::OperatorWithKernel {
4747
out_shape[i] = x_dims[i] * expand_times[i];
4848
}
4949

50+
// set the first dim to -1 in compile time
51+
if (!ctx->IsRuntime()) {
52+
out_shape[0] = x_dims[0];
53+
}
54+
5055
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
5156
if (out_shape[0] == x_dims[0]) {
5257
ctx->ShareLoD("X", "Out");

python/paddle/fluid/tests/unittests/test_infer_shape.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,34 @@ def test_mul_op(self):
8383
mul_op_desc.infer_shape(block)
8484
self.assertEqual(out.shape(), [x_shape[0], y_shape[1]])
8585

86+
def test_expand_op(self):
87+
prog = core.ProgramDesc()
88+
self.assertIsNotNone(prog)
89+
block = prog.block(0)
90+
self.assertIsNotNone(block)
91+
92+
shape = [-1, 20]
93+
expand_times = [3, 1]
94+
95+
# prepare input/output
96+
x1 = block.var(six.b("x"))
97+
x1.set_type(core.VarDesc.VarType.LOD_TENSOR)
98+
x1.set_shape(shape)
99+
100+
out = block.var(six.b("out"))
101+
out.set_type(core.VarDesc.VarType.LOD_TENSOR)
102+
103+
# prepare the operator
104+
sum_op_desc = block.append_op()
105+
sum_op_desc.set_type("expand")
106+
sum_op_desc.set_input("X", ["x"])
107+
sum_op_desc.set_output("Out", ["out"])
108+
sum_op_desc._set_attr('expand_times', expand_times)
109+
110+
sum_op_desc.check_attrs()
111+
sum_op_desc.infer_shape(block)
112+
self.assertEqual(out.shape(), shape)
113+
86114

87115
if __name__ == '__main__':
88116
unittest.main()

0 commit comments

Comments
 (0)