Skip to content

Commit fa9e85f

Browse files
committed
Validate encoder sample rate
1 parent 8b19f45 commit fa9e85f

File tree

3 files changed

+39
-14
lines changed

3 files changed

+39
-14
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,45 @@
33

44
namespace facebook::torchcodec {
55

6+
namespace {
7+
8+
void validateSampleRate(const AVCodec& avCodec, int sampleRate) {
9+
if (avCodec.supported_samplerates == nullptr) {
10+
return;
11+
}
12+
13+
for (auto i = 0; avCodec.supported_samplerates[i] != 0; ++i) {
14+
if (sampleRate == avCodec.supported_samplerates[i]) {
15+
return;
16+
}
17+
}
18+
std::string supportedRates;
19+
for (auto i = 0; avCodec.supported_samplerates[i] != 0; ++i) {
20+
if (i > 0) {
21+
supportedRates += ", ";
22+
}
23+
supportedRates += std::to_string(avCodec.supported_samplerates[i]);
24+
}
25+
26+
TORCH_CHECK(
27+
false,
28+
"invalid sample rate=",
29+
sampleRate,
30+
". Supported sample rate values are: ",
31+
supportedRates);
32+
}
33+
34+
} // namespace
35+
636
AudioEncoder::~AudioEncoder() {}
737

838
// TODO-ENCODING: disable ffmpeg logs by default
939

1040
AudioEncoder::AudioEncoder(
1141
const torch::Tensor wf,
1242
int sampleRate,
13-
std::string_view fileName)
14-
: wf_(wf), sampleRate_(sampleRate) {
43+
std::string_view fileName,
44+
: wf_(wf) {
1545
TORCH_CHECK(
1646
wf_.dtype() == torch::kFloat32,
1747
"waveform must have float32 dtype, got ",
@@ -55,7 +85,8 @@ AudioEncoder::AudioEncoder(
5585
// TODO-ENCODING Should also let user choose for compressed formats like mp3.
5686
avCodecContext_->bit_rate = 0;
5787

58-
avCodecContext_->sample_rate = sampleRate_;
88+
validateSampleRate(*avCodec, sampleRate);
89+
avCodecContext_->sample_rate = sampleRate;
5990

6091
// Note: This is the format of the **input** waveform. This doesn't determine
6192
// the output.

src/torchcodec/_core/Encoder.h

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ class AudioEncoder {
99

1010
AudioEncoder(
1111
const torch::Tensor wf,
12+
// The *output* sample rate. We can't really decide for the user what it
13+
// should be. Particularly, the sample rate of the input waveform should
14+
// match this, and that's up to the user. If sample rates don't match,
15+
// encoding will still work but audio will be distorted.
1216
int sampleRate,
1317
std::string_view fileName);
1418
void encode();
@@ -24,13 +28,5 @@ class AudioEncoder {
2428
int streamIndex_;
2529

2630
const torch::Tensor wf_;
27-
// The *output* sample rate. We can't really decide for the user what it
28-
// should be. Particularly, the sample rate of the input waveform should match
29-
// this, and that's up to the user. If sample rates don't match, encoding will
30-
// still work but audio will be distorted.
31-
// We technically could let the user also specify the input sample rate, and
32-
// resample the waveform internally to match them, but that's not in scope for
33-
// an initial version (if at all).
34-
int sampleRate_;
3531
};
3632
} // namespace facebook::torchcodec

test/test_ops.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,9 +1107,7 @@ def test_bad_input(self, tmp_path):
11071107
wf=torch.rand(10, 10), sample_rate=10, filename="./file.bad_extension"
11081108
)
11091109

1110-
# TODO-ENCODING: raise more informative error message when sample rate
1111-
# isn't supported
1112-
with pytest.raises(RuntimeError, match="Invalid argument"):
1110+
with pytest.raises(RuntimeError, match="invalid sample rate=10"):
11131111
create_audio_encoder(
11141112
wf=self.decode(NASA_AUDIO_MP3),
11151113
sample_rate=10,

0 commit comments

Comments
 (0)