Skip to content

Commit 37629a8

Browse files
fix: Raise error when receive non-positive value in RandAugment. (pytorch#8994)
Co-authored-by: Nicolas Hug <[email protected]> Co-authored-by: Nicolas Hug <[email protected]>
1 parent 9f5b3f2 commit 37629a8

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

test/test_transforms_v2.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3505,6 +3505,14 @@ def test_aug_mix_severity_error(self, severity):
35053505
with pytest.raises(ValueError, match="severity must be between"):
35063506
transforms.AugMix(severity=severity)
35073507

3508+
@pytest.mark.parametrize("num_ops", [-1, 1.1])
3509+
def test_rand_augment_num_ops_error(self, num_ops):
3510+
with pytest.raises(
3511+
ValueError,
3512+
match=re.escape(f"num_ops should be a non-negative integer, but got {num_ops} instead."),
3513+
):
3514+
transforms.RandAugment(num_ops=num_ops)
3515+
35083516

35093517
class TestConvertBoundingBoxFormat:
35103518
old_new_formats = list(itertools.permutations(iter(tv_tensors.BoundingBoxFormat), 2))

torchvision/transforms/v2/_auto_augment.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,8 @@ class RandAugment(_AutoAugmentBase):
361361
If img is PIL Image, it is expected to be in mode "L" or "RGB".
362362
363363
Args:
364-
num_ops (int, optional): Number of augmentation transformations to apply sequentially.
364+
num_ops (int, optional): Number of augmentation transformations to apply sequentially,
365+
must be non-negative integer. Default: 2.
365366
magnitude (int, optional): Magnitude for all the transformations.
366367
num_magnitude_bins (int, optional): The number of different magnitude values.
367368
interpolation (InterpolationMode, optional): Desired interpolation enum defined by
@@ -407,6 +408,8 @@ def __init__(
407408
fill: Union[_FillType, Dict[Union[Type, str], _FillType]] = None,
408409
) -> None:
409410
super().__init__(interpolation=interpolation, fill=fill)
411+
if not isinstance(num_ops, int) or (num_ops < 0):
412+
raise ValueError(f"num_ops should be a non-negative integer, but got {num_ops} instead.")
410413
self.num_ops = num_ops
411414
self.magnitude = magnitude
412415
self.num_magnitude_bins = num_magnitude_bins

0 commit comments

Comments
 (0)