Skip to content

Commit 7341bb3

Browse files
committed
Improve prediction performance
Signed-off-by: Beat Buesser <[email protected]>
1 parent 7811c6b commit 7341bb3

File tree

5 files changed

+32
-26
lines changed

5 files changed

+32
-26
lines changed

art/estimators/classification/keras.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ def __init__(
9797
layer with this index will be considered for computing gradients. For models with only one
9898
output layer this values is not required.
9999
"""
100+
import tensorflow as tf
101+
100102
super().__init__(
101103
model=model,
102104
clip_values=clip_values,
@@ -130,6 +132,12 @@ def __init__(
130132
self._input_shape = tuple(self._input.shape[1:])
131133
self._layer_names = self._get_layers()
132134

135+
@tf.function(reduce_retracing=True) # Compile this for speed
136+
def _forward_pass(model, x, training, batch_size):
137+
return model(x, training=training, batch_size=batch_size, verbose=False)
138+
139+
self._forward_pass = _forward_pass
140+
133141
@property
134142
def input_shape(self) -> tuple[int, ...]:
135143
"""
@@ -397,15 +405,14 @@ def predict(self, x: np.ndarray, batch_size: int = 128, training_mode: bool = Fa
397405
x_preprocessed, _ = self._apply_preprocessing(x, y=None, fit=False)
398406

399407
# Run predictions with batching
400-
if training_mode:
401-
predictions = self._model(x_preprocessed, training=training_mode, verbose=False)
402-
else:
403-
predictions = self._model.predict(x_preprocessed, batch_size=batch_size, verbose=False)
408+
predictions = self._forward_pass(
409+
self._model, x_preprocessed, training=training_mode, batch_size=batch_size
410+
) # Fast, compiled call
404411

405412
# Apply postprocessing
406-
predictions = self._apply_postprocessing(preds=predictions, fit=False)
413+
predictions_post = self._apply_postprocessing(preds=predictions.numpy(), fit=False)
407414

408-
return predictions
415+
return predictions_post
409416

410417
def fit(
411418
self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 20, verbose: bool = False, **kwargs

art/estimators/classification/tensorflow.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@ def __init__(
120120
else:
121121
self._reduce_labels = False
122122

123+
@tf.function(reduce_retracing=True) # Compile this for speed
124+
def _forward_pass(model, x, training):
125+
return model(x, training=training)
126+
127+
self._forward_pass = _forward_pass
128+
123129
@property
124130
def input_shape(self) -> tuple[int, ...]:
125131
"""
@@ -168,24 +174,13 @@ def predict(self, x: np.ndarray, batch_size: int = 128, training_mode: bool = Fa
168174
# Apply preprocessing
169175
x_preprocessed, _ = self._apply_preprocessing(x, y=None, fit=False)
170176

171-
# Run prediction with batch processing
172-
results_list = []
173-
num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))
174-
for m in range(num_batch):
175-
# Batch indexes
176-
begin, end = (
177-
m * batch_size,
178-
min((m + 1) * batch_size, x_preprocessed.shape[0]),
179-
)
180-
181-
# Run prediction
182-
results_list.append(self._model(x_preprocessed[begin:end], training=training_mode))
183-
184-
results = np.vstack(results_list)
177+
# Run predictions with batching
178+
predictions = self._forward_pass(self._model, x_preprocessed, training=training_mode) # Fast, compiled call
185179

186180
# Apply postprocessing
187-
predictions = self._apply_postprocessing(preds=results, fit=False)
188-
return predictions
181+
predictions_post = self._apply_postprocessing(preds=predictions.numpy(), fit=False)
182+
183+
return predictions_post
189184

190185
def _predict_framework(self, x: "tf.Tensor", training_mode: bool = False) -> "tf.Tensor":
191186
"""

conftest.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def _image_dl_estimator_for_attack(attack, defended=False, **kwargs):
150150
potential_classifier, _ = image_dl_estimator_defended(**kwargs)
151151
else:
152152
potential_classifier, _ = image_dl_estimator(**kwargs)
153-
image_dl_estimator_for_attack
153+
154154
classifier_list = [potential_classifier]
155155
classifier_tested = [
156156
potential_classifier
@@ -459,16 +459,19 @@ def _image_dl_estimator(functional=False, **kwargs):
459459
image_dl_estimator.__name__,
460460
framework,
461461
)
462+
462463
if framework == "tensorflow2":
463464
if wildcard is False and functional is False:
464465
classifier, sess = get_image_classifier_tf(**kwargs, framework=framework)
465466
return classifier, sess
467+
466468
if framework == "pytorch":
467469
if not wildcard:
468470
if functional:
469471
classifier = get_image_classifier_pt_functional(**kwargs)
470472
else:
471473
classifier = get_image_classifier_pt(**kwargs)
474+
472475
if framework == "kerastf":
473476
if wildcard:
474477
classifier = get_image_classifier_kr_tf_with_wildcard(**kwargs)

tests/attacks/evasion/test_sign_opt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def fix_get_mnist_subset_large(get_mnist_dataset):
5454
def test_tabular(art_warning, tabular_dl_estimator, framework, get_iris_dataset, clipped_classifier, targeted):
5555
try:
5656
classifier = tabular_dl_estimator(clipped=clipped_classifier)
57-
attack = SignOPTAttack(classifier, targeted=targeted, max_iter=1000, query_limit=4000, verbose=False)
57+
attack = SignOPTAttack(classifier, targeted=targeted, num_trial=10, max_iter=100, query_limit=40, verbose=True)
5858
if targeted:
5959
backend_targeted_tabular(attack, get_iris_dataset)
6060
else:

tests/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -665,14 +665,15 @@ def get_image_classifier_kr_tf(loss_name="categorical_crossentropy", loss_type="
665665
import tensorflow as tf
666666

667667
# pylint: disable=E0401
668-
from tensorflow.keras.layers import Conv2D, Dense, Flatten, MaxPooling2D
668+
from tensorflow.keras.layers import Conv2D, Dense, Flatten, MaxPooling2D, Input
669669
from tensorflow.keras.models import Sequential
670670

671671
from art.estimators.classification.keras import KerasClassifier
672672

673673
# Create simple CNN
674674
model = Sequential()
675-
model.add(Conv2D(1, kernel_size=(7, 7), activation="relu", input_shape=(28, 28, 1)))
675+
model.add(Input((28, 28, 1)))
676+
model.add(Conv2D(1, kernel_size=(7, 7), activation="relu"))
676677
model.layers[-1].set_weights(
677678
[_kr_tf_weights_loader("MNIST", "W", "CONV2D"), _kr_tf_weights_loader("MNIST", "B", "CONV2D")]
678679
)

0 commit comments

Comments
 (0)