Skip to content

Commit eaad3e4

Browse files
authored
Add check of input in sequence_expand op. (#15466)
* Add check of input in sequence_expand op. test=develop * Correct the unittest of sequence_expand op. test=develop
1 parent f4dec5c commit eaad3e4

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

paddle/fluid/operators/sequence_ops/sequence_expand_op.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ class SequenceExpandOp : public framework::OperatorWithKernel {
6868
"Level number of Input(X)'s lod could be 0. Otherwise "
6969
"size of Input(X)'s first level lod should be equal to "
7070
"size of Input(Y)'s referred level lod.");
71+
} else {
72+
PADDLE_ENFORCE_EQ(x_dims[0], y_lod[ref_level].size() - 1,
73+
"When Input(X)'s lod is null, the dims[0] of "
74+
"Input(X) should match the "
75+
"size of Input(Y)'s referred level lod.");
7176
}
7277

7378
int64_t out_first_dim = 0;

python/paddle/fluid/tests/unittests/test_sequence_expand.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,10 @@ def test_check_grad(self):
8181
class TestSequenceExpandCase1(TestSequenceExpand):
8282
def set_data(self):
8383
x_data = np.random.uniform(0.1, 1, [5, 1]).astype('float32')
84-
x_lod = [[2, 3]]
8584
y_data = np.random.uniform(0.1, 1, [13, 1]).astype('float32')
8685
y_lod = [[2, 3], [2, 2, 3, 3, 3]]
8786
self.inputs = {'X': x_data, 'Y': (y_data, y_lod)}
88-
self.attrs = {'ref_level': 0}
87+
self.attrs = {'ref_level': 1}
8988

9089

9190
class TestSequenceExpandCase2(TestSequenceExpand):

0 commit comments

Comments
 (0)