Skip to content

Commit 7b3847f

Browse files
committed
Merge branch 'main' of github.com:pytorch/torchcodec into avio
2 parents ee7a217 + 2c137e7 commit 7b3847f

File tree

3 files changed

+44
-44
lines changed

3 files changed

+44
-44
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 43 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,44 @@ void validateSampleRate(const AVCodec& avCodec, int sampleRate) {
3434
supportedRates.str());
3535
}
3636

37+
static const std::vector<AVSampleFormat> preferredFormatsOrder = {
38+
AV_SAMPLE_FMT_FLTP,
39+
AV_SAMPLE_FMT_FLT,
40+
AV_SAMPLE_FMT_DBLP,
41+
AV_SAMPLE_FMT_DBL,
42+
AV_SAMPLE_FMT_S64P,
43+
AV_SAMPLE_FMT_S64,
44+
AV_SAMPLE_FMT_S32P,
45+
AV_SAMPLE_FMT_S32,
46+
AV_SAMPLE_FMT_S16P,
47+
AV_SAMPLE_FMT_S16,
48+
AV_SAMPLE_FMT_U8P,
49+
AV_SAMPLE_FMT_U8};
50+
51+
AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) {
52+
// Find a sample format that the encoder supports. We prefer using FLT[P],
53+
// since this is the format of the input waveform. If FLTP isn't supported
54+
// then we'll need to convert the AVFrame's format. Our heuristic is to encode
55+
// into the format with the highest resolution.
56+
if (avCodec.sample_fmts == nullptr) {
57+
// Can't really validate anything in this case, best we can do is hope that
58+
// FLTP is supported by the encoder. If not, FFmpeg will raise.
59+
return AV_SAMPLE_FMT_FLTP;
60+
}
61+
62+
for (AVSampleFormat preferredFormat : preferredFormatsOrder) {
63+
for (int i = 0; avCodec.sample_fmts[i] != -1; ++i) {
64+
if (avCodec.sample_fmts[i] == preferredFormat) {
65+
return preferredFormat;
66+
}
67+
}
68+
}
69+
// We should always find a match in preferredFormatsOrder, so we should always
70+
// return earlier. But in the event that a future FFmpeg version defines an
71+
// additional sample format that isn't in preferredFormatsOrder, we fallback:
72+
return avCodec.sample_fmts[0];
73+
}
74+
3775
} // namespace
3876

