Skip to content

Commit 14e2876

Browse files
committed
Added logic
1 parent d475890 commit 14e2876

File tree

5 files changed

+45
-13
lines changed

5 files changed

+45
-13
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,10 +1090,43 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
10901090

10911091
VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtPtss(
10921092
int streamIndex,
1093-
const std::vector<int64_t>& framePtss,
1093+
const std::vector<double>& framePtss,
10941094
const bool sortPtss) {
1095-
return getFramesAtIndices(streamIndex, framePtss, sortPtss);
1096-
}
1095+
validateUserProvidedStreamIndex(streamIndex);
1096+
validateScannedAllStreams("getFramesAtPtss");
1097+
1098+
// The frame displayed at timestamp t and the one displayed at timestamp `t +
1099+
// eps` are probably the same frame, with the same index. The easiest way to
1100+
// avoid decoding that unique frame twice is to convert the input timestamps
1101+
// to indices, and leverage the de-duplication logic of getFramesAtIndices.
1102+
1103+
const auto& streamMetadata = containerMetadata_.streams[streamIndex];
1104+
const auto& stream = streams_[streamIndex];
1105+
double minSeconds = streamMetadata.minPtsSecondsFromScan.value();
1106+
double maxSeconds = streamMetadata.maxPtsSecondsFromScan.value();
1107+
1108+
std::vector<int64_t> frameIndices(framePtss.size());
1109+
for (auto i = 0; i < framePtss.size(); ++i) {
1110+
auto framePts = framePtss[i];
1111+
TORCH_CHECK(
1112+
framePts >= minSeconds && framePts < maxSeconds,
1113+
"frame pts is " + std::to_string(framePts) + "; must be in range [" +
1114+
std::to_string(minSeconds) + ", " + std::to_string(maxSeconds) +
1115+
").");
1116+
1117+
auto it = std::lower_bound(
1118+
stream.allFrames.begin(),
1119+
stream.allFrames.end(),
1120+
framePts,
1121+
[&stream](const FrameInfo& info, double start) {
1122+
return ptsToSeconds(info.nextPts, stream.timeBase) <= start;
1123+
});
1124+
int64_t frameIndex = it - stream.allFrames.begin();
1125+
frameIndices[i] = frameIndex;
1126+
}
1127+
1128+
return getFramesAtIndices(streamIndex, frameIndices, sortPtss);
1129+
}
10971130

10981131
VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
10991132
int streamIndex,

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ class VideoDecoder {
247247

248248
BatchDecodedOutput getFramesAtPtss(
249249
int streamIndex,
250-
const std::vector<int64_t>& framePtss,
250+
const std::vector<double>& framePtss,
251251
const bool sortPtss = false);
252252

253253
// Returns frames within a given range for a given stream as a single stacked

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
4242
m.def(
4343
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices, bool sort_indices=False) -> (Tensor, Tensor, Tensor)");
4444
m.def(
45-
"get_frames_at_ptss(Tensor(a!) decoder, *, int stream_index, int[] frame_ptss, bool sort_ptss=False) -> (Tensor, Tensor, Tensor)");
45+
"get_frames_at_ptss(Tensor(a!) decoder, *, int stream_index, float[] frame_ptss, bool sort_ptss=False) -> (Tensor, Tensor, Tensor)");
4646
m.def(
4747
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
4848
m.def(
@@ -214,17 +214,15 @@ OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds) {
214214
OpsBatchDecodedOutput get_frames_at_ptss(
215215
at::Tensor& decoder,
216216
int64_t stream_index,
217-
at::IntArrayRef frame_ptss,
217+
at::ArrayRef<double> frame_ptss,
218218
bool sort_ptss) {
219219
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
220-
std::vector<int64_t> framePtssVec(
221-
frame_ptss.begin(), frame_ptss.end());
222-
auto result = videoDecoder->getFramesAtPtss(
223-
stream_index, framePtssVec, sort_ptss);
220+
std::vector<double> framePtssVec(frame_ptss.begin(), frame_ptss.end());
221+
auto result =
222+
videoDecoder->getFramesAtPtss(stream_index, framePtssVec, sort_ptss);
224223
return makeOpsBatchDecodedOutput(result);
225224
}
226225

227-
228226
OpsDecodedOutput get_frame_at_index(
229227
at::Tensor& decoder,
230228
int64_t stream_index,

src/torchcodec/decoders/_core/VideoDecoderOps.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ OpsDecodedOutput get_frame_at_pts(at::Tensor& decoder, double seconds);
7979
OpsBatchDecodedOutput get_frames_at_ptss(
8080
at::Tensor& decoder,
8181
int64_t stream_index,
82-
at::IntArrayRef frame_ptss,
82+
at::ArrayRef<double> frame_ptss,
8383
bool sort_ptss = false);
8484

8585
// Return the frame that is visible at a given index in the video.

src/torchcodec/decoders/_core/video_decoder_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,13 @@ def get_frame_at_pts_abstract(
172172
torch.empty([], dtype=torch.float),
173173
)
174174

175+
175176
@register_fake("torchcodec_ns::get_frames_at_ptss")
176177
def get_frames_at_pts_abstract(
177178
decoder: torch.Tensor,
178179
*,
179180
stream_index: int,
180-
frame_ptss: List[int],
181+
frame_ptss: List[float],
181182
sort_ptss: bool = False,
182183
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
183184
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]

0 commit comments

Comments
 (0)