Skip to content

Commit 85bb85c

Browse files
author
Beat Buesser
committed
Fix style checks for DGM poisoning attacks
Signed-off-by: Beat Buesser <[email protected]>
1 parent 3569d85 commit 85bb85c

File tree

7 files changed

+17
-10
lines changed

7 files changed

+17
-10
lines changed

art/attacks/poisoning/backdoor_attack_dgm/backdoor_attack_dgm_trail.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import numpy as np
2727

28-
from art.estimators.gan.tensorflow_gan import TensorFlowV2GAN
28+
from art.estimators.gan.tensorflow import TensorFlowV2GAN
2929
from art.attacks.attack import PoisoningAttackGenerator
3030

3131
logger = logging.getLogger(__name__)

art/estimators/gan/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
"""
22
GAN Estimator API.
33
"""
4-
from art.estimators.gan.tensorflow_gan import TensorFlowV2GAN
4+
from art.estimators.gan.tensorflow import TensorFlowV2GAN

art/estimators/gan/tensorflow_gan.py renamed to art/estimators/gan/tensorflow.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"""
1919
This module creates GANs using the TensorFlow ML Framework
2020
"""
21-
from typing import Tuple, TYPE_CHECKING
21+
from typing import Tuple, TYPE_CHECKING, Union
2222

2323
import numpy as np
2424
from art.estimators.tensorflow import TensorFlowV2Estimator
@@ -52,17 +52,15 @@ def __init__(
5252
:param generator_optimizer_fct: the optimizer function to use for the generator
5353
:param discriminator_optimizer_fct: the optimizer function to use for the discriminator
5454
"""
55-
super().__init__(model=None, clip_values=None)
55+
super().__init__(model=None, clip_values=None, channels_first=None)
5656
self._generator = generator
5757
self._discriminator_classifier = discriminator
5858
self._generator_loss = generator_loss
5959
self._generator_optimizer_fct = generator_optimizer_fct
6060
self._discriminator_loss = discriminator_loss
6161
self._discriminator_optimizer_fct = discriminator_optimizer_fct
6262

63-
def predict(
64-
self, x: np.ndarray, batch_size: int = 128, **kwargs
65-
) -> np.ndarray: # lgtm [py/inheritance/incorrect-overridden-signature]
63+
def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> np.ndarray:
6664
"""
6765
Generates a sample
6866
@@ -166,3 +164,11 @@ def discriminator_optimizer_fct(self) -> "tf.Tensor":
166164
:return: the optimizer function for the discriminator
167165
"""
168166
return self._discriminator_optimizer_fct
167+
168+
def loss_gradient(self, x, y, **kwargs):
169+
raise NotImplementedError
170+
171+
def get_activations(
172+
self, x: np.ndarray, layer: Union[int, str], batch_size: int, framework: bool = False
173+
) -> np.ndarray:
174+
raise NotImplementedError

art/estimators/tensorflow.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs):
112112
"""
113113
Perform prediction of the neural network for samples `x`.
114114
115+
:param x: Samples of shape (nb_samples, nb_features) or (nb_samples, nb_pixels_1, nb_pixels_2,
115116
:param x: Samples of shape (nb_samples, nb_features) or (nb_samples, nb_pixels_1, nb_pixels_2,
116117
nb_channels) or (nb_samples, nb_channels, nb_pixels_1, nb_pixels_2).
117118
:param batch_size: Batch size.

examples/backdoor_attack_dgm_trail.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import tensorflow as tf
99

1010
from art.attacks.poisoning.backdoor_attack_dgm.backdoor_attack_dgm_trail import BackdoorAttackDGMTrailTensorFlowV2
11-
from art.estimators.gan.tensorflow_gan import TensorFlowV2GAN
11+
from art.estimators.gan.tensorflow import TensorFlowV2GAN
1212
from art.estimators.generation.tensorflow import TensorFlowV2Generator
1313
from art.estimators.classification.tensorflow import TensorFlowV2Classifier
1414

tests/attacks/poison/test_backdoor_attack_dgm_trail.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_poison_estimator_trail(art_warning, get_default_mnist_subset, image_dl_
4545
z_trigger=z_trigger, x_target=x_target, images=train_images, max_iter=2
4646
)
4747
assert isinstance(generator, TensorFlowV2Generator)
48-
np.testing.assert_approx_equal(round(trail_attack.fidelity(z_trigger, x_target).numpy(), 3), 0.398)
48+
assert pytest.approx(trail_attack.fidelity(z_trigger, x_target).numpy(), 0.398, 0.05)
4949

5050
except ARTTestException as e:
5151
art_warning(e)

tests/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from art.estimators.classification.tensorflow import TensorFlowV2Classifier
3434
from art.estimators.encoding.tensorflow import TensorFlowEncoder
3535
from art.estimators.generation.tensorflow import TensorFlowGenerator, TensorFlowV2Generator
36-
from art.estimators.gan.tensorflow_gan import TensorFlowV2GAN
36+
from art.estimators.gan.tensorflow import TensorFlowV2GAN
3737
from art.utils import load_dataset
3838

3939
logger = logging.getLogger(__name__)

0 commit comments

Comments
 (0)