Skip to content

Commit adbf151

Browse files
author
Daniel Flores
committed
test input correctness
1 parent 1cca109 commit adbf151

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

src/torchcodec/_core/Encoder.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,7 @@ namespace {
517517
torch::Tensor validateFrames(const torch::Tensor& frames) {
518518
TORCH_CHECK(
519519
frames.dtype() == torch::kUInt8,
520-
"frames must have kUInt8 dtype, got ",
520+
"frames must have uint8 dtype, got ",
521521
frames.dtype());
522522
TORCH_CHECK(
523523
frames.dim() == 4,
@@ -527,7 +527,6 @@ torch::Tensor validateFrames(const torch::Tensor& frames) {
527527
frames.sizes()[1] == 3,
528528
"frame must have 3 channels (R, G, B), got ",
529529
frames.sizes()[1]);
530-
// TODO-VideoEncoder: Add tests for above validations
531530
// TODO-VideoEncoder: Investigate if non-contiguous frames can be returned
532531
return frames.contiguous();
533532
}

test/test_ops.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,6 +1228,36 @@ def test_bad_input(self, tmp_path):
12281228

12291229
class TestVideoEncoderOps:
12301230

1231+
def test_bad_input(self, tmp_path):
1232+
output_file = str(tmp_path / ".mp3")
1233+
1234+
with pytest.raises(
1235+
RuntimeError, match="frames must have uint8 dtype, got float"
1236+
):
1237+
encode_video_to_file(
1238+
frames=torch.rand((10, 3, 60, 60), dtype=torch.float),
1239+
frame_rate=10,
1240+
filename=output_file,
1241+
)
1242+
1243+
with pytest.raises(
1244+
RuntimeError, match=r"frames must have 4 dimensions \(N, C, H, W\), got 3"
1245+
):
1246+
encode_video_to_file(
1247+
frames=torch.randint(high=1, size=(3, 60, 60), dtype=torch.uint8),
1248+
frame_rate=10,
1249+
filename=output_file,
1250+
)
1251+
1252+
with pytest.raises(
1253+
RuntimeError, match=r"frame must have 3 channels \(R, G, B\), got 2"
1254+
):
1255+
encode_video_to_file(
1256+
frames=torch.randint(high=1, size=(10, 2, 60, 60), dtype=torch.uint8),
1257+
frame_rate=10,
1258+
filename=output_file,
1259+
)
1260+
12311261
def decode(self, file_path) -> torch.Tensor:
12321262
decoder = create_from_file(str(file_path), seek_mode="approximate")
12331263
add_video_stream(decoder)

0 commit comments

Comments
 (0)