Skip to content

Commit d369577

Browse files
committed
add reversed poolSequenceWithStride
1 parent 08d6622 commit d369577

File tree

7 files changed

+40
-30
lines changed

7 files changed

+40
-30
lines changed

paddle/gserver/layers/SequenceLastInstanceLayer.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ class SequenceLastInstanceLayer : public SequencePoolLayer {
4040
protected:
4141
MatrixPtr tmpSrc_;
4242
MatrixPtr tmpDest_;
43-
bool select_first_;
4443
std::vector<int> insId_;
4544

4645
public:
@@ -59,7 +58,7 @@ REGISTER_LAYER(seqlastins, SequenceLastInstanceLayer);
5958
bool SequenceLastInstanceLayer::init(const LayerMap& layerMap,
6059
const ParameterMap& parameterMap) {
6160
SequencePoolLayer::init(layerMap, parameterMap);
62-
select_first_ = config_.select_first();
61+
reversed_ = config_.select_first();
6362

6463
tmpSrc_ =
6564
Matrix::create(nullptr, /* height= */ 1, 1, /* trans= */ false, useGpu_);
@@ -83,7 +82,7 @@ void SequenceLastInstanceLayer::forward(PassType passType) {
8382

8483
insId_.clear();
8584
for (size_t seqId = 0; seqId < newBatchSize_; ++seqId) {
86-
int insId = select_first_ ? starts[seqId] : starts[seqId + 1] - 1;
85+
int insId = reversed_ ? starts[seqId] : starts[seqId + 1] - 1;
8786
insId_.push_back(insId);
8887

8988
outputValue->subMatrix(seqId, 1, tmpDest_)

paddle/gserver/layers/SequencePoolLayer.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,9 @@ void SequencePoolLayer::forward(PassType passType) {
6868
}
6969
if (stride_ > 0) {
7070
CHECK_EQ(input.hasSubseq(), 0UL)
71-
<< "sequence stride pooling is not suitable for hasSubseq now";
72-
output_.poolSequenceWithStride(input, stride_, &stridePositions_);
71+
<< "sequence stride pooling is invalid for hasSubseq now";
72+
output_.poolSequenceWithStride(
73+
input, stride_, &stridePositions_, reversed_);
7374
newBatchSize_ = stridePositions_->getSize() - 1;
7475
}
7576

paddle/gserver/layers/SequencePoolLayer.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ class SequencePoolLayer : public Layer {
4949
int stride_;
5050
// store the start position of each stride window
5151
IVectorPtr stridePositions_;
52+
// Whether it is reversed sequence
53+
bool reversed_ = false;
5254

5355
public:
5456
explicit SequencePoolLayer(const LayerConfig& config) : Layer(config) {}

paddle/parameter/Argument.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -561,11 +561,13 @@ void Argument::degradeSequence(const Argument& input) {
561561

562562
void Argument::poolSequenceWithStride(const Argument& input,
563563
size_t stride,
564-
IVectorPtr* stridePostions) {
564+
IVectorPtr* stridePostions,
565+
bool reversed) {
565566
/*
566567
* If input.sequenceStartPositions = [0, 9, 14, 17, 30] and stride = 5,
567-
* then sequenceStartPositions = [0, 2, 3, 4, 7],
568-
* and stridePostions = [0, 5, 9, 14, 17, 22, 27, 30]
568+
* then sequenceStartPositions = [0, 2, 3, 4, 7].
569+
* If reversed = false, stridePostions = [0, 5, 9, 14, 17, 22, 27, 30];
570+
* else reversed = true, stridePostions = [0, 4, 9, 14, 17, 20, 25, 30]
569571
*/
570572
CHECK(input.sequenceStartPositions);
571573
CHECK_EQ(input.hasSubseq(), 0UL);
@@ -584,14 +586,13 @@ void Argument::poolSequenceWithStride(const Argument& input,
584586
if (seqLength == 0) {
585587
// empty sequence
586588
tgtBuf[seqId + 1] = tgtBuf[seqId];
587-
} else if (seqLength < stride) {
588-
tgtBuf[seqId + 1] = tgtBuf[seqId] + 1;
589589
} else {
590-
tgtBuf[seqId + 1] = tgtBuf[seqId] + ceil((float)seqLength / stride);
591-
int size =
592-
(seqLength % stride) ? seqLength / stride : seqLength / stride - 1;
593-
for (int i = 0; i < size; i++) {
594-
stridePos.emplace_back(stridePos.back() + stride);
590+
int size = ceil((float)seqLength / stride);
591+
tgtBuf[seqId + 1] = tgtBuf[seqId] + size;
592+
for (int i = 0; i < size - 1; i++) {
593+
int cur = reversed ? starts[seqId + 1] - (size - 1 - i) * stride
594+
: stridePos.back() + stride;
595+
stridePos.emplace_back(cur);
595596
}
596597
}
597598
}

paddle/parameter/Argument.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,8 @@ struct Argument {
298298
*/
299299
void poolSequenceWithStride(const Argument& input,
300300
size_t stride,
301-
IVectorPtr* stridePositions);
301+
IVectorPtr* stridePositions,
302+
bool reversed = false);
302303
/**
303304
* @brief getValueString will return the argument's output in string. There
304305
* are several kinds of output. The keys of output dictionary are 'value',

paddle/parameter/tests/test_argument.cpp

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,26 @@ TEST(Argument, poolSequenceWithStride) {
2727
inStart[3] = 17;
2828
inStart[4] = 30;
2929

30-
IVectorPtr stridePositions;
31-
output.poolSequenceWithStride(input, 5 /* stride */, &stridePositions);
32-
33-
const int* outStart = output.sequenceStartPositions->getData(false);
34-
CHECK_EQ(outStart[0], 0);
35-
CHECK_EQ(outStart[1], 2);
36-
CHECK_EQ(outStart[2], 3);
37-
CHECK_EQ(outStart[3], 4);
38-
CHECK_EQ(outStart[4], 7);
39-
40-
CHECK_EQ(stridePositions->getSize(), 8);
4130
int strideResult[] = {0, 5, 9, 14, 17, 22, 27, 30};
42-
for (int i = 0; i < 8; i++) {
43-
CHECK_EQ(stridePositions->getData()[i], strideResult[i]);
31+
int strideResultReversed[] = {0, 4, 9, 14, 17, 20, 25, 30};
32+
33+
for (auto reversed : {false, true}) {
34+
IVectorPtr stridePositions;
35+
output.poolSequenceWithStride(
36+
input, 5 /* stride */, &stridePositions, reversed);
37+
38+
const int* outStart = output.sequenceStartPositions->getData(false);
39+
CHECK_EQ(outStart[0], 0);
40+
CHECK_EQ(outStart[1], 2);
41+
CHECK_EQ(outStart[2], 3);
42+
CHECK_EQ(outStart[3], 4);
43+
CHECK_EQ(outStart[4], 7);
44+
45+
CHECK_EQ(stridePositions->getSize(), 8);
46+
auto result = reversed ? strideResultReversed : strideResult;
47+
for (int i = 0; i < 8; i++) {
48+
CHECK_EQ(stridePositions->getData()[i], result[i]);
49+
}
4450
}
4551
}
4652

python/paddle/trainer/config_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2497,7 +2497,7 @@ def __init__(self,
24972497
config_assert(
24982498
len(inputs) == 1, 'SequenceLastInstanceLayer must have 1 input')
24992499
if trans_type == 'seq':
2500-
config_assert(stride == -1, 'subseq do not support stride window')
2500+
config_assert(stride == -1, 'subseq does not support stride window')
25012501
self.config.trans_type = trans_type
25022502
self.config.seq_pool_stride = stride
25032503
self.set_layer_size(self.get_input_layer(0).size)

0 commit comments

Comments
 (0)