2323"""
2424from __future__ import absolute_import , division , print_function , unicode_literals
2525
26+ from collections import OrderedDict
2627import logging
28+ import os
2729import time
2830from typing import Optional , Tuple , TYPE_CHECKING , List , Dict , Union
29- from collections import OrderedDict
30- import six
3131
32+ import six
3233import numpy as np
3334from tqdm .auto import trange
34- from art import config
3535
36+ from art import config
3637from art .defences .trainer .adversarial_trainer_oaat import AdversarialTrainerOAAT
3738from art .estimators .classification .pytorch import PyTorchClassifier
3839from art .data_generators import DataGenerator
@@ -71,7 +72,7 @@ def __init__(
7172 :param lpips_classifier: Weight averaging model for calculating activations.
7273 :param list_avg_models: list of models for weight averaging.
7374 :param attack: attack to use for data augmentation in adversarial training.
74- :param train_params: training parmaters ' dictionary related to adversarial training
75+ :param train_params: training parameters ' dictionary related to adversarial training
7576 """
7677 super ().__init__ (classifier , proxy_classifier , lpips_classifier , list_avg_models , attack , train_params )
7778 self ._classifier : PyTorchClassifier
@@ -104,7 +105,6 @@ def fit(
104105 :param kwargs: Dictionary of framework-specific arguments. These will be passed as such to the `fit` function of
105106 the target classifier.
106107 """
107- import os
108108 import torch
109109
110110 logger .info ("Performing adversarial training with OAAT protocol" )
@@ -302,7 +302,6 @@ def fit_generator(
302302 :param kwargs: Dictionary of framework-specific arguments. These will be passed as such to the `fit` function of
303303 the target classifier.
304304 """
305- import os
306305 import torch
307306
308307 logger .info ("Performing adversarial training with OAAT protocol" )
@@ -895,7 +894,7 @@ def update_learning_rate(
895894 else :
896895 raise ValueError (f"lr_schedule { lr_schedule } not supported" )
897896
898- def _attack_lpips ( # type: ignore
897+ def _attack_lpips (
899898 self ,
900899 x : np .ndarray ,
901900 y : np .ndarray ,
@@ -993,7 +992,7 @@ def _one_step_adv_example(
993992
994993 return x_adv
995994
996- def _compute_perturbation ( # pylint: disable=W0221
995+ def _compute_perturbation (
997996 self , x : "torch.Tensor" , x_init : "torch.Tensor" , y : "torch.Tensor" , training_mode : bool = False
998997 ) -> "torch.Tensor" :
999998 """
@@ -1010,9 +1009,6 @@ def _compute_perturbation( # pylint: disable=W0221
10101009 """
10111010 import torch
10121011
1013- # Pick a small scalar to avoid division by 0
1014- tol = 10e-8
1015-
10161012 self ._classifier .model .train (mode = training_mode )
10171013 self ._lpips_classifier .model .train (mode = training_mode )
10181014
@@ -1124,17 +1120,17 @@ def _compute_perturbation( # pylint: disable=W0221
11241120
11251121 elif self ._train_params ["norm" ] == 1 :
11261122 ind = tuple (range (1 , len (x .shape )))
1127- grad = grad / (torch .sum (grad .abs (), dim = ind , keepdims = True ) + tol ) # type: ignore
1123+ grad = grad / (torch .sum (grad .abs (), dim = ind , keepdims = True ) + EPS ) # type: ignore
11281124
11291125 elif self ._train_params ["norm" ] == 2 :
11301126 ind = tuple (range (1 , len (x .shape )))
1131- grad = grad / (torch .sqrt (torch .sum (grad * grad , axis = ind , keepdims = True )) + tol ) # type: ignore
1127+ grad = grad / (torch .sqrt (torch .sum (grad * grad , axis = ind , keepdims = True )) + EPS ) # type: ignore
11321128
11331129 assert x .shape == grad .shape
11341130
11351131 return grad
11361132
1137- def _apply_perturbation ( # pylint: disable=W0221
1133+ def _apply_perturbation (
11381134 self , x : "torch.Tensor" , perturbation : "torch.Tensor" , eps_step : Union [int , float , np .ndarray ]
11391135 ) -> "torch.Tensor" :
11401136 """
@@ -1173,8 +1169,6 @@ def _projection(
11731169 """
11741170 import torch
11751171
1176- # Pick a small scalar to avoid division by 0
1177- tol = 10e-8
11781172 values_tmp = values .reshape (values .shape [0 ], - 1 )
11791173
11801174 if norm_p == 2 :
@@ -1187,7 +1181,7 @@ def _projection(
11871181 values_tmp
11881182 * torch .min (
11891183 torch .tensor ([1.0 ], dtype = torch .float32 ).to (self ._classifier .device ),
1190- eps / (torch .norm (values_tmp , p = 2 , dim = 1 ) + tol ),
1184+ eps / (torch .norm (values_tmp , p = 2 , dim = 1 ) + EPS ),
11911185 ).unsqueeze_ (- 1 )
11921186 )
11931187
@@ -1201,14 +1195,14 @@ def _projection(
12011195 values_tmp
12021196 * torch .min (
12031197 torch .tensor ([1.0 ], dtype = torch .float32 ).to (self ._classifier .device ),
1204- eps / (torch .norm (values_tmp , p = 1 , dim = 1 ) + tol ),
1198+ eps / (torch .norm (values_tmp , p = 1 , dim = 1 ) + EPS ),
12051199 ).unsqueeze_ (- 1 )
12061200 )
12071201
12081202 elif norm_p in [np .inf , "inf" ]:
12091203 if isinstance (eps , np .ndarray ):
1210- eps = eps * np .ones_like (values .cpu ())
1211- eps = eps .reshape ([eps .shape [0 ], - 1 ]) # type: ignore
1204+ eps_array = eps * np .ones_like (values .cpu ())
1205+ eps = eps_array .reshape ([eps_array .shape [0 ], - 1 ])
12121206
12131207 values_tmp = values_tmp .sign () * torch .min (
12141208 values_tmp .abs (), torch .tensor ([eps ], dtype = torch .float32 ).to (self ._classifier .device )
0 commit comments