3977
AudioEncoder::~AudioEncoder() {}
@@ -43,7 +81,7 @@ AudioEncoder::AudioEncoder(
4381
int sampleRate,
4482
std::optional<std::string_view> fileName,
4583
std::optional<std::string_view> formatName,
46-
std::optional<int64_t> bit_rate)
84+
std::optional<int64_t> bitRate)
4785
: wf_(wf) {
4886
TORCH_CHECK(
4987
fileName.has_value() ^ formatName.has_value(),
@@ -96,20 +134,20 @@ AudioEncoder::AudioEncoder(
96134
TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context.");
97135
avCodecContext_.reset(avCodecContext);
98136

99-
if (bit_rate.has_value()) {
100-
TORCH_CHECK(*bit_rate >= 0, "bit_rate=", *bit_rate, " must be >= 0.");
137+
if (bitRate.has_value()) {
138+
TORCH_CHECK(*bitRate >= 0, "bit_rate=", *bitRate, " must be >= 0.");
101139
}
102140
// bit_rate=None defaults to 0, which is what the FFmpeg CLI seems to use as
103141
// well when "-b:a" isn't specified.
104-
avCodecContext_->bit_rate = bit_rate.value_or(0);
142+
avCodecContext_->bit_rate = bitRate.value_or(0);
105143

106144
validateSampleRate(*avCodec, sampleRate);
107145
avCodecContext_->sample_rate = sampleRate;
108146

109147
// Input waveform is expected to be FLTP. Not all encoders support FLTP, so we
110148
// may need to convert the wf into a supported output sample format, which is
111149
// what the `.sample_fmt` defines.
112-
avCodecContext_->sample_fmt = findOutputSampleFormat(*avCodec);
150+
avCodecContext_->sample_fmt = findBestOutputSampleFormat(*avCodec);
113151

114152
int numChannels = static_cast<int>(wf_.sizes()[0]);
115153
TORCH_CHECK(
@@ -144,42 +182,6 @@ AudioEncoder::AudioEncoder(
144182
streamIndex_ = avStream->index;
145183
}
146184

147-
AVSampleFormat AudioEncoder::findOutputSampleFormat(const AVCodec& avCodec) {
148-
// Find a sample format that the encoder supports. We prefer using FLT[P],
149-
// since this is the format of the input waveform. If FLTP isn't supported
150-
// then we'll need to convert the AVFrame's format. Our heuristic is to encode
151-
// into the format with the highest resolution.
152-
if (avCodec.sample_fmts == nullptr) {
153-
// Can't really validate anything in this case, best we can do is hope that
154-
// FLTP is supported by the encoder. If not, FFmpeg will raise.
155-
return AV_SAMPLE_FMT_FLTP;
156-
}
157-
158-
std::vector<AVSampleFormat> preferredFormatsOrder = {
159-
AV_SAMPLE_FMT_FLTP,
160-
AV_SAMPLE_FMT_FLT,
161-
AV_SAMPLE_FMT_DBLP,
162-
AV_SAMPLE_FMT_DBL,
163-
AV_SAMPLE_FMT_S64P,
164-
AV_SAMPLE_FMT_S64,
165-
AV_SAMPLE_FMT_S32P,
166-
AV_SAMPLE_FMT_S32,
167-
AV_SAMPLE_FMT_S16P,
168-
AV_SAMPLE_FMT_S16,
169-
AV_SAMPLE_FMT_U8P,
170-
AV_SAMPLE_FMT_U8};
171-
172-
for (AVSampleFormat preferredFormat : preferredFormatsOrder) {
173-
for (auto i = 0; avCodec.sample_fmts[i] != -1; ++i) {
174-
if (avCodec.sample_fmts[i] == preferredFormat) {
175-
return preferredFormat;
176-
}
177-
}
178-
}
179-
// Should never happen, but just in case
180-
return avCodec.sample_fmts[0];
181-
}
182-
183185
torch::Tensor AudioEncoder::encodeToTensor() {
184186
TORCH_CHECK(
185187
avioContextHolder_ != nullptr,

src/torchcodec/_core/Encoder.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class AudioEncoder {
2222
int sampleRate,
2323
std::optional<std::string_view> fileName,
2424
std::optional<std::string_view> formatName,
25-
std::optional<int64_t> bit_rate = std::nullopt);
25+
std::optional<int64_t> bitRate = std::nullopt);
2626
void encode();
2727
torch::Tensor encodeToTensor();
2828

@@ -31,7 +31,6 @@ class AudioEncoder {
3131
AutoAVPacket& autoAVPacket,
3232
const UniqueAVFrame& srcAVFrame);
3333
void flushBuffers();
34-
AVSampleFormat findOutputSampleFormat(const AVCodec& avCodec);
3534

3635
UniqueEncodingAVFormatContext avFormatContext_;
3736
UniqueAVCodecContext avCodecContext_;

test/test_ops.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1172,7 +1172,6 @@ def test_against_cli(self, asset, bit_rate, output_format, tmp_path):
11721172
pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files")
11731173

11741174
encoded_by_ffmpeg = tmp_path / f"ffmpeg_output.{output_format}"
1175-
11761175
subprocess.run(
11771176
["ffmpeg", "-i", str(asset.path)]
11781177
+ (["-b:a", f"{bit_rate}"] if bit_rate is not None else [])

0 commit comments

Comments
 (0)