11#include < sstream>
22
3+ #include " src/torchcodec/_core/AVIOBytesContext.h"
34#include " src/torchcodec/_core/Encoder.h"
45#include " torch/types.h"
56
@@ -40,9 +41,13 @@ AudioEncoder::~AudioEncoder() {}
4041AudioEncoder::AudioEncoder (
4142 const torch::Tensor wf,
4243 int sampleRate,
43- std::string_view fileName,
44+ std::optional<std::string_view> fileName,
45+ std::optional<std::string_view> formatName,
4446 std::optional<int64_t > bit_rate)
4547 : wf_(wf) {
48+ TORCH_CHECK (
49+ fileName.has_value () ^ formatName.has_value (),
50+ " Pass one of filename OR format, not both." );
4651 TORCH_CHECK (
4752 wf_.dtype () == torch::kFloat32 ,
4853 " waveform must have float32 dtype, got " ,
@@ -52,27 +57,35 @@ AudioEncoder::AudioEncoder(
5257 TORCH_CHECK (
5358 wf_.dim () == 2 , " waveform must have 2 dimensions, got " , wf_.dim ());
5459
60+ avioContextHolder_ = std::make_unique<AVIOToTensorContext>();
61+
5562 setFFmpegLogLevel ();
5663 AVFormatContext* avFormatContext = nullptr ;
57- auto status = avformat_alloc_output_context2 (
58- &avFormatContext, nullptr , nullptr , fileName.data ());
64+ int status = AVSUCCESS;
65+ if (fileName.has_value ()) {
66+ status = avformat_alloc_output_context2 (
67+ &avFormatContext, nullptr , nullptr , fileName->data ());
68+ } else {
69+ status = avformat_alloc_output_context2 (
70+ &avFormatContext, nullptr , formatName->data (), nullptr );
71+ }
5972 TORCH_CHECK (
6073 avFormatContext != nullptr ,
6174 " Couldn't allocate AVFormatContext. " ,
6275 " Check the desired extension? " ,
6376 getFFMPEGErrorStringFromErrorCode (status));
6477 avFormatContext_.reset (avFormatContext);
6578
66- // TODO-ENCODING: Should also support encoding into bytes (use
67- // AVIOBytesContext)
68- TORCH_CHECK (
69- !(avFormatContext-> oformat -> flags & AVFMT_NOFILE),
70- " AVFMT_NOFILE is set. We only support writing to a file. " );
71- status = avio_open (&avFormatContext_-> pb , fileName. data (), AVIO_FLAG_WRITE);
72- TORCH_CHECK (
73- status >= 0 ,
74- " avio_open failed: " ,
75- getFFMPEGErrorStringFromErrorCode (status));
79+ if (fileName. has_value ()) {
80+ status =
81+ avio_open (&avFormatContext_-> pb , fileName-> data (), AVIO_FLAG_WRITE);
82+ TORCH_CHECK (
83+ status >= 0 ,
84+ " avio_open failed: " ,
85+ getFFMPEGErrorStringFromErrorCode (status));
86+ } else {
87+ avFormatContext-> pb = avioContextHolder_-> getAVIOContext ();
88+ }
7689
7790 // We use the AVFormatContext's default codec for that
7891 // specific format/container.
@@ -168,7 +181,18 @@ AVSampleFormat AudioEncoder::findOutputSampleFormat(const AVCodec& avCodec) {
168181 return avCodec.sample_fmts [0 ];
169182}
170183
184+ torch::Tensor AudioEncoder::encodeToTensor () {
185+ TORCH_CHECK (
186+ avioContextHolder_ != nullptr ,
187+ " Cannot encode to tensor, avio context doesn't exist." );
188+ encode ();
189+ return avioContextHolder_->getOutputTensor ();
190+ }
191+
171192void AudioEncoder::encode () {
193+ // TODO-ENCODING: Need to check, but consecutive calls to encode() are
194+ // probably invalid. We can address this once we (re)design the public and
195+ // private encoding APIs.
172196 UniqueAVFrame avFrame (av_frame_alloc ());
173197 TORCH_CHECK (avFrame != nullptr , " Couldn't allocate AVFrame." );
174198 // Default to 256 like in torchaudio
0 commit comments