Skip to content

Commit d6a2cfa

Browse files
committed
Merge branch 'nvdec-rework-frame-ordering' into nvdec-h265
2 parents 7176788 + 70873bf commit d6a2cfa

13 files changed

+81
-40
lines changed

src/torchcodec/_core/BetaCudaDeviceInterface.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,17 @@ static UniqueCUvideodecoder createDecoder(CUVIDEOFORMAT* videoFormat) {
9797
"x",
9898
caps.nMaxHeight);
9999

100+
// See nMaxMBCount in cuviddec.h
101+
constexpr unsigned int macroblockConstant = 256;
100102
TORCH_CHECK(
101-
videoFormat->coded_width * videoFormat->coded_height / 256 <=
103+
videoFormat->coded_width * videoFormat->coded_height /
104+
macroblockConstant <=
102105
caps.nMaxMBCount,
103106
"Video is too large (too many macroblocks). "
104-
"Provided (width * height / 256): ",
105-
videoFormat->coded_width * videoFormat->coded_height / 256,
107+
"Provided (width * height / ",
108+
macroblockConstant,
109+
"): ",
110+
videoFormat->coded_width * videoFormat->coded_height / macroblockConstant,
106111
" vs supported:",
107112
caps.nMaxMBCount);
108113

@@ -376,8 +381,10 @@ void BetaCudaDeviceInterface::applyBSF(ReferenceAVPacket& packet) {
376381
// fields of the filtered packet into the original packet. The filtered packet
377382
// fields are re-set by av_packet_move_ref, so when it goes out of scope and
378383
// gets destructed, it's not going to affect the original packet.
379-
av_packet_unref(packet.get());
380-
av_packet_move_ref(packet.get(), filteredPacket.get());
384+
packet.reset(filteredPacket);
385+
// TODONVDEC P0: consider cleaner ways to do this. Maybe we should let
386+
// applyBSF return a new packet, and maybe that new packet needs to be a field
387+
// on the interface to avoid complex lifetime issues.
381388
}
382389

