11#include < sstream>
22
3+ #include " src/torchcodec/_core/AVIOBytesContext.h"
34#include " src/torchcodec/_core/Encoder.h"
45#include " torch/types.h"
56
@@ -40,9 +41,13 @@ AudioEncoder::~AudioEncoder() {}
4041AudioEncoder::AudioEncoder (
4142 const torch::Tensor wf,
4243 int sampleRate,
43- std::string_view fileName,
44+ std::optional<std::string_view> fileName,
45+ std::optional<std::string_view> formatName,
4446 std::optional<int64_t > bit_rate)
4547 : wf_(wf) {
48+ TORCH_CHECK (
49+ fileName.has_value () ^ formatName.has_value (),
50+ " Pass one of filename OR format, not both." );
4651 TORCH_CHECK (
4752 wf_.dtype () == torch::kFloat32 ,
4853 " waveform must have float32 dtype, got " ,
@@ -54,25 +59,32 @@ AudioEncoder::AudioEncoder(
5459
5560 setFFmpegLogLevel ();
5661 AVFormatContext* avFormatContext = nullptr ;
57- auto status = avformat_alloc_output_context2 (
58- &avFormatContext, nullptr , nullptr , fileName.data ());
62+ int status = AVSUCCESS;
63+ if (fileName.has_value ()) {
64+ status = avformat_alloc_output_context2 (
65+ &avFormatContext, nullptr , nullptr , fileName->data ());
66+ } else {
67+ status = avformat_alloc_output_context2 (
68+ &avFormatContext, nullptr , formatName->data (), nullptr );
69+ }
5970 TORCH_CHECK (
6071 avFormatContext != nullptr ,
6172 " Couldn't allocate AVFormatContext. " ,
6273 " Check the desired extension? " ,
6374 getFFMPEGErrorStringFromErrorCode (status));
6475 avFormatContext_.reset (avFormatContext);
6576
66- // TODO-ENCODING: Should also support encoding into bytes (use
67- // AVIOBytesContext)
68- TORCH_CHECK (
69- !(avFormatContext->oformat ->flags & AVFMT_NOFILE),
70- " AVFMT_NOFILE is set. We only support writing to a file." );
71- status = avio_open (&avFormatContext_->pb , fileName.data (), AVIO_FLAG_WRITE);
72- TORCH_CHECK (
73- status >= 0 ,
74- " avio_open failed: " ,
75- getFFMPEGErrorStringFromErrorCode (status));
77+ if (fileName.has_value ()) {
78+ status =
79+ avio_open (&avFormatContext_->pb , fileName->data (), AVIO_FLAG_WRITE);
80+ TORCH_CHECK (
81+ status >= 0 ,
82+ " avio_open failed: " ,
83+ getFFMPEGErrorStringFromErrorCode (status));
84+ } else {
85+ avioContextHolder_ = std::make_unique<AVIOToTensorContext>();
86+ avFormatContext->pb = avioContextHolder_->getAVIOContext ();
87+ }
7688
7789 // We use the AVFormatContext's default codec for that
7890 // specific format/container.
@@ -168,7 +180,18 @@ AVSampleFormat AudioEncoder::findOutputSampleFormat(const AVCodec& avCodec) {
168180 return avCodec.sample_fmts [0 ];
169181}
170182
183+ torch::Tensor AudioEncoder::encodeToTensor () {
184+ TORCH_CHECK (
185+ avioContextHolder_ != nullptr ,
186+ " Cannot encode to tensor, avio context doesn't exist." );
187+ encode ();
188+ return avioContextHolder_->getOutputTensor ();
189+ }
190+
171191void AudioEncoder::encode () {
192+ // TODO-ENCODING: Need to check, but consecutive calls to encode() are
193+ // probably invalid. We can address this once we (re)design the public and
194+ // private encoding APIs.
172195 UniqueAVFrame avFrame (av_frame_alloc ());
173196 TORCH_CHECK (avFrame != nullptr , " Couldn't allocate AVFrame." );
174197 // Default to 256 like in torchaudio
0 commit comments