|
18 | 18 | """ |
19 | 19 | This module creates GANs using the TensorFlow ML Framework |
20 | 20 | """ |
21 | | -from typing import Tuple, TYPE_CHECKING |
| 21 | +from typing import Tuple, TYPE_CHECKING, Union |
22 | 22 |
|
23 | 23 | import numpy as np |
24 | 24 | from art.estimators.tensorflow import TensorFlowV2Estimator |
@@ -52,17 +52,15 @@ def __init__( |
52 | 52 | :param generator_optimizer_fct: the optimizer function to use for the generator |
53 | 53 | :param discriminator_optimizer_fct: the optimizer function to use for the discriminator |
54 | 54 | """ |
55 | | - super().__init__(model=None, clip_values=None) |
| 55 | + super().__init__(model=None, clip_values=None, channels_first=None) |
56 | 56 | self._generator = generator |
57 | 57 | self._discriminator_classifier = discriminator |
58 | 58 | self._generator_loss = generator_loss |
59 | 59 | self._generator_optimizer_fct = generator_optimizer_fct |
60 | 60 | self._discriminator_loss = discriminator_loss |
61 | 61 | self._discriminator_optimizer_fct = discriminator_optimizer_fct |
62 | 62 |
|
63 | | - def predict( |
64 | | - self, x: np.ndarray, batch_size: int = 128, **kwargs |
65 | | - ) -> np.ndarray: # lgtm [py/inheritance/incorrect-overridden-signature] |
| 63 | + def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> np.ndarray: |
66 | 64 | """ |
67 | 65 | Generates a sample |
68 | 66 |
|
@@ -166,3 +164,11 @@ def discriminator_optimizer_fct(self) -> "tf.Tensor": |
166 | 164 | :return: the optimizer function for the discriminator |
167 | 165 | """ |
168 | 166 | return self._discriminator_optimizer_fct |
| 167 | + |
| 168 | + def loss_gradient(self, x, y, **kwargs): |
| 169 | + raise NotImplementedError |
| 170 | + |
| 171 | + def get_activations( |
| 172 | + self, x: np.ndarray, layer: Union[int, str], batch_size: int, framework: bool = False |
| 173 | + ) -> np.ndarray: |
| 174 | + raise NotImplementedError |
0 commit comments