Skip to content

Commit ff6c1e0

Browse files
committed
Create 2 separate constructors
1 parent 9cb31c9 commit ff6c1e0

File tree

3 files changed

+62
-37
lines changed

3 files changed

+62
-37
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 49 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,17 @@ namespace facebook::torchcodec {
88

99
namespace {
1010

11+
torch::Tensor validateWf(torch::Tensor wf) {
12+
TORCH_CHECK(
13+
wf.dtype() == torch::kFloat32,
14+
"waveform must have float32 dtype, got ",
15+
wf.dtype());
16+
// TODO-ENCODING check contiguity of the input wf to ensure that it is indeed
17+
// planar (fltp).
18+
TORCH_CHECK(wf.dim() == 2, "waveform must have 2 dimensions, got ", wf.dim());
19+
return wf;
20+
}
21+
1122
void validateSampleRate(const AVCodec& avCodec, int sampleRate) {
1223
if (avCodec.supported_samplerates == nullptr) {
1324
return;
@@ -79,51 +90,57 @@ AudioEncoder::~AudioEncoder() {}
7990
AudioEncoder::AudioEncoder(
8091
const torch::Tensor wf,
8192
int sampleRate,
82-
std::optional<std::string_view> fileName,
83-
std::optional<std::string_view> formatName,
93+
std::string_view fileName,
8494
std::optional<int64_t> bitRate)
85-
: wf_(wf) {
86-
TORCH_CHECK(
87-
fileName.has_value() ^ formatName.has_value(),
88-
"Pass one of filename OR format, not both.");
95+
: wf_(validateWf(wf)) {
96+
setFFmpegLogLevel();
97+
AVFormatContext* avFormatContext = nullptr;
98+
int status = avformat_alloc_output_context2(
99+
&avFormatContext, nullptr, nullptr, fileName.data());
100+
89101
TORCH_CHECK(
90-
wf_.dtype() == torch::kFloat32,
91-
"waveform must have float32 dtype, got ",
92-
wf_.dtype());
93-
// TODO-ENCODING check contiguity of the input wf to ensure that it is indeed
94-
// planar (fltp).
102+
avFormatContext != nullptr,
103+
"Couldn't allocate AVFormatContext. ",
104+
"Check the desired extension? ",
105+
getFFMPEGErrorStringFromErrorCode(status));
106+
avFormatContext_.reset(avFormatContext);
107+
108+
status = avio_open(&avFormatContext_->pb, fileName.data(), AVIO_FLAG_WRITE);
95109
TORCH_CHECK(
96-
wf_.dim() == 2, "waveform must have 2 dimensions, got ", wf_.dim());
110+
status >= 0,
111+
"avio_open failed: ",
112+
getFFMPEGErrorStringFromErrorCode(status));
97113

114+
initializeEncoder(sampleRate, bitRate);
115+
}
116+
117+
AudioEncoder::AudioEncoder(
118+
const torch::Tensor wf,
119+
int sampleRate,
120+
std::string_view formatName,
121+
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
122+
std::optional<int64_t> bitRate)
123+
: wf_(validateWf(wf)), avioContextHolder_(std::move(avioContextHolder)) {
98124
setFFmpegLogLevel();
99125
AVFormatContext* avFormatContext = nullptr;
100-
int status = AVSUCCESS;
101-
if (fileName.has_value()) {
102-
status = avformat_alloc_output_context2(
103-
&avFormatContext, nullptr, nullptr, fileName->data());
104-
} else {
105-
status = avformat_alloc_output_context2(
106-
&avFormatContext, nullptr, formatName->data(), nullptr);
107-
}
126+
int status = avformat_alloc_output_context2(
127+
&avFormatContext, nullptr, formatName.data(), nullptr);
128+
108129
TORCH_CHECK(
109130
avFormatContext != nullptr,
110131
"Couldn't allocate AVFormatContext. ",
111132
"Check the desired extension? ",
112133
getFFMPEGErrorStringFromErrorCode(status));
113134
avFormatContext_.reset(avFormatContext);
114135

115-
if (fileName.has_value()) {
116-
status =
117-
avio_open(&avFormatContext_->pb, fileName->data(), AVIO_FLAG_WRITE);
118-
TORCH_CHECK(
119-
status >= 0,
120-
"avio_open failed: ",
121-
getFFMPEGErrorStringFromErrorCode(status));
122-
} else {
123-
avioContextHolder_ = std::make_unique<AVIOToTensorContext>();
124-
avFormatContext->pb = avioContextHolder_->getAVIOContext();
125-
}
136+
avFormatContext_->pb = avioContextHolder_->getAVIOContext();
137+
138+
initializeEncoder(sampleRate, bitRate);
139+
}
126140

141+
void AudioEncoder::initializeEncoder(
142+
int sampleRate,
143+
std::optional<int64_t> bitRate) {
127144
// We use the AVFormatContext's default codec for that
128145
// specific format/container.
129146
const AVCodec* avCodec =
@@ -162,7 +179,7 @@ AudioEncoder::AudioEncoder(
162179

163180
setDefaultChannelLayout(avCodecContext_, numChannels);
164181

165-
status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
182+
int status = avcodec_open2(avCodecContext_.get(), avCodec, nullptr);
166183
TORCH_CHECK(
167184
status == AVSUCCESS,
168185
"avcodec_open2 failed: ",

src/torchcodec/_core/Encoder.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,21 @@ class AudioEncoder {
2020
// match this, and that's up to the user. If sample rates don't match,
2121
// encoding will still work but audio will be distorted.
2222
int sampleRate,
23-
std::optional<std::string_view> fileName,
24-
std::optional<std::string_view> formatName,
23+
std::string_view fileName,
24+
std::optional<int64_t> bitRate = std::nullopt);
25+
AudioEncoder(
26+
const torch::Tensor wf,
27+
int sampleRate,
28+
std::string_view formatName,
29+
std::unique_ptr<AVIOToTensorContext> avioContextHolder,
2530
std::optional<int64_t> bitRate = std::nullopt);
2631
void encode();
2732
torch::Tensor encodeToTensor();
2833

2934
private:
35+
void initializeEncoder(
36+
int sampleRate,
37+
std::optional<int64_t> bitRate = std::nullopt);
3038
void encodeInnerLoop(
3139
AutoAVPacket& autoAVPacket,
3240
const UniqueAVFrame& srcAVFrame);

src/torchcodec/_core/custom_ops.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,8 +390,7 @@ void encode_audio_to_file(
390390
int64_t sample_rate,
391391
std::string_view file_name,
392392
std::optional<int64_t> bit_rate = std::nullopt) {
393-
AudioEncoder(
394-
wf, validateSampleRate(sample_rate), file_name, std::nullopt, bit_rate)
393+
AudioEncoder(wf, validateSampleRate(sample_rate), file_name, bit_rate)
395394
.encode();
396395
}
397396

@@ -402,11 +401,12 @@ at::Tensor encode_audio_to_tensor(
402401
int64_t sample_rate,
403402
std::string_view format,
404403
std::optional<int64_t> bit_rate = std::nullopt) {
404+
auto avioContextHolder = std::make_unique<AVIOToTensorContext>();
405405
return AudioEncoder(
406406
wf,
407407
validateSampleRate(sample_rate),
408-
std::nullopt,
409408
format,
409+
std::move(avioContextHolder),
410410
bit_rate)
411411
.encodeToTensor();
412412
}

0 commit comments

Comments
 (0)