Skip to content

Commit ac25fd1

Browse files
authored
Merge pull request #825 from TS-Lee/main
Added an option to use probability values for model stealing. Added a…
2 parents 375f7d9 + f2ef604 commit ac25fd1

File tree

4 files changed

+769
-4
lines changed

4 files changed

+769
-4
lines changed

art/attacks/extraction/copycat_cnn.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ class CopycatCNN(ExtractionAttack):
5151
"batch_size_query",
5252
"nb_epochs",
5353
"nb_stolen",
54+
"use_probability",
5455
]
5556
_estimator_requirements = (BaseEstimator, ClassifierMixin)
5657

@@ -61,6 +62,7 @@ def __init__(
6162
batch_size_query: int = 1,
6263
nb_epochs: int = 10,
6364
nb_stolen: int = 1,
65+
use_probability: bool = False
6466
) -> None:
6567
"""
6668
Create a Copycat CNN attack instance.
@@ -77,6 +79,7 @@ def __init__(
7779
self.batch_size_query = batch_size_query
7880
self.nb_epochs = nb_epochs
7981
self.nb_stolen = nb_stolen
82+
self.use_probability = use_probability
8083
self._check_params()
8184

8285
def extract(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> "CLASSIFIER_TYPE":
@@ -139,8 +142,9 @@ def _query_label(self, x: np.ndarray) -> np.ndarray:
139142
:return: Target values (class labels) one-hot-encoded of shape (nb_samples, nb_classes).
140143
"""
141144
labels = self.estimator.predict(x=x, batch_size=self.batch_size_query)
142-
labels = np.argmax(labels, axis=1)
143-
labels = to_categorical(labels=labels, nb_classes=self.estimator.nb_classes)
145+
if not self.use_probability:
146+
labels = np.argmax(labels, axis=1)
147+
labels = to_categorical(labels=labels, nb_classes=self.estimator.nb_classes)
144148

145149
return labels
146150

@@ -156,3 +160,6 @@ def _check_params(self) -> None:
156160

157161
if not isinstance(self.nb_stolen, (int, np.int)) or self.nb_stolen <= 0:
158162
raise ValueError("The number of queries submitted to the victim classifier must be a positive integer.")
163+
164+
if not isinstance(self.use_probability, bool):
165+
raise ValueError("The argument `use_probability` has to be of type bool.")

art/attacks/extraction/knockoff_nets.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class KnockoffNets(ExtractionAttack):
5555
"sampling_strategy",
5656
"reward",
5757
"verbose",
58+
"use_probability",
5859
]
5960

6061
_estimator_requirements = (BaseEstimator, ClassifierMixin)
@@ -69,6 +70,7 @@ def __init__(
6970
sampling_strategy: str = "random",
7071
reward: str = "all",
7172
verbose: bool = True,
73+
use_probability: bool = False,
7274
) -> None:
7375
"""
7476
Create a KnockoffNets attack instance. Note, it is assumed that both the victim classifier and the thieved
@@ -92,6 +94,7 @@ def __init__(
9294
self.sampling_strategy = sampling_strategy
9395
self.reward = reward
9496
self.verbose = verbose
97+
self.use_probability = use_probability
9598
self._check_params()
9699

97100
def extract(self, x: np.ndarray, y: Optional[np.ndarray] = None, **kwargs) -> "CLASSIFIER_TYPE":
@@ -173,8 +176,9 @@ def _query_label(self, x: np.ndarray) -> np.ndarray:
173176
:return: Target values (class labels) one-hot-encoded of shape `(nb_samples, nb_classes)`.
174177
"""
175178
labels = self.estimator.predict(x=x, batch_size=self.batch_size_query)
176-
labels = np.argmax(labels, axis=1)
177-
labels = to_categorical(labels=labels, nb_classes=self.estimator.nb_classes)
179+
if not self.use_probability:
180+
labels = np.argmax(labels, axis=1)
181+
labels = to_categorical(labels=labels, nb_classes=self.estimator.nb_classes)
178182

179183
return labels
180184

@@ -403,3 +407,5 @@ def _check_params(self) -> None:
403407

404408
if not isinstance(self.verbose, bool):
405409
raise ValueError("The argument `verbose` has to be of type bool.")
410+
if not isinstance(self.use_probability, bool):
411+
raise ValueError("The argument `use_probability` has to be of type bool.")

notebooks/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,10 @@ and MNIST datasets.
157157
demonstrates the detection of adversarial examples using ART. The classifier model is a neural network of a ResNet
158158
architecture in Keras for the CIFAR-10 dataset.
159159

160+
## Model stealing / model theft / model extraction
161+
162+
[model-stealing-demo.ipynb](model-stealing-demo.ipynb) [[on nbviewer](https://nbviewer.jupyter.org/github/Trusted-AI/adversarial-robustness-toolbox/blob/main/notebooks/model-stealing-demo.ipynb)] demonstrates model stealing attacks and a reverse sigmoid defense against them.
163+
160164
## Poisoning
161165

162166
[poisoning_attack_svm.ipynb](poisoning_attack_svm.ipynb) [[on nbviewer](https://nbviewer.jupyter.org/github/Trusted-AI/adversarial-robustness-toolbox/blob/main/notebooks/poisoning_attack_svm.ipynb)]

0 commit comments

Comments
 (0)