Skip to content

Commit d475890

Browse files
committed
scaffolding
1 parent f391582 commit d475890

File tree

6 files changed

+56
-2
lines changed

6 files changed

+56
-2
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,6 +1088,13 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10881088
return output;
10891089
}
10901090

1091+
VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtPtss(
1092+
int streamIndex,
1093+
const std::vector<int64_t>& framePtss,
1094+
const bool sortPtss) {
1095+
return getFramesAtIndices(streamIndex, framePtss, sortPtss);
1096+
}
1097+
10911098
VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
10921099
int streamIndex,
10931100
int64_t start,

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,7 @@ class VideoDecoder {
223223
// i.e. it will be returned when this function is called with seconds=5.0 or
224224
// seconds=5.999, etc.
225225
DecodedOutput getFrameDisplayedAtTimestampNoDemux(double seconds);
226+
226227
DecodedOutput getFrameAtIndex(
227228
int streamIndex,
228229
int64_t frameIndex,
@@ -243,6 +244,12 @@ class VideoDecoder {
243244
int streamIndex,
244245
const std::vector<int64_t>& frameIndices,
245246
const bool sortIndices = false);
247+
248+
BatchDecodedOutput getFramesAtPtss(
249+
int streamIndex,
250+
const std::vector<int64_t>& framePtss,
251+
const bool sortPtss = false);
252+
246253
// Returns frames within a given range for a given stream as a single stacked
247254
// Tensor. The range is defined by [start, stop). The values retrieved from
248255
// the range are:

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4141
"get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)");
4242
m.def(
4343
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices, bool sort_indices=False) -> (Tensor, Tensor, Tensor)");
44+
m.def(
45+
"get_frames_at_ptss(Tensor(a!) decoder, *, int stream_index, int[] frame_ptss, bool sort_ptss=False) -> (Tensor, Tensor, Tensor)");
4446
m.def(
4547
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
4648
m.def(
@@ -209,6 +211,20 @@ OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds) {
209211
return makeOpsDecodedOutput(result);
210212
}
211213

214+
OpsBatchDecodedOutput get_frames_at_ptss(
215+
at::Tensor& decoder,
216+
int64_t stream_index,
217+
at::IntArrayRef frame_ptss,
218+
bool sort_ptss) {
219+
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
220+
std::vector<int64_t> framePtssVec(
221+
frame_ptss.begin(), frame_ptss.end());
222+
auto result = videoDecoder->getFramesAtPtss(
223+
stream_index, framePtssVec, sort_ptss);
224+
return makeOpsBatchDecodedOutput(result);
225+
}
226+
227+
212228
OpsDecodedOutput get_frame_at_index(
213229
at::Tensor& decoder,
214230
int64_t stream_index,
@@ -485,6 +501,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
485501
m.impl("get_frame_at_pts", &get_frame_at_pts);
486502
m.impl("get_frame_at_index", &get_frame_at_index);
487503
m.impl("get_frames_at_indices", &get_frames_at_indices);
504+
m.impl("get_frames_at_ptss", &get_frames_at_ptss);
488505
m.impl("get_frames_in_range", &get_frames_in_range);
489506
m.impl("get_frames_by_pts_in_range", &get_frames_by_pts_in_range);
490507
m.impl("_test_frame_pts_equality", &_test_frame_pts_equality);

src/torchcodec/decoders/_core/VideoDecoderOps.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@ using OpsBatchDecodedOutput = std::tuple<at::Tensor, at::Tensor, at::Tensor>;
7575
// given timestamp T has T >= PTS and T < PTS + Duration.
7676
OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds);
7777

78+
// Return the frames at given ptss for a given stream
79+
OpsBatchDecodedOutput get_frames_at_ptss(
80+
at::Tensor& decoder,
81+
int64_t stream_index,
82+
at::IntArrayRef frame_ptss,
83+
bool sort_ptss = false);
84+
7885
// Return the frame that is visible at a given index in the video.
7986
OpsDecodedOutput get_frame_at_index(
8087
at::Tensor& decoder,
@@ -85,8 +92,7 @@ OpsDecodedOutput get_frame_at_index(
8592
// duration as tensors.
8693
OpsDecodedOutput get_next_frame(at::Tensor& decoder);
8794

88-
// Return the frames at a given index for a given stream as a single stacked
89-
// Tensor.
95+
// Return the frames at given indices for a given stream
9096
OpsBatchDecodedOutput get_frames_at_indices(
9197
at::Tensor& decoder,
9298
int64_t stream_index,

src/torchcodec/decoders/_core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
get_frame_at_index,
2323
get_frame_at_pts,
2424
get_frames_at_indices,
25+
get_frames_at_ptss,
2526
get_frames_by_pts_in_range,
2627
get_frames_in_range,
2728
get_json_metadata,

src/torchcodec/decoders/_core/video_decoder_ops.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def load_torchcodec_extension():
7171
get_frame_at_pts = torch.ops.torchcodec_ns.get_frame_at_pts.default
7272
get_frame_at_index = torch.ops.torchcodec_ns.get_frame_at_index.default
7373
get_frames_at_indices = torch.ops.torchcodec_ns.get_frames_at_indices.default
74+
get_frames_at_ptss = torch.ops.torchcodec_ns.get_frames_at_ptss.default
7475
get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default
7576
get_frames_by_pts_in_range = torch.ops.torchcodec_ns.get_frames_by_pts_in_range.default
7677
get_json_metadata = torch.ops.torchcodec_ns.get_json_metadata.default
@@ -171,6 +172,21 @@ def get_frame_at_pts_abstract(
171172
torch.empty([], dtype=torch.float),
172173
)
173174

175+
@register_fake("torchcodec_ns::get_frames_at_ptss")
176+
def get_frames_at_pts_abstract(
177+
decoder: torch.Tensor,
178+
*,
179+
stream_index: int,
180+
frame_ptss: List[int],
181+
sort_ptss: bool = False,
182+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
183+
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
184+
return (
185+
torch.empty(image_size),
186+
torch.empty([], dtype=torch.float),
187+
torch.empty([], dtype=torch.float),
188+
)
189+
174190

175191
@register_fake("torchcodec_ns::get_frame_at_index")
176192
def get_frame_at_index_abstract(

0 commit comments

Comments
 (0)