Skip to content

Commit b368f13

Browse files
committed
Fix output dims of sequence expand op
1 parent ef8cb8f commit b368f13

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

paddle/operators/sequence_expand_op.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ class SequenceExpandOp : public framework::OperatorWithKernel {
2929
PADDLE_ENFORCE(ctx->HasOutput("Out"));
3030
PADDLE_ENFORCE(ctx->HasInput("Y"));
3131
framework::DDim out_dim;
32-
out_dim = ctx->GetInputDim("Y");
32+
auto y_dim = ctx->GetInputDim("Y");
33+
out_dim = ctx->GetInputDim("X");
34+
out_dim[0] = y_dim[0];
3335
ctx->ShareLoD("Y", "Out");
3436
ctx->SetOutputDim("Out", out_dim);
3537
}

python/paddle/v2/fluid/tests/test_sequence_expand.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,5 +73,20 @@ def set_data(self):
7373
self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)}
7474

7575

76+
class TestSequenceExpandCase4(TestSequenceExpand):
77+
def set_data(self):
78+
x_data = np.array(
79+
[0.1, 0.3, 0.2, 0.15, 0.25, 0.2, 0.15, 0.25, 0.1, 0.3]).reshape(
80+
[2, 5]).astype('float32')
81+
x_lod = [[
82+
0,
83+
1,
84+
2,
85+
]]
86+
y_data = np.random.uniform(0.1, 1, [2, 1]).astype('float32')
87+
y_lod = [[0, 1, 2], [0, 1, 2]]
88+
self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)}
89+
90+
7691
if __name__ == '__main__':
7792
unittest.main()

0 commit comments

Comments
 (0)