Skip to content

Commit 7143f15

Browse files
committed
Better error checking
1 parent c8d546e commit 7143f15

File tree

3 files changed

+62
-4
lines changed

3 files changed

+62
-4
lines changed

benchmarks/decoders/gpu_benchmark.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,7 @@ def decode_full_video(video_path, decode_device_string, resize_device_string):
3939
stream_index=-1,
4040
device=decode_device_string,
4141
num_threads=num_threads,
42-
width=width,
43-
height=height,
42+
transform_specs=f"resize, {height}, {width}",
4443
)
4544

4645
start_time = time.time()

src/torchcodec/_core/custom_ops.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,13 +183,26 @@ SingleStreamDecoder::SeekMode seekModeFromString(std::string_view seekMode) {
183183
}
184184
}
185185

186+
int checkedToPositiveInt(const std::string& str) {
187+
int ret = 0;
188+
try {
189+
ret = std::stoi(str);
190+
} catch (const std::invalid_argument&) {
191+
TORCH_CHECK(false, "String cannot be converted to an int:" + str);
192+
} catch (const std::out_of_range&) {
193+
TORCH_CHECK(false, "String would become integer out of range:" + str);
194+
}
195+
TORCH_CHECK(ret > 0, "String must be a positive integer:" + str);
196+
return ret;
197+
}
198+
186199
Transform* makeResizeTransform(
187200
const std::vector<std::string>& resizeTransformSpec) {
188201
TORCH_CHECK(
189202
resizeTransformSpec.size() == 3,
190203
"resizeTransformSpec must have 3 elements including its name");
191-
int height = std::stoi(resizeTransformSpec[1]);
192-
int width = std::stoi(resizeTransformSpec[2]);
204+
int height = checkedToPositiveInt(resizeTransformSpec[1]);
205+
int width = checkedToPositiveInt(resizeTransformSpec[2]);
193206
return new ResizeTransform(FrameDims(height, width));
194207
}
195208

test/test_ops.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,52 @@ def test_scaling_on_cuda_fails(self):
655655
):
656656
add_video_stream(decoder, device="cuda", transform_specs="resize, 100, 100")
657657

658+
def test_transform_fails(self):
659+
decoder = create_from_file(str(NASA_VIDEO.path))
660+
with pytest.raises(
661+
RuntimeError,
662+
match="Invalid transform spec",
663+
):
664+
add_video_stream(decoder, transform_specs=";")
665+
666+
with pytest.raises(
667+
RuntimeError,
668+
match="Invalid transform name",
669+
):
670+
add_video_stream(decoder, transform_specs="invalid, 1, 2")
671+
672+
def test_resize_transform_fails(self):
673+
decoder = create_from_file(str(NASA_VIDEO.path))
674+
with pytest.raises(
675+
RuntimeError,
676+
match="must have 3 elements",
677+
):
678+
add_video_stream(decoder, transform_specs="resize, 100, 100, 100")
679+
680+
with pytest.raises(
681+
RuntimeError,
682+
match="must be a positive integer",
683+
):
684+
add_video_stream(decoder, transform_specs="resize, -10, 100")
685+
686+
with pytest.raises(
687+
RuntimeError,
688+
match="must be a positive integer",
689+
):
690+
add_video_stream(decoder, transform_specs="resize, 100, 0")
691+
692+
with pytest.raises(
693+
RuntimeError,
694+
match="cannot be converted to an int",
695+
):
696+
add_video_stream(decoder, transform_specs="resize, blah, 100")
697+
698+
with pytest.raises(
699+
RuntimeError,
700+
match="out of range",
701+
):
702+
add_video_stream(decoder, transform_specs="resize, 100, 1000000000000")
703+
658704
@pytest.mark.parametrize("dimension_order", ("NHWC", "NCHW"))
659705
@pytest.mark.parametrize("color_conversion_library", ("filtergraph", "swscale"))
660706
def test_color_conversion_library_with_dimension_order(

0 commit comments

Comments
 (0)