Skip to content

Commit 466ceb4

Browse files
committed
MOOOOORE
1 parent d3bdfea commit 466ceb4

File tree

9 files changed

+41
-57
lines changed

9 files changed

+41
-57
lines changed

src/torchcodec/decoders/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from ._core import AudioStreamMetadata, VideoStreamMetadata
7+
from ._core import VideoStreamMetadata
88
from ._video_decoder import VideoDecoder # noqa
99

1010
SimpleVideoDecoder = VideoDecoder

src/torchcodec/decoders/_audio_decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
class AudioDecoder:
20-
"""TODO-audio docs"""
20+
"""TODO-AUDIO docs"""
2121

2222
def __init__(
2323
self,

src/torchcodec/decoders/_core/FFMPEGCommon.cpp

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,22 +60,15 @@ int64_t getDuration(const AVFrame* frame) {
6060
#endif
6161
}
6262

63-
int getNumChannels(const AVFrame* avFrame) {
63+
int64_t getNumChannels(const UniqueAVCodecContext& avCodecContext) {
6464
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
6565
(IBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
66-
return avFrame->ch_layout.nb_channels;
66+
int numChannels = avCodecContext->ch_layout.nb_channels;
6767
#else
68-
return av_get_channel_layout_nb_channels(avFrame->channel_layout);
68+
int numChannels = avCodecContext->channels;
6969
#endif
70-
}
7170

72-
int getNumChannels(const UniqueAVCodecContext& avCodecContext) {
73-
#if LIBAVFILTER_VERSION_MAJOR > 8 || \
74-
(IBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44)
75-
return avCodecContext->ch_layout.nb_channels;
76-
#else
77-
return avCodecContext->channels;
78-
#endif
71+
return static_cast<int64_t>(numChannels);
7972
}
8073

8174
AVIOBytesContext::AVIOBytesContext(

src/torchcodec/decoders/_core/FFMPEGCommon.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,7 @@ std::string getFFMPEGErrorStringFromErrorCode(int errorCode);
139139
int64_t getDuration(const UniqueAVFrame& frame);
140140
int64_t getDuration(const AVFrame* frame);
141141

142-
int getNumChannels(const AVFrame* avFrame);
143-
int getNumChannels(const UniqueAVCodecContext& avCodecContext);
142+
int64_t getNumChannels(const UniqueAVCodecContext& avCodecContext);
144143

145144
// Returns true if sws_scale can handle unaligned data.
146145
bool canSwsScaleHandleUnalignedData();

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -169,17 +169,6 @@ void VideoDecoder::initializeDecoder() {
169169
}
170170
containerMetadata_.numVideoStreams++;
171171
} else if (avStream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) {
172-
int numSamplesPerFrame = avStream->codecpar->frame_size;
173-
int sampleRate = avStream->codecpar->sample_rate;
174-
if (numSamplesPerFrame > 0 && sampleRate > 0) {
175-
// This should allow the approximate mode to do its magic.
176-
// fps is numFrames / duration where
177-
// - duration = numSamplesTotal / sampleRate and
178-
// - numSamplesTotal = numSamplesPerFrame * numFrames
179-
// so fps = numFrames * sampleRate / (numSamplesPerFrame * numFrames)
180-
streamMetadata.averageFps =
181-
static_cast<double>(sampleRate) / numSamplesPerFrame;
182-
}
183172
containerMetadata_.numAudioStreams++;
184173
}
185174

@@ -422,7 +411,7 @@ VideoDecoder::VideoStreamOptions::VideoStreamOptions(
422411
void VideoDecoder::addStream(
423412
int streamIndex,
424413
AVMediaType mediaType,
425-
const VideoStreamOptions& videoStreamOptions) {
414+
const torch::Device& device) {
426415
TORCH_CHECK(
427416
activeStreamIndex_ == NO_ACTIVE_STREAM,
428417
"Can only add one single stream.");
@@ -457,36 +446,25 @@ void VideoDecoder::addStream(
457446
activeStreamIndex_,
458447
" which is of the wrong media type.");
459448

460-
// TODO_CODE_QUALITY this is meh to have that in the middle
461-
if (mediaType == AVMEDIA_TYPE_VIDEO &&
462-
videoStreamOptions.device.type() == torch::kCUDA) {
449+
// TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within
450+
// addStream() which is supposed to be generic
451+
if (mediaType == AVMEDIA_TYPE_VIDEO && device.type() == torch::kCUDA) {
463452
avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream(
464-
findCudaCodec(
465-
videoStreamOptions.device, streamInfo.stream->codecpar->codec_id)
453+
findCudaCodec(device, streamInfo.stream->codecpar->codec_id)
466454
.value_or(avCodec));
467455
}
468456

469457
AVCodecContext* codecContext = avcodec_alloc_context3(avCodec);
470458
TORCH_CHECK(codecContext != nullptr);
471-
codecContext->thread_count =
472-
videoStreamOptions.ffmpegThreadCount.value_or(0); // TODO VIDEO ONLY?
473459
streamInfo.codecContext.reset(codecContext);
474460

475461
int retVal = avcodec_parameters_to_context(
476462
streamInfo.codecContext.get(), streamInfo.stream->codecpar);
477463
TORCH_CHECK_EQ(retVal, AVSUCCESS);
478464

479-
// TODO_CODE_QUALITY meh again
480-
if (mediaType == AVMEDIA_TYPE_VIDEO) {
481-
if (videoStreamOptions.device.type() == torch::kCPU) {
482-
// No more initialization needed for CPU.
483-
} else if (videoStreamOptions.device.type() == torch::kCUDA) {
484-
initializeContextOnCuda(videoStreamOptions.device, codecContext);
485-
} else {
486-
TORCH_CHECK(
487-
false, "Invalid device type: " + videoStreamOptions.device.str());
488-
}
489-
streamInfo.videoStreamOptions = videoStreamOptions;
465+
// TODO_CODE_QUALITY same as above.
466+
if (mediaType == AVMEDIA_TYPE_VIDEO && device.type() == torch::kCUDA) {
467+
initializeContextOnCuda(device, codecContext);
490468
}
491469

492470
retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr);
@@ -512,9 +490,16 @@ void VideoDecoder::addStream(
512490
void VideoDecoder::addVideoStream(
513491
int streamIndex,
514492
const VideoStreamOptions& videoStreamOptions) {
515-
addStream(streamIndex, AVMEDIA_TYPE_VIDEO, videoStreamOptions);
493+
TORCH_CHECK(
494+
videoStreamOptions.device.type() == torch::kCPU ||
495+
videoStreamOptions.device.type() == torch::kCUDA,
496+
"Invalid device type: " + videoStreamOptions.device.str());
497+
addStream(streamIndex, AVMEDIA_TYPE_VIDEO, videoStreamOptions.device);
516498

517499
auto& streamInfo = streamInfos_[activeStreamIndex_];
500+
streamInfo.codecContext->thread_count =
501+
videoStreamOptions.ffmpegThreadCount.value_or(0);
502+
518503
containerMetadata_.allStreamMetadata[activeStreamIndex_].width =
519504
streamInfo.codecContext->width;
520505
containerMetadata_.allStreamMetadata[activeStreamIndex_].height =
@@ -547,8 +532,12 @@ void VideoDecoder::addAudioStream(int streamIndex) {
547532

548533
addStream(streamIndex, AVMEDIA_TYPE_AUDIO);
549534

550-
containerMetadata_.allStreamMetadata[activeStreamIndex_].sampleRate =
551-
streamInfo.codecContext->sample_rate;
535+
auto& streamInfo = streamInfos_[activeStreamIndex_];
536+
auto& streamMetadata =
537+
containerMetadata_.allStreamMetadata[activeStreamIndex_];
538+
streamMetadata.sampleRate =
539+
static_cast<int64_t>(streamInfo.codecContext->sample_rate);
540+
streamMetadata.numChannels = getNumChannels(streamInfo.codecContext);
552541
}
553542

554543
// --------------------------------------------------------------------------

src/torchcodec/decoders/_core/VideoDecoder.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ class VideoDecoder {
8080

8181
// Audio-only fields
8282
std::optional<int64_t> sampleRate;
83+
std::optional<int64_t> numChannels;
8384
};
8485

8586
struct ContainerMetadata {
@@ -428,7 +429,7 @@ class VideoDecoder {
428429
void addStream(
429430
int streamIndex,
430431
AVMediaType mediaType,
431-
const VideoStreamOptions& videoStreamOptions = VideoStreamOptions());
432+
const torch::Device& device = torch::kCPU);
432433

433434
// Returns the "best" stream index for a given media type. The "best" is
434435
// determined by various heuristics in FFMPEG.

src/torchcodec/decoders/_core/VideoDecoderOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,9 @@ std::string get_stream_json_metadata(
488488
if (streamMetadata.sampleRate.has_value()) {
489489
map["sampleRate"] = std::to_string(*streamMetadata.sampleRate);
490490
}
491+
if (streamMetadata.numChannels.has_value()) {
492+
map["numChannels"] = std::to_string(*streamMetadata.numChannels);
493+
}
491494
if (streamMetadata.mediaType == AVMEDIA_TYPE_VIDEO) {
492495
map["mediaType"] = "\"video\"";
493496
} else if (streamMetadata.mediaType == AVMEDIA_TYPE_AUDIO) {

src/torchcodec/decoders/_core/_metadata.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
SPACES = " "
2323

2424

25-
# TODO-audio: docs below are mostly for video streams, we should edit them and /
25+
# TODO-AUDIO: docs below are mostly for video streams, we should edit them and /
2626
# or make sure they're OK for audio streams as well. Not sure how to best handle
2727
# docs for such class hierarchy.
2828
@dataclass
@@ -161,8 +161,9 @@ def __repr__(self):
161161
class AudioStreamMetadata(StreamMetadata):
162162
"""Metadata of a single audio stream."""
163163

164-
# TODO-AUDIO Need sample rate and format and num_channels
164+
# TODO-AUDIO Add sample format field
165165
sample_rate: Optional[int]
166+
num_channels: Optional[int]
166167

167168
def __repr__(self):
168169
return super().__repr__()
@@ -236,6 +237,7 @@ def get_container_metadata(decoder: torch.Tensor) -> ContainerMetadata:
236237
streams_metadata.append(
237238
AudioStreamMetadata(
238239
sample_rate=stream_dict.get("sampleRate"),
240+
num_channels=stream_dict.get("numChannels"),
239241
**common_meta,
240242
)
241243
)

test/decoders/test_video_decoder.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,9 @@
1111
import torch
1212
from torchcodec import FrameBatch
1313

14-
from torchcodec.decoders import (
15-
_core,
16-
AudioStreamMetadata,
17-
VideoDecoder,
18-
VideoStreamMetadata,
19-
)
14+
from torchcodec.decoders import _core, VideoDecoder, VideoStreamMetadata
2015
from torchcodec.decoders._audio_decoder import AudioDecoder
16+
from torchcodec.decoders._core._metadata import AudioStreamMetadata
2117

2218
from ..utils import (
2319
assert_frames_equal,
@@ -950,3 +946,4 @@ def test_metadata(self):
950946
assert decoder.stream_index == decoder.metadata.stream_index == 4
951947
assert decoder.metadata.duration_seconds == pytest.approx(13.056)
952948
assert decoder.metadata.sample_rate == 16_000
949+
assert decoder.metadata.num_channels == 2

0 commit comments

Comments
 (0)