Skip to content

Commit 66825ec

Browse files
authored
Merge pull request #1624 from TS-Lee/gm-shape-fix
Gradient matching unknown input shape support
2 parents 101b339 + 7ec4b2a commit 66825ec

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

art/attacks/poisoning/gradient_matching_attack.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ def __initialize_poison_tensorflow(
151151
152152
:param x_trigger: A list of samples to use as triggers.
153153
:param y_trigger: A list of target classes to classify the triggers into.
154-
:param x_train: A list of training data to poison a portion of.
155-
:param y_train: A list of labels for x_train.
154+
:param x_poison: A list of training data to poison a portion of.
155+
:param y_poison: A list of true labels for x_poison.
156156
"""
157157
# pylint: disable=no-name-in-module
158158
from tensorflow.keras import backend as K
@@ -190,7 +190,7 @@ def _weight_grad(classifier: TensorFlowV2Classifier, x: tf.Tensor, target: tf.Te
190190
y_true_poison = Input(shape=np.shape(y_poison)[1:])
191191
embedding_layer = Embedding(
192192
len(x_poison),
193-
np.prod(input_poison.shape[1:]),
193+
np.prod(x_poison.shape[1:]),
194194
embeddings_initializer=tf.keras.initializers.RandomNormal(stddev=self.epsilon * 0.01),
195195
)
196196
embeddings = embedding_layer(input_indices)

0 commit comments

Comments
 (0)