383390
// Parser triggers this callback within cuvidParseVideoData when a frame is
@@ -477,6 +484,9 @@ UniqueAVFrame BetaCudaDeviceInterface::convertCudaFrameToAVFrame(
477484
avFrame->format = AV_PIX_FMT_CUDA;
478485
avFrame->pts = dispInfo.timestamp;
479486

487+
// TODONVDEC P0: Zero division error!!!
488+
// TODONVDEC P0: Move AVRational arithmetic to FFMPEGCommon, and put the
489+
// similar SingleStreamDecoder stuff there too.
480490
unsigned int frameRateNum = videoFormat_.frame_rate.numerator;
481491
unsigned int frameRateDen = videoFormat_.frame_rate.denominator;
482492
int64_t duration = static_cast<int64_t>((frameRateDen * timeBase_.den)) /

src/torchcodec/_core/FFMPEGCommon.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ AVPacket* ReferenceAVPacket::operator->() {
3333
return avPacket_;
3434
}
3535

36+
void ReferenceAVPacket::reset(ReferenceAVPacket& other) {
37+
if (this != &other) {
38+
av_packet_unref(avPacket_);
39+
av_packet_move_ref(avPacket_, other.avPacket_);
40+
}
41+
}
42+
3643
AVCodecOnlyUseForCallingAVFindBestStream
3744
makeAVCodecOnlyUseForCallingAVFindBestStream(const AVCodec* codec) {
3845
#if LIBAVCODEC_VERSION_INT < AV_VERSION_INT(59, 18, 100)

src/torchcodec/_core/FFMPEGCommon.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ class ReferenceAVPacket {
135135
~ReferenceAVPacket();
136136
AVPacket* get();
137137
AVPacket* operator->();
138+
void reset(ReferenceAVPacket& other);
138139
};
139140

140141
// av_find_best_stream is not const-correct before commit:

src/torchcodec/_core/NVDECCache.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ class NVDECCache {
4949
// all these parameters match.
5050
struct CacheKey {
5151
cudaVideoCodec codecType;
52-
unsigned int width;
53-
unsigned int height;
52+
uint32_t width;
53+
uint32_t height;
5454
cudaVideoChromaFormat chromaFormat;
55-
unsigned int bitDepthLumaMinus8;
56-
unsigned char numDecodeSurfaces;
55+
uint32_t bitDepthLumaMinus8;
56+
uint8_t numDecodeSurfaces;
5757

5858
CacheKey() = delete;
5959

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -606,25 +606,34 @@ FrameOutput SingleStreamDecoder::getFrameAtIndexInternal(
606606
}
607607

608608
FrameBatchOutput SingleStreamDecoder::getFramesAtIndices(
609-
const std::vector<int64_t>& frameIndices) {
609+
const torch::Tensor& frameIndices) {
610610
validateActiveStream(AVMEDIA_TYPE_VIDEO);
611611

612-
auto indicesAreSorted =
613-
std::is_sorted(frameIndices.begin(), frameIndices.end());
612+
auto frameIndicesAccessor = frameIndices.accessor<int64_t, 1>();
613+
614+
bool indicesAreSorted = true;
615+
for (int64_t i = 1; i < frameIndices.numel(); ++i) {
616+
if (frameIndicesAccessor[i] < frameIndicesAccessor[i - 1]) {
617+
indicesAreSorted = false;
618+
break;
619+
}
620+
}
614621

615622
std::vector<size_t> argsort;
616623
if (!indicesAreSorted) {
617624
// if frameIndices is [13, 10, 12, 11]
618625
// when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want
619626
// to use to decode the frames
620627
// and argsort is [ 1, 3, 2, 0]
621-
argsort.resize(frameIndices.size());
628+
argsort.resize(frameIndices.numel());
622629
for (size_t i = 0; i < argsort.size(); ++i) {
623630
argsort[i] = i;
624631
}
625632
std::sort(
626-
argsort.begin(), argsort.end(), [&frameIndices](size_t a, size_t b) {
627-
return frameIndices[a] < frameIndices[b];
633+
argsort.begin(),
634+
argsort.end(),
635+
[&frameIndicesAccessor](size_t a, size_t b) {
636+
return frameIndicesAccessor[a] < frameIndicesAccessor[b];
628637
});
629638
}
630639

@@ -633,12 +642,12 @@ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices(
633642
const auto& streamInfo = streamInfos_[activeStreamIndex_];
634643
const auto& videoStreamOptions = streamInfo.videoStreamOptions;
635644
FrameBatchOutput frameBatchOutput(
636-
frameIndices.size(), videoStreamOptions, streamMetadata);
645+
frameIndices.numel(), videoStreamOptions, streamMetadata);
637646

638647
auto previousIndexInVideo = -1;
639-
for (size_t f = 0; f < frameIndices.size(); ++f) {
648+
for (int64_t f = 0; f < frameIndices.numel(); ++f) {
640649
auto indexInOutput = indicesAreSorted ? f : argsort[f];
641-
auto indexInVideo = frameIndices[indexInOutput];
650+
auto indexInVideo = frameIndicesAccessor[indexInOutput];
642651

643652
if ((f > 0) && (indexInVideo == previousIndexInVideo)) {
644653
// Avoid decoding the same frame twice
@@ -780,7 +789,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
780789
frameIndices[i] = secondsToIndexLowerBound(frameSeconds);
781790
}
782791

783-
return getFramesAtIndices(frameIndices);
792+
// TODO: Support tensors natively instead of a vector to avoid a copy.
793+
return getFramesAtIndices(torch::tensor(frameIndices));
784794
}
785795

786796
FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
@@ -1202,6 +1212,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
12021212
if (status == AVERROR_EOF) {
12031213
// End of file reached. We must drain the decoder
12041214
if (useCustomInterface) {
1215+
// TODONVDEC P0: Re-think this. This should be simpler.
12051216
AutoAVPacket eofAutoPacket;
12061217
ReferenceAVPacket eofPacket(eofAutoPacket);
12071218
eofPacket->data = nullptr;

src/torchcodec/_core/SingleStreamDecoder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ class SingleStreamDecoder {
106106

107107
// Returns frames at the given indices for a given stream as a single stacked
108108
// Tensor.
109-
FrameBatchOutput getFramesAtIndices(const std::vector<int64_t>& frameIndices);
109+
FrameBatchOutput getFramesAtIndices(const torch::Tensor& frameIndices);
110110

111111
// Returns frames within a given range. The range is defined by [start, stop).
112112
// The values retrieved from the range are: [start, start+step,

src/torchcodec/_core/custom_ops.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
5555
m.def(
5656
"get_frame_at_index(Tensor(a!) decoder, *, int frame_index) -> (Tensor, Tensor, Tensor)");
5757
m.def(
58-
"get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices) -> (Tensor, Tensor, Tensor)");
58+
"get_frames_at_indices(Tensor(a!) decoder, *, Tensor frame_indices) -> (Tensor, Tensor, Tensor)");
5959
m.def(
6060
"get_frames_in_range(Tensor(a!) decoder, *, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
6161
m.def(
@@ -384,11 +384,9 @@ OpsFrameOutput get_frame_at_index(at::Tensor& decoder, int64_t frame_index) {
384384
// Return the frames at given indices for a given stream
385385
OpsFrameBatchOutput get_frames_at_indices(
386386
at::Tensor& decoder,
387-
at::IntArrayRef frame_indices) {
387+
const at::Tensor& frame_indices) {
388388
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
389-
std::vector<int64_t> frameIndicesVec(
390-
frame_indices.begin(), frame_indices.end());
391-
auto result = videoDecoder->getFramesAtIndices(frameIndicesVec);
389+
auto result = videoDecoder->getFramesAtIndices(frame_indices);
392390
return makeOpsFrameBatchOutput(result);
393391
}
394392

src/torchcodec/_core/ops.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,9 @@ def load_torchcodec_shared_libraries():
114114
get_next_frame = torch.ops.torchcodec_ns.get_next_frame.default
115115
get_frame_at_pts = torch.ops.torchcodec_ns.get_frame_at_pts.default
116116
get_frame_at_index = torch.ops.torchcodec_ns.get_frame_at_index.default
117-
get_frames_at_indices = torch.ops.torchcodec_ns.get_frames_at_indices.default
117+
_get_frames_at_indices_tensor_input = (
118+
torch.ops.torchcodec_ns.get_frames_at_indices.default
119+
)
118120
get_frames_by_pts = torch.ops.torchcodec_ns.get_frames_by_pts.default
119121
get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default
120122
get_frames_by_pts_in_range = torch.ops.torchcodec_ns.get_frames_by_pts_in_range.default
@@ -198,6 +200,18 @@ def encode_audio_to_file_like(
198200
)
199201

200202

203+
def get_frames_at_indices(
204+
decoder: torch.Tensor, *, frame_indices: Union[torch.Tensor, list[int]]
205+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
206+
if isinstance(frame_indices, torch.Tensor):
207+
# Ensure indices is the correct dtype (int64)
208+
frame_indices = frame_indices.to(torch.int64)
209+
else:
210+
# Convert list to tensor for dispatch
211+
frame_indices = torch.tensor(frame_indices)
212+
return _get_frames_at_indices_tensor_input(decoder, frame_indices=frame_indices)
213+
214+
201215
# ==============================
202216
# Abstract impl for the operators. Needed by torch.compile.
203217
# ==============================
@@ -373,9 +387,7 @@ def get_frame_at_index_abstract(
373387

374388
@register_fake("torchcodec_ns::get_frames_at_indices")
375389
def get_frames_at_indices_abstract(
376-
decoder: torch.Tensor,
377-
*,
378-
frame_indices: List[int],
390+
decoder: torch.Tensor, *, frame_indices: Union[torch.Tensor, List[int]]
379391
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
380392
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
381393
return (

src/torchcodec/_samplers/video_clip_sampler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,8 @@ def _get_clips_for_index_based_sampling(
227227
clip_start_idx + i * index_based_sampler_args.video_frame_dilation
228228
for i in range(index_based_sampler_args.frames_per_clip)
229229
]
230+
# Need torch.stack to convert List[Tensor[int]] into 1D Tensor[int]
231+
batch_indexes = torch.stack(batch_indexes)
230232
frames, *_ = get_frames_at_indices(
231233
video_decoder,
232234
frame_indices=batch_indexes,

src/torchcodec/decoders/_video_decoder.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -253,24 +253,20 @@ def get_frame_at(self, index: int) -> Frame:
253253
duration_seconds=duration_seconds.item(),
254254
)
255255

256-
def get_frames_at(self, indices: list[int]) -> FrameBatch:
256+
def get_frames_at(self, indices: Union[torch.Tensor, list[int]]) -> FrameBatch:
257257
"""Return frames at the given indices.
258258
259259
Args:
260-
indices (list of int): The indices of the frames to retrieve.
260+
indices (torch.Tensor or list of int): The indices of the frames to retrieve.
261261
262262
Returns:
263263
FrameBatch: The frames at the given indices.
264264
"""
265-
if isinstance(indices, torch.Tensor):
266-
# TODO we should avoid converting tensors to lists and just let the
267-
# core ops and C++ code natively accept tensors. See
268-
# https://github.com/pytorch/torchcodec/issues/879
269-
indices = indices.to(torch.int).tolist()
270265

271266
data, pts_seconds, duration_seconds = core.get_frames_at_indices(
272267
self._decoder, frame_indices=indices
273268
)
269+
274270
return FrameBatch(
275271
data=data,
276272
pts_seconds=pts_seconds,

0 commit comments

Comments
 (0)