@@ -48,6 +48,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4848 " get_frames_by_pts_in_range(Tensor(a!) decoder, *, int stream_index, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)" );
4949 m.def (
5050 " get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] timestamps) -> (Tensor, Tensor, Tensor)" );
51+ m.def (" get_key_frame_indices(Tensor(a!) decoder, int stream_index) -> int[]" );
5152 m.def (" get_json_metadata(Tensor(a!) decoder) -> str" );
5253 m.def (" get_container_json_metadata(Tensor(a!) decoder) -> str" );
5354 m.def (
@@ -334,6 +335,13 @@ bool _test_frame_pts_equality(
334335 videoDecoder->getPtsSecondsForFrame (stream_index, frame_index);
335336}
336337
338+ std::vector<int64_t > get_key_frame_indices (
339+ at::Tensor& decoder,
340+ int64_t stream_index) {
341+ auto videoDecoder = unwrapTensorToGetDecoder (decoder);
342+ return videoDecoder->getKeyFrameIndices (stream_index);
343+ }
344+
337345std::string get_json_metadata (at::Tensor& decoder) {
338346 auto videoDecoder = unwrapTensorToGetDecoder (decoder);
339347
@@ -526,6 +534,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
526534 m.impl (" add_video_stream" , &add_video_stream);
527535 m.impl (" _add_video_stream" , &_add_video_stream);
528536 m.impl (" get_next_frame" , &get_next_frame);
537+ m.impl (" get_key_frame_indices" , &get_key_frame_indices);
529538 m.impl (" get_json_metadata" , &get_json_metadata);
530539 m.impl (" get_container_json_metadata" , &get_container_json_metadata);
531540 m.impl (" get_stream_json_metadata" , &get_stream_json_metadata);
0 commit comments