@@ -606,25 +606,34 @@ FrameOutput SingleStreamDecoder::getFrameAtIndexInternal(
606
606
}
607
607
608
608
FrameBatchOutput SingleStreamDecoder::getFramesAtIndices (
609
- const std::vector< int64_t > & frameIndices) {
609
+ const torch::Tensor & frameIndices) {
610
610
validateActiveStream (AVMEDIA_TYPE_VIDEO);
611
611
612
- auto indicesAreSorted =
613
- std::is_sorted (frameIndices.begin (), frameIndices.end ());
612
+ auto frameIndicesAccessor = frameIndices.accessor <int64_t , 1 >();
613
+
614
+ bool indicesAreSorted = true ;
615
+ for (int64_t i = 1 ; i < frameIndices.numel (); ++i) {
616
+ if (frameIndicesAccessor[i] < frameIndicesAccessor[i - 1 ]) {
617
+ indicesAreSorted = false ;
618
+ break ;
619
+ }
620
+ }
614
621
615
622
std::vector<size_t > argsort;
616
623
if (!indicesAreSorted) {
617
624
// if frameIndices is [13, 10, 12, 11]
618
625
// when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want
619
626
// to use to decode the frames
620
627
// and argsort is [ 1, 3, 2, 0]
621
- argsort.resize (frameIndices.size ());
628
+ argsort.resize (frameIndices.numel ());
622
629
for (size_t i = 0 ; i < argsort.size (); ++i) {
623
630
argsort[i] = i;
624
631
}
625
632
std::sort (
626
- argsort.begin (), argsort.end (), [&frameIndices](size_t a, size_t b) {
627
- return frameIndices[a] < frameIndices[b];
633
+ argsort.begin (),
634
+ argsort.end (),
635
+ [&frameIndicesAccessor](size_t a, size_t b) {
636
+ return frameIndicesAccessor[a] < frameIndicesAccessor[b];
628
637
});
629
638
}
630
639
@@ -633,12 +642,12 @@ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices(
633
642
const auto & streamInfo = streamInfos_[activeStreamIndex_];
634
643
const auto & videoStreamOptions = streamInfo.videoStreamOptions ;
635
644
FrameBatchOutput frameBatchOutput (
636
- frameIndices.size (), videoStreamOptions, streamMetadata);
645
+ frameIndices.numel (), videoStreamOptions, streamMetadata);
637
646
638
647
auto previousIndexInVideo = -1 ;
639
- for (size_t f = 0 ; f < frameIndices.size (); ++f) {
648
+ for (int64_t f = 0 ; f < frameIndices.numel (); ++f) {
640
649
auto indexInOutput = indicesAreSorted ? f : argsort[f];
641
- auto indexInVideo = frameIndices [indexInOutput];
650
+ auto indexInVideo = frameIndicesAccessor [indexInOutput];
642
651
643
652
if ((f > 0 ) && (indexInVideo == previousIndexInVideo)) {
644
653
// Avoid decoding the same frame twice
@@ -780,7 +789,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
780
789
frameIndices[i] = secondsToIndexLowerBound (frameSeconds);
781
790
}
782
791
783
- return getFramesAtIndices (frameIndices);
792
+ // TODO: Support tensors natively instead of a vector to avoid a copy.
793
+ return getFramesAtIndices (torch::tensor (frameIndices));
784
794
}
785
795
786
796
FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange (
@@ -1202,6 +1212,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
1202
1212
if (status == AVERROR_EOF) {
1203
1213
// End of file reached. We must drain the decoder
1204
1214
if (useCustomInterface) {
1215
+ // TODONVDEC P0: Re-think this. This should be simpler.
1205
1216
AutoAVPacket eofAutoPacket;
1206
1217
ReferenceAVPacket eofPacket (eofAutoPacket);
1207
1218
eofPacket->data = nullptr ;
0 commit comments