@@ -34,6 +34,44 @@ void validateSampleRate(const AVCodec& avCodec, int sampleRate) {
3434 supportedRates.str ());
3535}
3636
37+ static const std::vector<AVSampleFormat> preferredFormatsOrder = {
38+ AV_SAMPLE_FMT_FLTP,
39+ AV_SAMPLE_FMT_FLT,
40+ AV_SAMPLE_FMT_DBLP,
41+ AV_SAMPLE_FMT_DBL,
42+ AV_SAMPLE_FMT_S64P,
43+ AV_SAMPLE_FMT_S64,
44+ AV_SAMPLE_FMT_S32P,
45+ AV_SAMPLE_FMT_S32,
46+ AV_SAMPLE_FMT_S16P,
47+ AV_SAMPLE_FMT_S16,
48+ AV_SAMPLE_FMT_U8P,
49+ AV_SAMPLE_FMT_U8};
50+
51+ AVSampleFormat findBestOutputSampleFormat (const AVCodec& avCodec) {
52+ // Find a sample format that the encoder supports. We prefer using FLT[P],
53+ // since this is the format of the input waveform. If FLTP isn't supported
54+ // then we'll need to convert the AVFrame's format. Our heuristic is to encode
55+ // into the format with the highest resolution.
56+ if (avCodec.sample_fmts == nullptr ) {
57+ // Can't really validate anything in this case, best we can do is hope that
58+ // FLTP is supported by the encoder. If not, FFmpeg will raise.
59+ return AV_SAMPLE_FMT_FLTP;
60+ }
61+
62+ for (AVSampleFormat preferredFormat : preferredFormatsOrder) {
63+ for (int i = 0 ; avCodec.sample_fmts [i] != -1 ; ++i) {
64+ if (avCodec.sample_fmts [i] == preferredFormat) {
65+ return preferredFormat;
66+ }
67+ }
68+ }
69+ // We should always find a match in preferredFormatsOrder, so we should always
70+ // return earlier. But in the event that a future FFmpeg version defines an
71+ // additional sample format that isn't in preferredFormatsOrder, we fallback:
72+ return avCodec.sample_fmts [0 ];
73+ }
74+
3775} // namespace
3876
3977AudioEncoder::~AudioEncoder () {}
@@ -43,7 +81,7 @@ AudioEncoder::AudioEncoder(
4381 int sampleRate,
4482 std::optional<std::string_view> fileName,
4583 std::optional<std::string_view> formatName,
46- std::optional<int64_t > bit_rate )
84+ std::optional<int64_t > bitRate )
4785 : wf_(wf) {
4886 TORCH_CHECK (
4987 fileName.has_value () ^ formatName.has_value (),
@@ -96,20 +134,20 @@ AudioEncoder::AudioEncoder(
96134 TORCH_CHECK (avCodecContext != nullptr , " Couldn't allocate codec context." );
97135 avCodecContext_.reset (avCodecContext);
98136
99- if (bit_rate .has_value ()) {
100- TORCH_CHECK (*bit_rate >= 0 , " bit_rate=" , *bit_rate , " must be >= 0." );
137+ if (bitRate .has_value ()) {
138+ TORCH_CHECK (*bitRate >= 0 , " bit_rate=" , *bitRate , " must be >= 0." );
101139 }
102140 // bit_rate=None defaults to 0, which is what the FFmpeg CLI seems to use as
103141 // well when "-b:a" isn't specified.
104- avCodecContext_->bit_rate = bit_rate .value_or (0 );
142+ avCodecContext_->bit_rate = bitRate .value_or (0 );
105143
106144 validateSampleRate (*avCodec, sampleRate);
107145 avCodecContext_->sample_rate = sampleRate;
108146
109147 // Input waveform is expected to be FLTP. Not all encoders support FLTP, so we
110148 // may need to convert the wf into a supported output sample format, which is
111149 // what the `.sample_fmt` defines.
112- avCodecContext_->sample_fmt = findOutputSampleFormat (*avCodec);
150+ avCodecContext_->sample_fmt = findBestOutputSampleFormat (*avCodec);
113151
114152 int numChannels = static_cast <int >(wf_.sizes ()[0 ]);
115153 TORCH_CHECK (
@@ -144,42 +182,6 @@ AudioEncoder::AudioEncoder(
144182 streamIndex_ = avStream->index ;
145183}
146184
147- AVSampleFormat AudioEncoder::findOutputSampleFormat (const AVCodec& avCodec) {
148- // Find a sample format that the encoder supports. We prefer using FLT[P],
149- // since this is the format of the input waveform. If FLTP isn't supported
150- // then we'll need to convert the AVFrame's format. Our heuristic is to encode
151- // into the format with the highest resolution.
152- if (avCodec.sample_fmts == nullptr ) {
153- // Can't really validate anything in this case, best we can do is hope that
154- // FLTP is supported by the encoder. If not, FFmpeg will raise.
155- return AV_SAMPLE_FMT_FLTP;
156- }
157-
158- std::vector<AVSampleFormat> preferredFormatsOrder = {
159- AV_SAMPLE_FMT_FLTP,
160- AV_SAMPLE_FMT_FLT,
161- AV_SAMPLE_FMT_DBLP,
162- AV_SAMPLE_FMT_DBL,
163- AV_SAMPLE_FMT_S64P,
164- AV_SAMPLE_FMT_S64,
165- AV_SAMPLE_FMT_S32P,
166- AV_SAMPLE_FMT_S32,
167- AV_SAMPLE_FMT_S16P,
168- AV_SAMPLE_FMT_S16,
169- AV_SAMPLE_FMT_U8P,
170- AV_SAMPLE_FMT_U8};
171-
172- for (AVSampleFormat preferredFormat : preferredFormatsOrder) {
173- for (auto i = 0 ; avCodec.sample_fmts [i] != -1 ; ++i) {
174- if (avCodec.sample_fmts [i] == preferredFormat) {
175- return preferredFormat;
176- }
177- }
178- }
179- // Should never happen, but just in case
180- return avCodec.sample_fmts [0 ];
181- }
182-
183185torch::Tensor AudioEncoder::encodeToTensor () {
184186 TORCH_CHECK (
185187 avioContextHolder_ != nullptr ,
0 commit comments