Skip to content

Commit d7e9f2b

Browse files
committed
Fix typing
Signed-off-by: Beat Buesser <[email protected]>
1 parent 3378220 commit d7e9f2b

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

art/attacks/evasion/adversarial_patch/adversarial_patch_pytorch.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import logging
2727
import math
2828
from packaging.version import parse
29-
from typing import Any, TYPE_CHECKING
29+
from typing import Any, cast, TYPE_CHECKING
3030

3131
import numpy as np
3232
from tqdm.auto import trange
@@ -513,13 +513,15 @@ def generate( # type: ignore
513513

514514
if hasattr(self.estimator, "nb_classes"):
515515

516+
y_array: np.ndarray
517+
516518
if y is None: # pragma: no cover
517519
logger.info("Setting labels to estimator classification predictions.")
518-
y_array: np.ndarray = to_categorical(
520+
y_array = to_categorical(
519521
np.argmax(self.estimator.predict(x=x), axis=1), nb_classes=self.estimator.nb_classes
520522
)
521523
else:
522-
y_array: np.ndarray = y
524+
y_array = cast(np.ndarray, y)
523525

524526
y = check_and_transform_label_format(labels=y_array, nb_classes=self.estimator.nb_classes)
525527

0 commit comments

Comments
 (0)