Skip to content

Commit 733ea0d

Browse files
committed
adjust infershape details
1 parent 2969aba commit 733ea0d

File tree

4 files changed

+11
-23
lines changed

4 files changed

+11
-23
lines changed

paddle/fluid/operators/sequence_enumerate_op.cc

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/operators/sequence_enumerate_op.h"
16-
#include <vector>
1716

1817
namespace paddle {
1918
namespace operators {
@@ -34,18 +33,12 @@ class SequenceEnumerateOp : public framework::OperatorWithKernel {
3433
PADDLE_ENFORCE_EQ(
3534
x_dims.size(), 2UL,
3635
"Input(X) of SequenceEnumerate operator's rank should be 2.");
36+
PADDLE_ENFORCE_EQ(
37+
x_dims[1], 1UL,
38+
"Input(X) of SequenceEnumerate operator's 2nd dimension should be 1.");
3739

3840
const auto win_size = ctx->Attrs().Get<int>("win_size");
39-
// TODO(chenweihang): unittest doesn't has batch size, but test_layers has
40-
auto first_dim = x_dims[0] == -1 ? x_dims[1] : x_dims[0];
41-
PADDLE_ENFORCE(win_size <= first_dim,
42-
"The enumerate window size should be less than or equal to "
43-
"input sequence length.");
44-
45-
std::vector<int64_t> out_shape(x_dims.size() + 1, 0);
46-
for (int i = 0; i < x_dims.size(); ++i) out_shape.emplace_back(x_dims[i]);
47-
out_shape.emplace_back(win_size);
48-
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
41+
ctx->SetOutputDim("Out", {x_dims[0], win_size});
4942
ctx->ShareLoD("X", "Out");
5043
}
5144
};

python/paddle/fluid/layers/nn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5563,15 +5563,15 @@ def sequence_enumerate(input, win_size, pad_value, name=None):
55635563
out = fluid.layers.sequence_enumerate(input=x, win_size=3, pad_value=0)
55645564
"""
55655565
helper = LayerHelper('sequence_enumerate', **locals())
5566-
out = helper.create_tmp_variable(helper.input_dtype())
5566+
out = helper.create_tmp_variable(helper.input_dtype(), stop_gradient=True)
55675567
helper.append_op(
55685568
type='sequence_enumerate',
55695569
inputs={'X': input},
55705570
outputs={'Out': out},
55715571
attrs={'win_size': win_size,
55725572
'pad_value': pad_value})
55735573

5574-
5574+
55755575
def stack(x, axis=0):
55765576
helper = LayerHelper('stack', **locals())
55775577
axis = 0 if axis is None else axis

python/paddle/fluid/tests/unittests/test_layers.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -522,10 +522,8 @@ def test_prelu(self):
522522
def test_sequence_enumerate(self):
523523
program = Program()
524524
with program_guard(program):
525-
x = layers.data(
526-
name="input", shape=[30], dtype='int32', lod_level=1)
525+
x = layers.data(name="input", shape=[1], dtype='int32', lod_level=1)
527526
out = layers.sequence_enumerate(input=x, win_size=2, pad_value=0)
528-
self.assertIsNotNone(out)
529527
print(str(program))
530528

531529

python/paddle/fluid/tests/unittests/test_sequence_enumerate_op.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from op_test import OpTest
2020

2121

22-
def sequence_enumerate(input_seq, lod0, win_size, pad_value):
22+
def sequence_enumerate(input_seq, win_size, pad_value):
2323
out_seq = []
2424
for idx in range(0, len(input_seq)):
2525
single_seq = []
@@ -48,8 +48,7 @@ def init_test_case(self):
4848
self.lod = [[9, 4, 11, 6]]
4949
self.win_size = 2
5050
self.pad_value = 0
51-
out_seq = sequence_enumerate(self.in_seq, self.lod[0], self.win_size,
52-
self.pad_value)
51+
out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value)
5352
self.out_seq = np.array(out_seq).astype("int32")
5453

5554

@@ -59,8 +58,7 @@ def init_test_case(self):
5958
self.lod = [[9, 4, 11, 6]]
6059
self.win_size = 2
6160
self.pad_value = 0
62-
out_seq = sequence_enumerate(self.in_seq, self.lod[0], self.win_size,
63-
self.pad_value)
61+
out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value)
6462
self.out_seq = np.array(out_seq).astype("int64")
6563

6664

@@ -70,8 +68,7 @@ def init_test_case(self):
7068
self.lod = [[9, 4, 11, 6]]
7169
self.win_size = 30
7270
self.pad_value = 0
73-
out_seq = sequence_enumerate(self.in_seq, self.lod[0], self.win_size,
74-
self.pad_value)
71+
out_seq = sequence_enumerate(self.in_seq, self.win_size, self.pad_value)
7572
self.out_seq = np.array(out_seq).astype("int32")
7673

7774

0 commit comments

Comments
 (0)