Skip to content

Commit 0aa91b5

Browse files
committed
Update estimator typing in AdversarialPatchPyTorch
Signed-off-by: Beat Buesser <[email protected]>
1 parent 2afa66a commit 0aa91b5

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141

4242
import torch
4343

44-
from art.utils import CLASSIFIER_NEURALNETWORK_TYPE
44+
from art.utils import CLASSIFIER_NEURALNETWORK_TYPE, PYTORCH_OBJECT_DETECTOR_TYPE
4545

4646
logger = logging.getLogger(__name__)
4747

@@ -72,7 +72,7 @@ class AdversarialPatchPyTorch(EvasionAttack):
7272

7373
def __init__(
7474
self,
75-
estimator: "CLASSIFIER_NEURALNETWORK_TYPE",
75+
estimator: "CLASSIFIER_NEURALNETWORK_TYPE | PYTORCH_OBJECT_DETECTOR_TYPE",
7676
rotation_max: float = 22.5,
7777
scale_min: float = 0.1,
7878
scale_max: float = 1.0,
@@ -91,7 +91,7 @@ def __init__(
9191
"""
9292
Create an instance of the :class:`.AdversarialPatchPyTorch`.
9393
94-
:param estimator: A trained estimator.
94+
:param estimator: A trained PyTorch estimator for classification or object detection.
9595
:param rotation_max: The maximum rotation applied to random patches. The value is expected to be in the
9696
range `[0, 180]`.
9797
:param scale_min: The minimum scaling applied to random patches. The value should be in the range `[0, 1]`,

0 commit comments

Comments
 (0)