Skip to content

Commit 55a8f46

Browse files
committed
super basic Audio POC
1 parent eb88cd6 commit 55a8f46

File tree

2 files changed

+37
-6
lines changed

2 files changed

+37
-6
lines changed

src/torchcodec/_frame.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ class Frame(Iterable):
4141
def __post_init__(self):
4242
# This is called after __init__() when a Frame is created. We can run
4343
# input validation checks here.
44-
if not self.data.ndim == 3:
45-
raise ValueError(f"data must be 3-dimensional, got {self.data.shape = }")
44+
# if not self.data.ndim == 3:
45+
# raise ValueError(f"data must be 3-dimensional, got {self.data.shape = }")
4646
self.pts_seconds = float(self.pts_seconds)
4747
self.duration_seconds = float(self.duration_seconds)
4848

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ VideoDecoder::ColorConversionLibrary getDefaultColorConversionLibrary(
124124
torch::Tensor VideoDecoder::maybePermuteHWC2CHW(
125125
int streamIndex,
126126
torch::Tensor& hwcTensor) {
127+
return hwcTensor;
127128
if (streamInfos_[streamIndex].videoStreamOptions.dimensionOrder == "NHWC") {
128129
return hwcTensor;
129130
}
@@ -439,11 +440,12 @@ void VideoDecoder::addVideoStreamDecoder(
439440
activeStreamIndex_ == NO_ACTIVE_STREAM,
440441
"Can only add one single stream.");
441442
TORCH_CHECK(formatContext_.get() != nullptr);
443+
printf("Adding stream %d\n", preferredStreamIndex);
442444

443445
AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr;
444446
int streamIndex = av_find_best_stream(
445447
formatContext_.get(),
446-
AVMEDIA_TYPE_VIDEO,
448+
AVMEDIA_TYPE_AUDIO,
447449
preferredStreamIndex,
448450
-1,
449451
&avCodec,
@@ -458,7 +460,7 @@ void VideoDecoder::addVideoStreamDecoder(
458460
streamInfo.timeBase = formatContext_->streams[streamIndex]->time_base;
459461
streamInfo.stream = formatContext_->streams[streamIndex];
460462

461-
if (streamInfo.stream->codecpar->codec_type != AVMEDIA_TYPE_VIDEO) {
463+
if (streamInfo.stream->codecpar->codec_type != AVMEDIA_TYPE_AUDIO) {
462464
throw std::invalid_argument(
463465
"Stream with index " + std::to_string(streamIndex) +
464466
" is not a video stream.");
@@ -915,18 +917,47 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(
915917

916918
VideoDecoder::FrameOutput VideoDecoder::convertAVFrameToFrameOutput(
917919
VideoDecoder::AVFrameStream& avFrameStream,
918-
std::optional<torch::Tensor> preAllocatedOutputTensor) {
920+
[[maybe_unused]] std::optional<torch::Tensor> preAllocatedOutputTensor) {
919921
// Convert the frame to tensor.
920922
FrameOutput frameOutput;
921923
int streamIndex = avFrameStream.streamIndex;
922924
AVFrame* avFrame = avFrameStream.avFrame.get();
923925
frameOutput.streamIndex = streamIndex;
924926
auto& streamInfo = streamInfos_[streamIndex];
925-
TORCH_CHECK(streamInfo.stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO);
927+
TORCH_CHECK(streamInfo.stream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO);
926928
frameOutput.ptsSeconds = ptsToSeconds(
927929
avFrame->pts, formatContext_->streams[streamIndex]->time_base);
928930
frameOutput.durationSeconds = ptsToSeconds(
929931
getDuration(avFrame), formatContext_->streams[streamIndex]->time_base);
932+
933+
auto numSamples = avFrame->nb_samples;
934+
auto sampleRate = avFrame->sample_rate;
935+
auto numChannels = avFrame->ch_layout.nb_channels;
936+
937+
printf("numSamples: %d\n", numSamples);
938+
printf("sample rate: %d\n", sampleRate);
939+
940+
printf("numChannels: %d\n", numChannels);
941+
int bytesPerSample =
942+
av_get_bytes_per_sample(streamInfo.codecContext->sample_fmt);
943+
printf("bytes per sample: %d\n", bytesPerSample);
944+
945+
// Assuming format is FLTP (float 32bits ???)
946+
947+
// This is slow, use accessor. or just memcpy?
948+
torch::Tensor data = torch::empty({numChannels, numSamples}, torch::kFloat32);
949+
for (auto channel = 0; channel < numChannels; ++channel) {
950+
// auto channelDataPtr = data[channel].data_ptr<uint8_t>();
951+
// std::memcpy(channelDataPtr, avFrame->data[channel], numSamples *
952+
// bytesPerSample);
953+
float* dataFloatPtr = (float*)(avFrame->data[channel]);
954+
for (auto sampleIndex = 0; sampleIndex < numSamples; ++sampleIndex) {
955+
data[channel][sampleIndex] = dataFloatPtr[sampleIndex];
956+
}
957+
}
958+
frameOutput.data = data;
959+
return frameOutput;
960+
930961
// TODO: we should fold preAllocatedOutputTensor into AVFrameStream.
931962
if (streamInfo.videoStreamOptions.device.type() == torch::kCPU) {
932963
convertAVFrameToFrameOutputOnCPU(

0 commit comments

Comments
 (0)