File tree Expand file tree Collapse file tree 1 file changed +16
-4
lines changed Expand file tree Collapse file tree 1 file changed +16
-4
lines changed Original file line number Diff line number Diff line change @@ -70,11 +70,23 @@ void SequenceReshapeLayer::forward(PassType passType) {
70
70
size_t outDim = getSize ();
71
71
72
72
size_t numSequences = input.getNumSequences ();
73
- auto startPositions = input.sequenceStartPositions ->getVector (false );
74
- const int * starts = startPositions->getData ();
75
73
76
- CHECK_EQ (starts[numSequences], input.getBatchSize ());
77
- CHECK_EQ (numSequences, startPositions->getSize () - 1 );
74
+ // by default, we assume each instance as a sequence
75
+ IVectorPtr seqStarts;
76
+ IVector::resizeOrCreate (seqStarts, input.getBatchSize () + 1 , false );
77
+ int * startsData = seqStarts->getData ();
78
+ for (int i = 0 ; i < input.getBatchSize () + 1 ; i++) {
79
+ startsData[i] = i;
80
+ }
81
+ const int * starts = startsData;
82
+
83
+ // if there is sequence, then use start positions
84
+ if (input.sequenceStartPositions ) {
85
+ auto startPositions = input.sequenceStartPositions ->getVector (false );
86
+ starts = startPositions->getData ();
87
+ CHECK_EQ (starts[numSequences], input.getBatchSize ());
88
+ CHECK_EQ (numSequences, startPositions->getSize () - 1 );
89
+ }
78
90
79
91
for (size_t seqID = 0 ; seqID < numSequences; seqID++) {
80
92
size_t inNumIns = starts[seqID + 1 ] - starts[seqID];
You can’t perform that action at this time.
0 commit comments