Skip to content

Commit bcdedec

Browse files
author
Haonan
authored
handle non-sequence data in sequenceReshapeLayer (#5188)
1 parent 0318f47 commit bcdedec

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

paddle/gserver/layers/SequenceReshapeLayer.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,23 @@ void SequenceReshapeLayer::forward(PassType passType) {
7070
size_t outDim = getSize();
7171

7272
size_t numSequences = input.getNumSequences();
73-
auto startPositions = input.sequenceStartPositions->getVector(false);
74-
const int* starts = startPositions->getData();
7573

76-
CHECK_EQ(starts[numSequences], input.getBatchSize());
77-
CHECK_EQ(numSequences, startPositions->getSize() - 1);
74+
// by default, we assume each instance as a sequence
75+
IVectorPtr seqStarts;
76+
IVector::resizeOrCreate(seqStarts, input.getBatchSize() + 1, false);
77+
int* startsData = seqStarts->getData();
78+
for (int i = 0; i < input.getBatchSize() + 1; i++) {
79+
startsData[i] = i;
80+
}
81+
const int* starts = startsData;
82+
83+
// if there is sequence, then use start positions
84+
if (input.sequenceStartPositions) {
85+
auto startPositions = input.sequenceStartPositions->getVector(false);
86+
starts = startPositions->getData();
87+
CHECK_EQ(starts[numSequences], input.getBatchSize());
88+
CHECK_EQ(numSequences, startPositions->getSize() - 1);
89+
}
7890

7991
for (size_t seqID = 0; seqID < numSequences; seqID++) {
8092
size_t inNumIns = starts[seqID + 1] - starts[seqID];

0 commit comments

Comments
 (0)