@@ -8,6 +8,17 @@ namespace facebook::torchcodec {
88
99namespace {
1010
11+ torch::Tensor validateWf (torch::Tensor wf) {
12+ TORCH_CHECK (
13+ wf.dtype () == torch::kFloat32 ,
14+ " waveform must have float32 dtype, got " ,
15+ wf.dtype ());
16+ // TODO-ENCODING check contiguity of the input wf to ensure that it is indeed
17+ // planar (fltp).
18+ TORCH_CHECK (wf.dim () == 2 , " waveform must have 2 dimensions, got " , wf.dim ());
19+ return wf;
20+ }
21+
1122void validateSampleRate (const AVCodec& avCodec, int sampleRate) {
1223 if (avCodec.supported_samplerates == nullptr ) {
1324 return ;
@@ -79,51 +90,57 @@ AudioEncoder::~AudioEncoder() {}
7990AudioEncoder::AudioEncoder (
8091 const torch::Tensor wf,
8192 int sampleRate,
82- std::optional<std::string_view> fileName,
83- std::optional<std::string_view> formatName,
93+ std::string_view fileName,
8494 std::optional<int64_t > bitRate)
85- : wf_(wf) {
86- TORCH_CHECK (
87- fileName.has_value () ^ formatName.has_value (),
88- " Pass one of filename OR format, not both." );
95+ : wf_(validateWf(wf)) {
96+ setFFmpegLogLevel ();
97+ AVFormatContext* avFormatContext = nullptr ;
98+ int status = avformat_alloc_output_context2 (
99+ &avFormatContext, nullptr , nullptr , fileName.data ());
100+
89101 TORCH_CHECK (
90- wf_.dtype () == torch::kFloat32 ,
91- " waveform must have float32 dtype, got " ,
92- wf_.dtype ());
93- // TODO-ENCODING check contiguity of the input wf to ensure that it is indeed
94- // planar (fltp).
102+ avFormatContext != nullptr ,
103+ " Couldn't allocate AVFormatContext. " ,
104+ " Check the desired extension? " ,
105+ getFFMPEGErrorStringFromErrorCode (status));
106+ avFormatContext_.reset (avFormatContext);
107+
108+ status = avio_open (&avFormatContext_->pb , fileName.data (), AVIO_FLAG_WRITE);
95109 TORCH_CHECK (
96- wf_.dim () == 2 , " waveform must have 2 dimensions, got " , wf_.dim ());
110+ status >= 0 ,
111+ " avio_open failed: " ,
112+ getFFMPEGErrorStringFromErrorCode (status));
97113
114+ initializeEncoder (sampleRate, bitRate);
115+ }
116+
117+ AudioEncoder::AudioEncoder (
118+ const torch::Tensor wf,
119+ int sampleRate,
120+ std::string_view formatName,
121+ std::unique_ptr<AVIOToTensorContext> avioContextHolder,
122+ std::optional<int64_t > bitRate)
123+ : wf_(validateWf(wf)), avioContextHolder_(std::move(avioContextHolder)) {
98124 setFFmpegLogLevel ();
99125 AVFormatContext* avFormatContext = nullptr ;
100- int status = AVSUCCESS;
101- if (fileName.has_value ()) {
102- status = avformat_alloc_output_context2 (
103- &avFormatContext, nullptr , nullptr , fileName->data ());
104- } else {
105- status = avformat_alloc_output_context2 (
106- &avFormatContext, nullptr , formatName->data (), nullptr );
107- }
126+ int status = avformat_alloc_output_context2 (
127+ &avFormatContext, nullptr , formatName.data (), nullptr );
128+
108129 TORCH_CHECK (
109130 avFormatContext != nullptr ,
110131 " Couldn't allocate AVFormatContext. " ,
111132 " Check the desired extension? " ,
112133 getFFMPEGErrorStringFromErrorCode (status));
113134 avFormatContext_.reset (avFormatContext);
114135
115- if (fileName.has_value ()) {
116- status =
117- avio_open (&avFormatContext_->pb , fileName->data (), AVIO_FLAG_WRITE);
118- TORCH_CHECK (
119- status >= 0 ,
120- " avio_open failed: " ,
121- getFFMPEGErrorStringFromErrorCode (status));
122- } else {
123- avioContextHolder_ = std::make_unique<AVIOToTensorContext>();
124- avFormatContext->pb = avioContextHolder_->getAVIOContext ();
125- }
136+ avFormatContext_->pb = avioContextHolder_->getAVIOContext ();
137+
138+ initializeEncoder (sampleRate, bitRate);
139+ }
126140
141+ void AudioEncoder::initializeEncoder (
142+ int sampleRate,
143+ std::optional<int64_t > bitRate) {
127144 // We use the AVFormatContext's default codec for that
128145 // specific format/container.
129146 const AVCodec* avCodec =
@@ -162,7 +179,7 @@ AudioEncoder::AudioEncoder(
162179
163180 setDefaultChannelLayout (avCodecContext_, numChannels);
164181
165- status = avcodec_open2 (avCodecContext_.get (), avCodec, nullptr );
182+ int status = avcodec_open2 (avCodecContext_.get (), avCodec, nullptr );
166183 TORCH_CHECK (
167184 status == AVSUCCESS,
168185 " avcodec_open2 failed: " ,
0 commit comments