@@ -41,6 +41,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4141 " get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)" );
4242 m.def (
4343 " get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices, bool sort_indices=False) -> (Tensor, Tensor, Tensor)" );
44+ m.def (
45+ " get_frames_at_ptss(Tensor(a!) decoder, *, int stream_index, int[] frame_ptss, bool sort_ptss=False) -> (Tensor, Tensor, Tensor)" );
4446 m.def (
4547 " get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)" );
4648 m.def (
@@ -209,6 +211,20 @@ OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds) {
209211 return makeOpsDecodedOutput (result);
210212}
211213
214+ OpsBatchDecodedOutput get_frames_at_ptss (
215+ at::Tensor& decoder,
216+ int64_t stream_index,
217+ at::IntArrayRef frame_ptss,
218+ bool sort_ptss) {
219+ auto videoDecoder = unwrapTensorToGetDecoder (decoder);
220+ std::vector<int64_t > framePtssVec (
221+ frame_ptss.begin (), frame_ptss.end ());
222+ auto result = videoDecoder->getFramesAtPtss (
223+ stream_index, framePtssVec, sort_ptss);
224+ return makeOpsBatchDecodedOutput (result);
225+ }
226+
227+
212228OpsDecodedOutput get_frame_at_index (
213229 at::Tensor& decoder,
214230 int64_t stream_index,
@@ -485,6 +501,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
485501 m.impl (" get_frame_at_pts" , &get_frame_at_pts);
486502 m.impl (" get_frame_at_index" , &get_frame_at_index);
487503 m.impl (" get_frames_at_indices" , &get_frames_at_indices);
504+ m.impl (" get_frames_at_ptss" , &get_frames_at_ptss);
488505 m.impl (" get_frames_in_range" , &get_frames_in_range);
489506 m.impl (" get_frames_by_pts_in_range" , &get_frames_by_pts_in_range);
490507 m.impl (" _test_frame_pts_equality" , &_test_frame_pts_equality);
0 commit comments