Skip to content

Commit 121a9fd

Browse files
committed
Change return type to a 1D tensor of int64
1 parent 736401c commit 121a9fd

File tree

6 files changed

+21
-14
lines changed

6 files changed

+21
-14
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -553,15 +553,17 @@ VideoDecoder::ContainerMetadata VideoDecoder::getContainerMetadata() const {
553553
return containerMetadata_;
554554
}
555555

556-
std::vector<int64_t> VideoDecoder::getKeyFrameIndices(int streamIndex) {
556+
torch::Tensor VideoDecoder::getKeyFrameIndices(int streamIndex) {
557557
validateUserProvidedStreamIndex(streamIndex);
558558
validateScannedAllStreams("getKeyFrameIndices");
559559

560-
std::vector<int64_t> keyFrameIndices;
561-
const StreamInfo& streamInfo = streamInfos_[streamIndex];
562-
for (const FrameInfo& frameInfo : streamInfo.keyFrames) {
563-
keyFrameIndices.push_back(getKeyFrameIndexForPtsUsingScannedIndex(
564-
streamInfo.keyFrames, frameInfo.pts));
560+
const std::vector<FrameInfo>& keyFrames = streamInfos_[streamIndex].keyFrames;
561+
torch::Tensor keyFrameIndices =
562+
torch::empty({static_cast<int64_t>(keyFrames.size())}, {torch::kInt64});
563+
for (size_t i = 0; i < keyFrames.size(); ++i) {
564+
int64_t pts = keyFrames[i].pts;
565+
keyFrameIndices[i] =
566+
getKeyFrameIndexForPtsUsingScannedIndex(keyFrames, pts);
565567
}
566568

567569
return keyFrameIndices;

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,9 @@ class VideoDecoder {
100100
// Returns the metadata for the container.
101101
ContainerMetadata getContainerMetadata() const;
102102

103-
std::vector<int64_t> getKeyFrameIndices(int streamIndex);
103+
// Returns the key frame indices as a tensor. The tensor is 1D and contains
104+
// int64 values, where each value is the frame index for a key frame.
105+
torch::Tensor getKeyFrameIndices(int streamIndex);
104106

105107
// --------------------------------------------------------------------------
106108
// ADDING STREAMS API

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4848
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, int stream_index, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
4949
m.def(
5050
"get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] timestamps) -> (Tensor, Tensor, Tensor)");
51-
m.def("_get_key_frame_indices(Tensor(a!) decoder, int stream_index) -> int[]");
51+
m.def(
52+
"_get_key_frame_indices(Tensor(a!) decoder, int stream_index) -> Tensor");
5253
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
5354
m.def("get_container_json_metadata(Tensor(a!) decoder) -> str");
5455
m.def(
@@ -335,7 +336,7 @@ bool _test_frame_pts_equality(
335336
videoDecoder->getPtsSecondsForFrame(stream_index, frame_index);
336337
}
337338

338-
std::vector<int64_t> _get_key_frame_indices(
339+
torch::Tensor _get_key_frame_indices(
339340
at::Tensor& decoder,
340341
int64_t stream_index) {
341342
auto videoDecoder = unwrapTensorToGetDecoder(decoder);

src/torchcodec/decoders/_core/VideoDecoderOps.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,7 @@ bool _test_frame_pts_equality(
137137
int64_t frame_index,
138138
double pts_seconds_to_test);
139139

140-
std::vector<int64_t> get_key_frame_indices(
141-
at::Tensor& decoder,
142-
int64_t stream_index);
140+
torch::Tensor _get_key_frame_indices(at::Tensor& decoder, int64_t stream_index);
143141

144142
// Get the metadata from the video as a string.
145143
std::string get_json_metadata(at::Tensor& decoder);

src/torchcodec/decoders/_video_decoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,9 @@ def __getitem__(self, key: Union[numbers.Integral, slice]) -> Tensor:
186186
)
187187

188188
def _get_key_frame_indices(self) -> list[int]:
189-
return core._get_key_frame_indices(self._decoder, stream_index=self.stream_index)
189+
return core._get_key_frame_indices(
190+
self._decoder, stream_index=self.stream_index
191+
)
190192

191193
def get_frame_at(self, index: int) -> Frame:
192194
"""Return a single frame at the given index.

test/decoders/test_video_decoder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,9 @@ def test_get_frames_by_pts_in_range_fails(self, device, seek_mode):
835835
def test_get_key_frame_indices(self, device):
836836
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode="exact")
837837
key_frame_indices = decoder._get_key_frame_indices()
838-
assert len(key_frame_indices) > 0
838+
size = key_frame_indices.size()
839+
assert size[0] > 0
840+
assert len(size) == 1
839841

840842

841843
if __name__ == "__main__":

0 commit comments

Comments
 (0)