Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/torchcodec/_core/BetaCudaDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,7 @@ UniqueAVFrame BetaCudaDeviceInterface::transferCpuFrameToGpuNV12(
void BetaCudaDeviceInterface::convertAVFrameToFrameOutput(
UniqueAVFrame& avFrame,
FrameOutput& frameOutput,
[[maybe_unused]] AVMediaType mediaType,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
UniqueAVFrame gpuFrame =
cpuFallback_ ? transferCpuFrameToGpuNV12(avFrame) : std::move(avFrame);
Expand Down
1 change: 1 addition & 0 deletions src/torchcodec/_core/BetaCudaDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class BetaCudaDeviceInterface : public DeviceInterface {
void convertAVFrameToFrameOutput(
UniqueAVFrame& avFrame,
FrameOutput& frameOutput,
AVMediaType mediaType,
std::optional<torch::Tensor> preAllocatedOutputTensor =
std::nullopt) override;

Expand Down
141 changes: 138 additions & 3 deletions src/torchcodec/_core/CpuDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ void CpuDeviceInterface::initializeVideo(
initialized_ = true;
}

void CpuDeviceInterface::initializeAudio(
const AudioStreamOptions& audioStreamOptions) {
audioStreamOptions_ = audioStreamOptions;
initialized_ = true;
}

ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary(
const FrameDims& outputDims) const {
// swscale requires widths to be multiples of 32:
Expand Down Expand Up @@ -114,6 +120,21 @@ ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary(
}
}

void CpuDeviceInterface::convertAVFrameToFrameOutput(
UniqueAVFrame& avFrame,
FrameOutput& frameOutput,
AVMediaType mediaType,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
TORCH_CHECK(initialized_, "CpuDeviceInterface was not initialized.");

if (mediaType == AVMEDIA_TYPE_AUDIO) {
convertAudioAVFrameToFrameOutput(avFrame, frameOutput);
} else {
convertVideoAVFrameToFrameOutput(
avFrame, frameOutput, preAllocatedOutputTensor);
}
}

// Note [preAllocatedOutputTensor with swscale and filtergraph]:
// Callers may pass a pre-allocated tensor, where the output.data tensor will
// be stored. This parameter is honored in any case, but it only leads to a
Expand All @@ -123,12 +144,10 @@ ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary(
// TODO: Figure out whether that's possible!
// Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
// `dimension_order` parameter. It's up to callers to re-shape it if needed.
void CpuDeviceInterface::convertAVFrameToFrameOutput(
void CpuDeviceInterface::convertVideoAVFrameToFrameOutput(
UniqueAVFrame& avFrame,
FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
TORCH_CHECK(initialized_, "CpuDeviceInterface was not initialized.");

// Note that we ignore the dimensions from the metadata; we don't even bother
// storing them. The resized dimensions take priority. If we don't have any,
// then we use the dimensions from the actual decoded frame. We use the actual
Expand Down Expand Up @@ -278,6 +297,122 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
return rgbAVFrameToTensor(filterGraph_->convert(avFrame));
}

void CpuDeviceInterface::convertAudioAVFrameToFrameOutput(
UniqueAVFrame& srcAVFrame,
FrameOutput& frameOutput) {
AVSampleFormat srcSampleFormat =
static_cast<AVSampleFormat>(srcAVFrame->format);
AVSampleFormat outSampleFormat = AV_SAMPLE_FMT_FLTP;

int srcSampleRate = srcAVFrame->sample_rate;
int outSampleRate = audioStreamOptions_.sampleRate.value_or(srcSampleRate);

int srcNumChannels = getNumChannels(codecContext_);
TORCH_CHECK(
srcNumChannels == getNumChannels(srcAVFrame),
"The frame has ",
getNumChannels(srcAVFrame),
" channels, expected ",
srcNumChannels,
". If you are hitting this, it may be because you are using "
"a buggy FFmpeg version. FFmpeg4 is known to fail here in some "
"valid scenarios. Try to upgrade FFmpeg?");
int outNumChannels = audioStreamOptions_.numChannels.value_or(srcNumChannels);

bool mustConvert =
(srcSampleFormat != outSampleFormat || srcSampleRate != outSampleRate ||
srcNumChannels != outNumChannels);

UniqueAVFrame convertedAVFrame;
if (mustConvert) {
if (!swrContext_) {
swrContext_.reset(createSwrContext(
srcSampleFormat,
outSampleFormat,
srcSampleRate,
outSampleRate,
srcAVFrame,
outNumChannels));
}

convertedAVFrame = convertAudioAVFrameSamples(
swrContext_,
srcAVFrame,
outSampleFormat,
outSampleRate,
outNumChannels);
}
const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame;

AVSampleFormat format = static_cast<AVSampleFormat>(avFrame->format);
TORCH_CHECK(
format == outSampleFormat,
"Something went wrong, the frame didn't get converted to the desired format. ",
"Desired format = ",
av_get_sample_fmt_name(outSampleFormat),
"source format = ",
av_get_sample_fmt_name(format));

int numChannels = getNumChannels(avFrame);
TORCH_CHECK(
numChannels == outNumChannels,
"Something went wrong, the frame didn't get converted to the desired ",
"number of channels = ",
outNumChannels,
". Got ",
numChannels,
" instead.");

auto numSamples = avFrame->nb_samples;

frameOutput.data = torch::empty({numChannels, numSamples}, torch::kFloat32);

if (numSamples > 0) {
uint8_t* outputChannelData =
static_cast<uint8_t*>(frameOutput.data.data_ptr());
auto numBytesPerChannel = numSamples * av_get_bytes_per_sample(format);
for (auto channel = 0; channel < numChannels;
++channel, outputChannelData += numBytesPerChannel) {
std::memcpy(
outputChannelData,
avFrame->extended_data[channel],
numBytesPerChannel);
}
}
}

std::optional<torch::Tensor> CpuDeviceInterface::maybeFlushAudioBuffers() {
// When sample rate conversion is involved, swresample buffers some of the
// samples in-between calls to swr_convert (see the libswresample docs).
// That's because the last few samples in a given frame require future
// samples from the next frame to be properly converted. This function
// flushes out the samples that are stored in swresample's buffers.
if (!swrContext_) {
return std::nullopt;
}
auto numRemainingSamples = swr_get_out_samples(swrContext_.get(), 0);

if (numRemainingSamples == 0) {
return std::nullopt;
}

int numChannels =
audioStreamOptions_.numChannels.value_or(getNumChannels(codecContext_));
torch::Tensor lastSamples =
torch::empty({numChannels, numRemainingSamples}, torch::kFloat32);

std::vector<uint8_t*> outputBuffers(numChannels);
for (auto i = 0; i < numChannels; i++) {
outputBuffers[i] = static_cast<uint8_t*>(lastSamples[i].data_ptr());
}

auto actualNumRemainingSamples = swr_convert(
swrContext_.get(), outputBuffers.data(), numRemainingSamples, nullptr, 0);

return lastSamples.narrow(
/*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples);
}

std::string CpuDeviceInterface::getDetails() {
return std::string("CPU Device Interface.");
}
Expand Down
19 changes: 19 additions & 0 deletions src/torchcodec/_core/CpuDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,30 @@ class CpuDeviceInterface : public DeviceInterface {
const std::vector<std::unique_ptr<Transform>>& transforms,
const std::optional<FrameDims>& resizedOutputDims) override;

virtual void initializeAudio(
const AudioStreamOptions& audioStreamOptions) override;

virtual std::optional<torch::Tensor> maybeFlushAudioBuffers() override;

void convertAVFrameToFrameOutput(
UniqueAVFrame& avFrame,
FrameOutput& frameOutput,
AVMediaType mediaType,
std::optional<torch::Tensor> preAllocatedOutputTensor =
std::nullopt) override;

std::string getDetails() override;

private:
void convertAudioAVFrameToFrameOutput(
UniqueAVFrame& srcAVFrame,
FrameOutput& frameOutput);

void convertVideoAVFrameToFrameOutput(
UniqueAVFrame& avFrame,
FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor);

int convertAVFrameToTensorUsingSwScale(
const UniqueAVFrame& avFrame,
torch::Tensor& outputTensor,
Expand Down Expand Up @@ -108,6 +123,10 @@ class CpuDeviceInterface : public DeviceInterface {
bool userRequestedSwScale_;

bool initialized_ = false;

// Audio-specific members
AudioStreamOptions audioStreamOptions_;
UniqueSwrContext swrContext_;
};

} // namespace facebook::torchcodec
4 changes: 3 additions & 1 deletion src/torchcodec/_core/CudaDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24(
void CudaDeviceInterface::convertAVFrameToFrameOutput(
UniqueAVFrame& avFrame,
FrameOutput& frameOutput,
[[maybe_unused]] AVMediaType mediaType,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
validatePreAllocatedTensorShape(preAllocatedOutputTensor, avFrame);

Expand Down Expand Up @@ -271,7 +272,8 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput(
} else {
// Reason 2 above. We need to do a full conversion which requires an
// actual CPU device.
cpuInterface_->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput);
cpuInterface_->convertAVFrameToFrameOutput(
avFrame, cpuFrameOutput, AVMEDIA_TYPE_VIDEO);
}

// Finally, we need to send the frame back to the GPU. Note that the
Expand Down
1 change: 1 addition & 0 deletions src/torchcodec/_core/CudaDeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class CudaDeviceInterface : public DeviceInterface {
void convertAVFrameToFrameOutput(
UniqueAVFrame& avFrame,
FrameOutput& frameOutput,
AVMediaType mediaType,
std::optional<torch::Tensor> preAllocatedOutputTensor =
std::nullopt) override;

Expand Down
16 changes: 16 additions & 0 deletions src/torchcodec/_core/DeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,21 @@ class DeviceInterface {
transforms,
[[maybe_unused]] const std::optional<FrameDims>& resizedOutputDims) {}

// Initialize the device with parameters specific to audio decoding. There is
// a default empty implementation.
virtual void initializeAudio(
[[maybe_unused]] const AudioStreamOptions& audioStreamOptions) {}

// Flush any remaining samples from the audio resampler buffer.
// When sample rate conversion is involved, some samples may be buffered
// between frames for proper interpolation. This function flushes those
// buffered samples.
// Returns an optional tensor containing the flushed samples, or std::nullopt
// if there are no buffered samples or audio is not supported.
virtual std::optional<torch::Tensor> maybeFlushAudioBuffers() {
return std::nullopt;
}

// In order for decoding to actually happen on an FFmpeg managed hardware
// device, we need to register the DeviceInterface managed
// AVHardwareDeviceContext with the AVCodecContext. We don't need to do this
Expand All @@ -75,6 +90,7 @@ class DeviceInterface {
virtual void convertAVFrameToFrameOutput(
UniqueAVFrame& avFrame,
FrameOutput& frameOutput,
AVMediaType mediaType,
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt) = 0;

// ------------------------------------------
Expand Down
Loading
Loading