Skip to content

Commit ace0bd4

Browse files
committed
Create AddAudioStream
1 parent 0f50aba commit ace0bd4

File tree

8 files changed

+118
-43
lines changed

8 files changed

+118
-43
lines changed

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -418,8 +418,9 @@ VideoDecoder::VideoStreamOptions::VideoStreamOptions(
418418
}
419419
}
420420

421-
void VideoDecoder::addVideoStreamDecoder(
421+
void VideoDecoder::addStream(
422422
int streamIndex,
423+
AVMediaType mediaType,
423424
const VideoStreamOptions& videoStreamOptions) {
424425
TORCH_CHECK(
425426
activeStreamIndex_ == NO_ACTIVE_STREAM,
@@ -429,30 +430,37 @@ void VideoDecoder::addVideoStreamDecoder(
429430
AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr;
430431

431432
activeStreamIndex_ = av_find_best_stream(
432-
formatContext_.get(), AVMEDIA_TYPE_VIDEO, streamIndex, -1, &avCodec, 0);
433+
formatContext_.get(), mediaType, streamIndex, -1, &avCodec, 0);
434+
433435
if (activeStreamIndex_ < 0) {
434-
throw std::invalid_argument("No valid stream found in input file.");
436+
throw std::invalid_argument(
437+
"No valid stream found in input file. Is " +
438+
std::to_string(streamIndex) + " of the desired media type?");
435439
}
440+
436441
TORCH_CHECK(avCodec != nullptr);
437442

438443
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
439444
streamInfo.streamIndex = activeStreamIndex_;
440445
streamInfo.timeBase = formatContext_->streams[activeStreamIndex_]->time_base;
441446
streamInfo.stream = formatContext_->streams[activeStreamIndex_];
447+
streamInfo.avMediaType = mediaType;
442448

443-
if (streamInfo.stream->codecpar->codec_type != AVMEDIA_TYPE_VIDEO) {
444-
throw std::invalid_argument(
445-
"Stream with index " + std::to_string(activeStreamIndex_) +
446-
" is not a video stream.");
447-
}
449+
// This should never happen, checking just to be safe.
450+
TORCH_CHECK(
451+
streamInfo.stream->codecpar->codec_type == mediaType,
452+
"FFmpeg found stream with index ", activeStreamIndex_, " which is of the wrong media type.");
448453

449-
if (videoStreamOptions.device.type() == torch::kCUDA) {
454+
455+
if (mediaType == AVMEDIA_TYPE_VIDEO &&
456+
videoStreamOptions.device.type() == torch::kCUDA) {
450457
avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream(
451458
findCudaCodec(
452459
videoStreamOptions.device, streamInfo.stream->codecpar->codec_id)
453460
.value_or(avCodec));
454461
}
455462

463+
// TODO figure out whether this should be VIDEO only
456464
StreamMetadata& streamMetadata =
457465
containerMetadata_.allStreamMetadata[activeStreamIndex_];
458466
if (seekMode_ == SeekMode::approximate &&
@@ -465,37 +473,34 @@ void VideoDecoder::addVideoStreamDecoder(
465473

466474
AVCodecContext* codecContext = avcodec_alloc_context3(avCodec);
467475
TORCH_CHECK(codecContext != nullptr);
468-
codecContext->thread_count = videoStreamOptions.ffmpegThreadCount.value_or(0);
476+
codecContext->thread_count =
477+
videoStreamOptions.ffmpegThreadCount.value_or(0); // TODO VIDEO ONLY?
469478
streamInfo.codecContext.reset(codecContext);
470479

471480
int retVal = avcodec_parameters_to_context(
472481
streamInfo.codecContext.get(), streamInfo.stream->codecpar);
473482
TORCH_CHECK_EQ(retVal, AVSUCCESS);
474483

475-
if (videoStreamOptions.device.type() == torch::kCPU) {
476-
// No more initialization needed for CPU.
477-
} else if (videoStreamOptions.device.type() == torch::kCUDA) {
478-
initializeContextOnCuda(videoStreamOptions.device, codecContext);
479-
} else {
480-
TORCH_CHECK(
481-
false, "Invalid device type: " + videoStreamOptions.device.str());
484+
if (mediaType == AVMEDIA_TYPE_VIDEO) {
485+
if (videoStreamOptions.device.type() == torch::kCPU) {
486+
// No more initialization needed for CPU.
487+
} else if (videoStreamOptions.device.type() == torch::kCUDA) {
488+
initializeContextOnCuda(videoStreamOptions.device, codecContext);
489+
} else {
490+
TORCH_CHECK(
491+
false, "Invalid device type: " + videoStreamOptions.device.str());
492+
}
493+
streamInfo.videoStreamOptions = videoStreamOptions;
482494
}
483-
streamInfo.videoStreamOptions = videoStreamOptions;
484495

485496
retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr);
486497
if (retVal < AVSUCCESS) {
487498
throw std::invalid_argument(getFFMPEGErrorStringFromErrorCode(retVal));
488499
}
489500

490501
codecContext->time_base = streamInfo.stream->time_base;
491-
492-
containerMetadata_.allStreamMetadata[activeStreamIndex_].width =
493-
codecContext->width;
494-
containerMetadata_.allStreamMetadata[activeStreamIndex_].height =
495-
codecContext->height;
496-
auto codedId = codecContext->codec_id;
497502
containerMetadata_.allStreamMetadata[activeStreamIndex_].codecName =
498-
std::string(avcodec_get_name(codedId));
503+
std::string(avcodec_get_name(codecContext->codec_id));
499504

500505
// We will only need packets from the active stream, so we tell FFmpeg to
501506
// discard packets from the other streams. Note that av_read_frame() may still
@@ -506,6 +511,18 @@ void VideoDecoder::addVideoStreamDecoder(
506511
formatContext_->streams[i]->discard = AVDISCARD_ALL;
507512
}
508513
}
514+
}
515+
516+
void VideoDecoder::addVideoStream(
517+
int streamIndex,
518+
const VideoStreamOptions& videoStreamOptions) {
519+
addStream(streamIndex, AVMEDIA_TYPE_VIDEO, videoStreamOptions);
520+
521+
auto& streamInfo = streamInfos_[activeStreamIndex_];
522+
containerMetadata_.allStreamMetadata[activeStreamIndex_].width =
523+
streamInfo.codecContext->width;
524+
containerMetadata_.allStreamMetadata[activeStreamIndex_].height =
525+
streamInfo.codecContext->height;
509526

510527
// By default, we want to use swscale for color conversion because it is
511528
// faster. However, it has width requirements, so we may need to fall back
@@ -514,7 +531,7 @@ void VideoDecoder::addVideoStreamDecoder(
514531
// swscale's width requirements to be violated. We don't expose the ability to
515532
// choose color conversion library publicly; we only use this ability
516533
// internally.
517-
int width = videoStreamOptions.width.value_or(codecContext->width);
534+
int width = videoStreamOptions.width.value_or(streamInfo.codecContext->width);
518535

519536
// swscale requires widths to be multiples of 32:
520537
// https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements
@@ -527,6 +544,10 @@ void VideoDecoder::addVideoStreamDecoder(
527544
videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary);
528545
}
529546

547+
void VideoDecoder::addAudioStream(int streamIndex) {
548+
addStream(streamIndex, AVMEDIA_TYPE_AUDIO);
549+
}
550+
530551
// --------------------------------------------------------------------------
531552
// HIGH-LEVEL DECODING ENTRY-POINTS
532553
// --------------------------------------------------------------------------
@@ -1051,7 +1072,6 @@ VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
10511072
AVFrame* avFrame = avFrameStream.avFrame.get();
10521073
frameOutput.streamIndex = streamIndex;
10531074
auto& streamInfo = streamInfos_[streamIndex];
1054-
TORCH_CHECK(streamInfo.stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO);
10551075
frameOutput.ptsSeconds = ptsToSeconds(
10561076
avFrame->pts, formatContext_->streams[streamIndex]->time_base);
10571077
frameOutput.durationSeconds = ptsToSeconds(

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,10 @@ class VideoDecoder {
136136

137137
struct AudioStreamOptions {};
138138

139-
void addVideoStreamDecoder(
139+
void addVideoStream(
140140
int streamIndex,
141141
const VideoStreamOptions& videoStreamOptions = VideoStreamOptions());
142-
void addAudioStreamDecoder(
143-
int streamIndex,
144-
const AudioStreamOptions& audioStreamOptions = AudioStreamOptions());
142+
void addAudioStream(int streamIndex);
145143

146144
// --------------------------------------------------------------------------
147145
// DECODING AND SEEKING APIs
@@ -322,6 +320,8 @@ class VideoDecoder {
322320
struct StreamInfo {
323321
int streamIndex = -1;
324322
AVStream* stream = nullptr;
323+
AVMediaType avMediaType = AVMEDIA_TYPE_UNKNOWN;
324+
325325
AVRational timeBase = {};
326326
UniqueAVCodecContext codecContext;
327327

@@ -433,6 +433,11 @@ class VideoDecoder {
433433
// STREAM AND METADATA APIS
434434
// --------------------------------------------------------------------------
435435

436+
void addStream(
437+
int streamIndex,
438+
AVMediaType mediaType,
439+
const VideoStreamOptions& videoStreamOptions = VideoStreamOptions());
440+
436441
// Returns the "best" stream index for a given media type. The "best" is
437442
// determined by various heuristics in FFMPEG.
438443
// See

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
3434
"_add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, str? color_conversion_library=None) -> ()");
3535
m.def(
3636
"add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None) -> ()");
37+
m.def(
38+
"add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None) -> ()");
3739
m.def("seek_to_pts(Tensor(a!) decoder, float seconds) -> ()");
3840
m.def("get_next_frame(Tensor(a!) decoder) -> (Tensor, Tensor, Tensor)");
3941
m.def(
@@ -220,8 +222,14 @@ void _add_video_stream(
220222
}
221223

222224
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
223-
videoDecoder->addVideoStreamDecoder(
224-
stream_index.value_or(-1), videoStreamOptions);
225+
videoDecoder->addVideoStream(stream_index.value_or(-1), videoStreamOptions);
226+
}
227+
228+
void add_audio_stream(
229+
at::Tensor& decoder,
230+
std::optional<int64_t> stream_index) {
231+
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
232+
videoDecoder->addAudioStream(stream_index.value_or(-1));
225233
}
226234

227235
void seek_to_pts(at::Tensor& decoder, double seconds) {
@@ -533,6 +541,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
533541
m.impl("seek_to_pts", &seek_to_pts);
534542
m.impl("add_video_stream", &add_video_stream);
535543
m.impl("_add_video_stream", &_add_video_stream);
544+
m.impl("add_audio_stream", &add_audio_stream);
536545
m.impl("get_next_frame", &get_next_frame);
537546
m.impl("_get_key_frame_indices", &_get_key_frame_indices);
538547
m.impl("get_json_metadata", &get_json_metadata);

src/torchcodec/decoders/_core/VideoDecoderOps.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ void _add_video_stream(
5555
std::optional<std::string_view> device = std::nullopt,
5656
std::optional<std::string_view> color_conversion_library = std::nullopt);
5757

58+
void add_audio_stream(
59+
at::Tensor& decoder,
60+
std::optional<int64_t> stream_index = std::nullopt);
61+
5862
// Seek to a particular presentation timestamp in the video in seconds.
5963
void seek_to_pts(at::Tensor& decoder, double seconds);
6064

src/torchcodec/decoders/_core/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
_add_video_stream,
1616
_get_key_frame_indices,
1717
_test_frame_pts_equality,
18+
add_audio_stream,
1819
add_video_stream,
1920
create_from_bytes,
2021
create_from_file,

src/torchcodec/decoders/_core/video_decoder_ops.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def load_torchcodec_extension():
6969
)
7070
add_video_stream = torch.ops.torchcodec_ns.add_video_stream.default
7171
_add_video_stream = torch.ops.torchcodec_ns._add_video_stream.default
72+
add_audio_stream = torch.ops.torchcodec_ns.add_audio_stream.default
7273
seek_to_pts = torch.ops.torchcodec_ns.seek_to_pts.default
7374
get_next_frame = torch.ops.torchcodec_ns.get_next_frame.default
7475
get_frame_at_pts = torch.ops.torchcodec_ns.get_frame_at_pts.default
@@ -150,6 +151,15 @@ def add_video_stream_abstract(
150151
return
151152

152153

154+
@register_fake("torchcodec_ns::add_audio_stream")
155+
def add_audio_stream_abstract(
156+
decoder: torch.Tensor,
157+
*,
158+
stream_index: Optional[int] = None,
159+
) -> None:
160+
return
161+
162+
153163
@register_fake("torchcodec_ns::seek_to_pts")
154164
def seek_abstract(decoder: torch.Tensor, seconds: float) -> None:
155165
return

test/decoders/VideoDecoderTest.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ TEST(VideoDecoderTest, RespectsWidthAndHeightFromOptions) {
148148
VideoDecoder::VideoStreamOptions videoStreamOptions;
149149
videoStreamOptions.width = 100;
150150
videoStreamOptions.height = 120;
151-
decoder->addVideoStreamDecoder(-1, videoStreamOptions);
151+
decoder->addVideoStream(-1, videoStreamOptions);
152152
torch::Tensor tensor = decoder->getNextFrame().data;
153153
EXPECT_EQ(tensor.sizes(), std::vector<long>({3, 120, 100}));
154154
}
@@ -158,7 +158,7 @@ TEST(VideoDecoderTest, RespectsOutputTensorDimensionOrderFromOptions) {
158158
std::unique_ptr<VideoDecoder> decoder = std::make_unique<VideoDecoder>(path);
159159
VideoDecoder::VideoStreamOptions videoStreamOptions;
160160
videoStreamOptions.dimensionOrder = "NHWC";
161-
decoder->addVideoStreamDecoder(-1, videoStreamOptions);
161+
decoder->addVideoStream(-1, videoStreamOptions);
162162
torch::Tensor tensor = decoder->getNextFrame().data;
163163
EXPECT_EQ(tensor.sizes(), std::vector<long>({270, 480, 3}));
164164
}
@@ -167,7 +167,7 @@ TEST_P(VideoDecoderTest, ReturnsFirstTwoFramesOfVideo) {
167167
std::string path = getResourcePath("nasa_13013.mp4");
168168
std::unique_ptr<VideoDecoder> ourDecoder =
169169
createDecoderFromPath(path, GetParam());
170-
ourDecoder->addVideoStreamDecoder(-1);
170+
ourDecoder->addVideoStream(-1);
171171
auto output = ourDecoder->getNextFrame();
172172
torch::Tensor tensor0FromOurDecoder = output.data;
173173
EXPECT_EQ(tensor0FromOurDecoder.sizes(), std::vector<long>({3, 270, 480}));
@@ -206,7 +206,7 @@ TEST_P(VideoDecoderTest, DecodesFramesInABatchInNCHW) {
206206
ourDecoder->scanFileAndUpdateMetadataAndIndex();
207207
int bestVideoStreamIndex =
208208
*ourDecoder->getContainerMetadata().bestVideoStreamIndex;
209-
ourDecoder->addVideoStreamDecoder(bestVideoStreamIndex);
209+
ourDecoder->addVideoStream(bestVideoStreamIndex);
210210
// Frame with index 180 corresponds to timestamp 6.006.
211211
auto output = ourDecoder->getFramesAtIndices({0, 180});
212212
auto tensor = output.data;
@@ -228,7 +228,7 @@ TEST_P(VideoDecoderTest, DecodesFramesInABatchInNHWC) {
228228
ourDecoder->scanFileAndUpdateMetadataAndIndex();
229229
int bestVideoStreamIndex =
230230
*ourDecoder->getContainerMetadata().bestVideoStreamIndex;
231-
ourDecoder->addVideoStreamDecoder(
231+
ourDecoder->addVideoStream(
232232
bestVideoStreamIndex,
233233
VideoDecoder::VideoStreamOptions("dimension_order=NHWC"));
234234
// Frame with index 180 corresponds to timestamp 6.006.
@@ -250,7 +250,7 @@ TEST_P(VideoDecoderTest, SeeksCloseToEof) {
250250
std::string path = getResourcePath("nasa_13013.mp4");
251251
std::unique_ptr<VideoDecoder> ourDecoder =
252252
createDecoderFromPath(path, GetParam());
253-
ourDecoder->addVideoStreamDecoder(-1);
253+
ourDecoder->addVideoStream(-1);
254254
ourDecoder->setCursorPtsInSeconds(388388. / 30'000);
255255
auto output = ourDecoder->getNextFrame();
256256
EXPECT_EQ(output.ptsSeconds, 388'388. / 30'000);
@@ -263,7 +263,7 @@ TEST_P(VideoDecoderTest, GetsFramePlayedAtTimestamp) {
263263
std::string path = getResourcePath("nasa_13013.mp4");
264264
std::unique_ptr<VideoDecoder> ourDecoder =
265265
createDecoderFromPath(path, GetParam());
266-
ourDecoder->addVideoStreamDecoder(-1);
266+
ourDecoder->addVideoStream(-1);
267267
auto output = ourDecoder->getFramePlayedAt(6.006);
268268
EXPECT_EQ(output.ptsSeconds, 6.006);
269269
// The frame's duration is 0.033367 according to ffprobe,
@@ -293,7 +293,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) {
293293
std::string path = getResourcePath("nasa_13013.mp4");
294294
std::unique_ptr<VideoDecoder> ourDecoder =
295295
createDecoderFromPath(path, GetParam());
296-
ourDecoder->addVideoStreamDecoder(-1);
296+
ourDecoder->addVideoStream(-1);
297297
ourDecoder->setCursorPtsInSeconds(6.0);
298298
auto output = ourDecoder->getNextFrame();
299299
torch::Tensor tensor6FromOurDecoder = output.data;
@@ -393,7 +393,7 @@ TEST_P(VideoDecoderTest, PreAllocatedTensorFilterGraph) {
393393
ourDecoder->scanFileAndUpdateMetadataAndIndex();
394394
int bestVideoStreamIndex =
395395
*ourDecoder->getContainerMetadata().bestVideoStreamIndex;
396-
ourDecoder->addVideoStreamDecoder(
396+
ourDecoder->addVideoStream(
397397
bestVideoStreamIndex,
398398
VideoDecoder::VideoStreamOptions("color_conversion_library=filtergraph"));
399399
auto output =
@@ -410,7 +410,7 @@ TEST_P(VideoDecoderTest, PreAllocatedTensorSwscale) {
410410
ourDecoder->scanFileAndUpdateMetadataAndIndex();
411411
int bestVideoStreamIndex =
412412
*ourDecoder->getContainerMetadata().bestVideoStreamIndex;
413-
ourDecoder->addVideoStreamDecoder(
413+
ourDecoder->addVideoStream(
414414
bestVideoStreamIndex,
415415
VideoDecoder::VideoStreamOptions("color_conversion_library=swscale"));
416416
auto output =

test/decoders/test_video_decoder_ops.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from torchcodec.decoders._core import (
2020
_add_video_stream,
2121
_test_frame_pts_equality,
22+
add_audio_stream,
2223
add_video_stream,
2324
create_from_bytes,
2425
create_from_file,
@@ -63,6 +64,31 @@ def seek(self, pts: float):
6364

6465

6566
class TestOps:
67+
def test_add_stream(self):
68+
valid_video_stream, valid_audio_stream = 0, 1
69+
70+
decoder = create_from_file(str(NASA_VIDEO.path))
71+
add_video_stream(decoder, stream_index=valid_video_stream)
72+
with pytest.raises(RuntimeError, match="Can only add one single stream"):
73+
add_video_stream(decoder, stream_index=valid_video_stream)
74+
75+
decoder = create_from_file(str(NASA_VIDEO.path))
76+
add_audio_stream(decoder, stream_index=valid_audio_stream)
77+
with pytest.raises(RuntimeError, match="Can only add one single stream"):
78+
add_audio_stream(decoder, stream_index=valid_audio_stream)
79+
80+
decoder = create_from_file(str(NASA_VIDEO.path))
81+
with pytest.raises(
82+
ValueError, match=f"Is {valid_audio_stream} of the desired media type"
83+
):
84+
add_video_stream(decoder, stream_index=valid_audio_stream)
85+
86+
decoder = create_from_file(str(NASA_VIDEO.path))
87+
with pytest.raises(
88+
ValueError, match=f"Is {valid_video_stream} of the desired media type"
89+
):
90+
add_audio_stream(decoder, stream_index=valid_video_stream)
91+
6692
@pytest.mark.parametrize("device", cpu_and_cuda())
6793
def test_seek_and_next(self, device):
6894
decoder = create_from_file(str(NASA_VIDEO.path))

0 commit comments

Comments
 (0)