Skip to content

Commit 9349163

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into approx
2 parents 99b0d4f + 547dd36 commit 9349163

File tree

6 files changed

+34
-16
lines changed

6 files changed

+34
-16
lines changed

src/torchcodec/decoders/_core/CPUOnlyDevice.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ void releaseContextOnCuda(
3535
throwUnsupportedDeviceError(device);
3636
}
3737

38-
std::optional<AVCodecPtr> findCudaCodec(
38+
std::optional<const AVCodec*> findCudaCodec(
3939
const torch::Device& device,
4040
[[maybe_unused]] const AVCodecID& codecId) {
4141
throwUnsupportedDeviceError(device);

src/torchcodec/decoders/_core/DeviceInterface.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ void releaseContextOnCuda(
4040
const torch::Device& device,
4141
AVCodecContext* codecContext);
4242

43-
std::optional<AVCodecPtr> findCudaCodec(
43+
std::optional<const AVCodec*> findCudaCodec(
4444
const torch::Device& device,
4545
const AVCodecID& codecId);
4646

src/torchcodec/decoders/_core/FFMPEGCommon.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,15 @@
1010

1111
namespace facebook::torchcodec {
1212

13+
AVCodecOnlyUseForCallingAVFindBestStream
14+
makeAVCodecOnlyUseForCallingAVFindBestStream(const AVCodec* codec) {
15+
#if LIBAVCODEC_VERSION_INT < AV_VERSION_INT(59, 18, 100)
16+
return const_cast<AVCodec*>(codec);
17+
#else
18+
return codec;
19+
#endif
20+
}
21+
1322
std::string getFFMPEGErrorStringFromErrorCode(int errorCode) {
1423
char errorBuffer[AV_ERROR_MAX_STRING_SIZE] = {0};
1524
av_strerror(errorCode, errorBuffer, AV_ERROR_MAX_STRING_SIZE);

src/torchcodec/decoders/_core/FFMPEGCommon.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,17 @@ using UniqueSwsContext =
7575
// which was released in FFMPEG version=5.0.3
7676
// with libavcodec's version=59.18.100
7777
// (https://www.ffmpeg.org/olddownload.html).
78+
// Note that the alias is so-named so that it is only used when interacting with
79+
// av_find_best_stream(). It is not needed elsewhere.
7880
#if LIBAVCODEC_VERSION_INT < AV_VERSION_INT(59, 18, 100)
79-
using AVCodecPtr = AVCodec*;
81+
using AVCodecOnlyUseForCallingAVFindBestStream = AVCodec*;
8082
#else
81-
using AVCodecPtr = const AVCodec*;
83+
using AVCodecOnlyUseForCallingAVFindBestStream = const AVCodec*;
8284
#endif
8385

86+
AVCodecOnlyUseForCallingAVFindBestStream
87+
makeAVCodecOnlyUseForCallingAVFindBestStream(const AVCodec* codec);
88+
8489
// Success code from FFMPEG is just a 0. We define it to make the code more
8590
// readable.
8691
const int AVSUCCESS = 0;

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,7 @@ void VideoDecoder::createFilterGraph(
436436
}
437437

438438
int VideoDecoder::getBestStreamIndex(AVMediaType mediaType) {
439-
AVCodecPtr codec = nullptr;
439+
AVCodecOnlyUseForCallingAVFindBestStream codec = nullptr;
440440
int streamNumber =
441441
av_find_best_stream(formatContext_.get(), mediaType, -1, -1, &codec, 0);
442442
return streamNumber;
@@ -452,7 +452,7 @@ void VideoDecoder::addVideoStreamDecoder(
452452
}
453453
TORCH_CHECK(formatContext_.get() != nullptr);
454454

455-
AVCodecPtr codec = nullptr;
455+
AVCodecOnlyUseForCallingAVFindBestStream codec = nullptr;
456456
int streamNumber = av_find_best_stream(
457457
formatContext_.get(),
458458
AVMEDIA_TYPE_VIDEO,
@@ -465,14 +465,6 @@ void VideoDecoder::addVideoStreamDecoder(
465465
}
466466
TORCH_CHECK(codec != nullptr);
467467

468-
StreamMetadata& streamMetadata = containerMetadata_.streams[streamNumber];
469-
if (seekMode_ == SeekMode::approximate &&
470-
!streamMetadata.averageFps.has_value()) {
471-
throw std::runtime_error(
472-
"Seek mode is approximate, but stream " + std::to_string(streamNumber) +
473-
" does not have an average fps in its metadata.");
474-
}
475-
476468
StreamInfo& streamInfo = streams_[streamNumber];
477469
streamInfo.streamIndex = streamNumber;
478470
streamInfo.timeBase = formatContext_->streams[streamNumber]->time_base;
@@ -485,8 +477,17 @@ void VideoDecoder::addVideoStreamDecoder(
485477
}
486478

487479
if (options.device.type() == torch::kCUDA) {
488-
codec = findCudaCodec(options.device, streamInfo.stream->codecpar->codec_id)
489-
.value_or(codec);
480+
codec = makeAVCodecOnlyUseForCallingAVFindBestStream(
481+
findCudaCodec(options.device, streamInfo.stream->codecpar->codec_id)
482+
.value_or(codec));
483+
}
484+
485+
StreamMetadata& streamMetadata = containerMetadata_.streams[streamNumber];
486+
if (seekMode_ == SeekMode::approximate &&
487+
!streamMetadata.averageFps.has_value()) {
488+
throw std::runtime_error(
489+
"Seek mode is approximate, but stream " + std::to_string(streamNumber) +
490+
" does not have an average fps in its metadata.");
490491
}
491492

492493
AVCodecContext* codecContext = avcodec_alloc_context3(codec);

test/decoders/test_video_decoder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,9 @@ def test_get_frames_at_fails(self, device, seek_mode):
428428

429429
@pytest.mark.parametrize("device", cpu_and_cuda())
430430
def test_get_frame_at_av1(self, device):
431+
if in_fbcode() and device == "cuda":
432+
return
433+
431434
decoder = VideoDecoder(AV1_VIDEO.path, device=device)
432435
ref_frame10 = AV1_VIDEO.get_frame_data_by_index(10)
433436
ref_frame_info10 = AV1_VIDEO.get_frame_info(10)

0 commit comments

Comments
 (0)