Skip to content

Commit c430e12

Browse files
fix: add appropriate error message when validating padding argument. (pytorch#8959)
Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent c68f4ed commit c430e12

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

test/test_transforms_v2.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3013,12 +3013,18 @@ def test_errors(self):
30133013
with pytest.raises(ValueError, match="Please provide only two dimensions"):
30143014
transforms.RandomCrop([10, 12, 14])
30153015

3016-
with pytest.raises(TypeError, match="Got inappropriate padding arg"):
3016+
with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"):
30173017
transforms.RandomCrop([10, 12], padding="abc")
30183018

30193019
with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"):
30203020
transforms.RandomCrop([10, 12], padding=[-0.7, 0, 0.7])
30213021

3022+
with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"):
3023+
transforms.RandomCrop([10, 12], padding=0.5)
3024+
3025+
with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"):
3026+
transforms.RandomCrop([10, 12], padding=[0.5, 0.5])
3027+
30223028
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
30233029
transforms.RandomCrop([10, 12], padding=1, fill="abc")
30243030

@@ -3878,12 +3884,18 @@ def test_transform(self, make_input):
38783884
check_transform(transforms.Pad(padding=[1]), make_input())
38793885

38803886
def test_transform_errors(self):
3881-
with pytest.raises(TypeError, match="Got inappropriate padding arg"):
3887+
with pytest.raises(ValueError, match="Padding must be"):
38823888
transforms.Pad("abc")
38833889

3884-
with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"):
3890+
with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4 element of tuple or list"):
38853891
transforms.Pad([-0.7, 0, 0.7])
38863892

3893+
with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4 element of tuple or list"):
3894+
transforms.Pad(0.5)
3895+
3896+
with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4 element of tuple or list"):
3897+
transforms.Pad(padding=[0.5, 0.5])
3898+
38873899
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
38883900
transforms.Pad(12, fill="abc")
38893901

torchvision/transforms/v2/_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,11 +81,13 @@ def _get_fill(fill_dict, inpt_type):
8181

8282

8383
def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None:
84-
if not isinstance(padding, (numbers.Number, tuple, list)):
85-
raise TypeError("Got inappropriate padding arg")
8684

87-
if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]:
88-
raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple")
85+
err_msg = f"Padding must be an int or a 1, 2, or 4 element of tuple or list, got {padding}."
86+
if isinstance(padding, (tuple, list)):
87+
if len(padding) not in [1, 2, 4] or not all(isinstance(p, int) for p in padding):
88+
raise ValueError(err_msg)
89+
elif not isinstance(padding, int):
90+
raise ValueError(err_msg)
8991

9092

9193
# TODO: let's use torchvision._utils.StrEnum to have the best of both worlds (strings and enums)

0 commit comments

Comments
 (0)