Skip to content

Commit 52c4d54

Browse files
committed
more tests
1 parent 42f5160 commit 52c4d54

File tree

2 files changed

+41
-3
lines changed

2 files changed

+41
-3
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,19 @@ Encoder::Encoder(
1717
int sampleRate,
1818
std::string_view fileName)
1919
: wf_(wf), sampleRate_(sampleRate) {
20+
TORCH_CHECK(
21+
wf_.dtype() == torch::kFloat32,
22+
"waveform must have float32 dtype, got ",
23+
wf_.dtype());
24+
TORCH_CHECK(
25+
wf_.dim() == 2, "waveform must have 2 dimensions, got ", wf_.dim());
2026
AVFormatContext* avFormatContext = nullptr;
2127
avformat_alloc_output_context2(
2228
&avFormatContext, nullptr, nullptr, fileName.data());
23-
TORCH_CHECK(avFormatContext != nullptr, "Couldn't allocate AVFormatContext.");
29+
TORCH_CHECK(
30+
avFormatContext != nullptr,
31+
"Couldn't allocate AVFormatContext. ",
32+
"Check the desired extension?");
2433
avFormatContext_.reset(avFormatContext);
2534

2635
// TODO-ENCODING: Should also support encoding into bytes (use
@@ -51,8 +60,6 @@ Encoder::Encoder(
5160
// TODO-ENCODING Should also let user choose for compressed formats like mp3.
5261
avCodecContext_->bit_rate = 0;
5362

54-
// FFmpeg will raise a reasonably informative error if the desired sample rate
55-
// isn't supported by the encoder.
5663
avCodecContext_->sample_rate = sampleRate_;
5764

5865
// Note: This is the format of the **input** waveform. This doesn't determine

test/decoders/test_ops.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1011,6 +1011,37 @@ def decode(self, source) -> torch.Tensor:
10111011
)
10121012
return frames
10131013

1014+
def test_bad_input(self, tmp_path):
1015+
1016+
valid_output_file = str(tmp_path / ".mp3")
1017+
1018+
with pytest.raises(RuntimeError, match="must have float32 dtype, got int"):
1019+
create_encoder(
1020+
wf=torch.arange(10, dtype=torch.int),
1021+
sample_rate=10,
1022+
filename=valid_output_file,
1023+
)
1024+
with pytest.raises(RuntimeError, match="must have 2 dimensions, got 1"):
1025+
create_encoder(wf=torch.rand(3), sample_rate=10, filename=valid_output_file)
1026+
1027+
with pytest.raises(RuntimeError, match="No such file or directory"):
1028+
create_encoder(
1029+
wf=torch.rand(10, 10), sample_rate=10, filename="./bad/path.mp3"
1030+
)
1031+
with pytest.raises(RuntimeError, match="Check the desired extension"):
1032+
create_encoder(
1033+
wf=torch.rand(10, 10), sample_rate=10, filename="./file.bad_extension"
1034+
)
1035+
1036+
# TODO-ENCODING: raise more informative error message when sample rate
1037+
# isn't supported
1038+
with pytest.raises(RuntimeError, match="Invalid argument"):
1039+
create_encoder(
1040+
wf=self.decode(NASA_AUDIO_MP3),
1041+
sample_rate=10,
1042+
filename=valid_output_file,
1043+
)
1044+
10141045
def test_round_trip(self, tmp_path):
10151046
# Check that decode(encode(samples)) == samples
10161047
asset = NASA_AUDIO_MP3

0 commit comments

Comments
 (0)