Skip to content

Commit 0219f92

Browse files
committed
Rough implementation of getting key frame indices
1 parent 3253c8f commit 0219f92

File tree

8 files changed

+45
-0
lines changed

8 files changed

+45
-0
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,20 @@ VideoDecoder::ContainerMetadata VideoDecoder::getContainerMetadata() const {
553553
return containerMetadata_;
554554
}
555555

556+
std::vector<int64_t> VideoDecoder::getKeyFrameIndices(int streamIndex) {
557+
validateUserProvidedStreamIndex(streamIndex);
558+
validateScannedAllStreams("getKeyFrameIndices");
559+
560+
std::vector<int64_t> keyFrameIndices;
561+
const StreamInfo& streamInfo = streamInfos_[streamIndex];
562+
for (const FrameInfo& frameInfo : streamInfo.keyFrames) {
563+
keyFrameIndices.push_back(
564+
getKeyFrameIndexForPtsUsingScannedIndex(streamInfo.keyFrames, frameInfo.pts));
565+
}
566+
567+
return keyFrameIndices;
568+
}
569+
556570
int VideoDecoder::getKeyFrameIndexForPtsUsingEncoderIndex(
557571
AVStream* stream,
558572
int64_t pts) const {

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ class VideoDecoder {
100100
// Returns the metadata for the container.
101101
ContainerMetadata getContainerMetadata() const;
102102

103+
std::vector<int64_t> getKeyFrameIndices(int streamIndex);
104+
103105
// --------------------------------------------------------------------------
104106
// ADDING STREAMS API
105107
// --------------------------------------------------------------------------

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4848
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, int stream_index, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
4949
m.def(
5050
"get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] timestamps) -> (Tensor, Tensor, Tensor)");
51+
m.def("get_key_frame_indices(Tensor(a!) decoder, int stream_index) -> int[]");
5152
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
5253
m.def("get_container_json_metadata(Tensor(a!) decoder) -> str");
5354
m.def(
@@ -334,6 +335,13 @@ bool _test_frame_pts_equality(
334335
videoDecoder->getPtsSecondsForFrame(stream_index, frame_index);
335336
}
336337

338+
std::vector<int64_t> get_key_frame_indices(
339+
at::Tensor& decoder,
340+
int64_t stream_index) {
341+
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
342+
return videoDecoder->getKeyFrameIndices(stream_index);
343+
}
344+
337345
std::string get_json_metadata(at::Tensor& decoder) {
338346
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
339347

@@ -526,6 +534,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
526534
m.impl("add_video_stream", &add_video_stream);
527535
m.impl("_add_video_stream", &_add_video_stream);
528536
m.impl("get_next_frame", &get_next_frame);
537+
m.impl("get_key_frame_indices", &get_key_frame_indices);
529538
m.impl("get_json_metadata", &get_json_metadata);
530539
m.impl("get_container_json_metadata", &get_container_json_metadata);
531540
m.impl("get_stream_json_metadata", &get_stream_json_metadata);

src/torchcodec/decoders/_core/VideoDecoderOps.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,10 @@ bool _test_frame_pts_equality(
137137
int64_t frame_index,
138138
double pts_seconds_to_test);
139139

140+
std::vector<int64_t> get_key_frame_indices(
141+
at::Tensor& decoder,
142+
int64_t stream_index);
143+
140144
// Get the metadata from the video as a string.
141145
std::string get_json_metadata(at::Tensor& decoder);
142146

src/torchcodec/decoders/_core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
get_frames_by_pts_in_range,
2727
get_frames_in_range,
2828
get_json_metadata,
29+
get_key_frame_indices,
2930
get_next_frame,
3031
scan_all_streams_to_update_metadata,
3132
seek_to_pts,

src/torchcodec/decoders/_core/video_decoder_ops.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def load_torchcodec_extension():
7777
get_frames_by_pts = torch.ops.torchcodec_ns.get_frames_by_pts.default
7878
get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default
7979
get_frames_by_pts_in_range = torch.ops.torchcodec_ns.get_frames_by_pts_in_range.default
80+
get_key_frame_indices = torch.ops.torchcodec_ns.get_key_frame_indices.default
8081
get_json_metadata = torch.ops.torchcodec_ns.get_json_metadata.default
8182
_test_frame_pts_equality = torch.ops.torchcodec_ns._test_frame_pts_equality.default
8283
_get_container_json_metadata = (
@@ -255,6 +256,11 @@ def get_frames_by_pts_in_range_abstract(
255256
)
256257

257258

259+
@register_fake("torchcodec_ns::get_key_frame_indices")
260+
def get_key_frame_indices_abstract(decoder: torch.Tensor, *, stream_index: int) -> List[int]:
261+
return []
262+
263+
258264
@register_fake("torchcodec_ns::get_json_metadata")
259265
def get_json_metadata_abstract(decoder: torch.Tensor) -> str:
260266
return ""

src/torchcodec/decoders/_video_decoder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,9 @@ def __getitem__(self, key: Union[numbers.Integral, slice]) -> Tensor:
185185
f"Unsupported key type: {type(key)}. Supported types are int and slice."
186186
)
187187

188+
def _get_key_frame_indices(self) -> list[int]:
189+
return core.get_key_frame_indices(self._decoder, stream_index=self.stream_index)
190+
188191
def get_frame_at(self, index: int) -> Frame:
189192
"""Return a single frame at the given index.
190193

test/decoders/test_video_decoder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -831,6 +831,12 @@ def test_get_frames_by_pts_in_range_fails(self, device, seek_mode):
831831
with pytest.raises(ValueError, match="Invalid stop seconds"):
832832
frame = decoder.get_frames_played_in_range(0, 23) # noqa
833833

834+
@pytest.mark.parametrize("device", cpu_and_cuda())
835+
def test_get_key_frame_indices(self, device):
836+
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode="exact")
837+
key_frame_indices = decoder._get_key_frame_indices()
838+
assert len(key_frame_indices) > 0
839+
834840

835841
if __name__ == "__main__":
836842
pytest.main()

0 commit comments

Comments
 (0)