Skip to content

Commit 16e84e4

Browse files
authored
Merge pull request #1360 from Trusted-AI/dev_jax_estimator
Implement a classifier on Jax framework
2 parents 4f41ea7 + 448958a commit 16e84e4

File tree

11 files changed

+662
-3
lines changed

11 files changed

+662
-3
lines changed

.github/workflows/ci-lingvo.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ jobs:
4747
sudo apt-get update
4848
sudo apt-get -y -q install ffmpeg libavcodec-extra
4949
python -m pip install --upgrade pip setuptools wheel
50-
pip install -q -r <(sed '/^scipy/d;/^matplotlib/d;/^pandas/d;/^statsmodels/d;/^numba/d' requirements_test.txt)
50+
pip install -q -r <(sed '/^scipy/d;/^matplotlib/d;/^pandas/d;/^statsmodels/d;/^numba/d;/^jax/d' requirements_test.txt)
5151
pip install scipy==1.5.4
5252
pip install matplotlib==3.3.4
5353
pip install pandas==1.1.5
@@ -58,6 +58,7 @@ jobs:
5858
pip install lingvo==${{ matrix.lingvo }}
5959
pip install tensorflow-addons==0.9.1
6060
pip install model-pruning-google-research==0.0.3
61+
pip install jax[cpu]==0.2.17
6162
pip list
6263
- name: Run ${{ matrix.name }} Tests
6364
run: ./run_tests.sh ${{ matrix.framework }}

art/experimental/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
"""
2+
This module contains the experimental Estimator API.
3+
"""
4+
from art.experimental.estimators.jax import JaxEstimator

art/experimental/estimators/__init__.py

Whitespace-only changes.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
"""
2+
Experimental classifiers.
3+
"""
4+
from art.experimental.estimators.classification.jax import JaxClassifier
Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
# MIT License
2+
#
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2021
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17+
# SOFTWARE.
18+
"""
19+
This module implements the classifier `JaxClassifier` for Jax models.
20+
"""
21+
from __future__ import absolute_import, division, print_function, unicode_literals
22+
23+
import logging
24+
import random
25+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
26+
27+
import numpy as np
28+
29+
from art.estimators.classification.classifier import (
30+
ClassGradientsMixin,
31+
ClassifierMixin,
32+
)
33+
from art.experimental.estimators.jax import JaxEstimator
34+
35+
if TYPE_CHECKING:
36+
from art.utils import CLIP_VALUES_TYPE, PREPROCESSING_TYPE
37+
from art.data_generators import DataGenerator
38+
from art.defences.preprocessor import Preprocessor
39+
from art.defences.postprocessor import Postprocessor
40+
41+
logger = logging.getLogger(__name__)
42+
43+
44+
class JaxClassifier(ClassGradientsMixin, ClassifierMixin, JaxEstimator): # lgtm [py/missing-call-to-init]
45+
"""
46+
This class implements a classifier with the Jax framework.
47+
"""
48+
49+
estimator_params = (
50+
JaxEstimator.estimator_params
51+
+ ClassifierMixin.estimator_params
52+
+ [
53+
"predict_func",
54+
"loss_func",
55+
"update_func",
56+
]
57+
)
58+
59+
def __init__(
60+
self,
61+
model: List,
62+
predict_func: Callable,
63+
loss_func: Callable,
64+
update_func: Callable,
65+
input_shape: Tuple[int, ...],
66+
nb_classes: int,
67+
channels_first: bool = False,
68+
clip_values: Optional["CLIP_VALUES_TYPE"] = None,
69+
preprocessing_defences: Union["Preprocessor", List["Preprocessor"], None] = None,
70+
postprocessing_defences: Union["Postprocessor", List["Postprocessor"], None] = None,
71+
preprocessing: "PREPROCESSING_TYPE" = (0.0, 1.0),
72+
) -> None:
73+
"""
74+
Initialization specifically for the Jax-based implementation.
75+
76+
:param model: Jax model, represented as a list of model parameters.
77+
:param predict_func: A function used to predict model output given the model and the input.
78+
:param loss_func: The loss function for which to compute gradients for training.
79+
:param update_func: The update function for which to train the model.
80+
:param input_shape: The shape of one input instance.
81+
:param nb_classes: The number of classes of the model.
82+
:param channels_first: Set channels first or last.
83+
:param clip_values: Tuple of the form `(min, max)` of floats or `np.ndarray` representing the minimum and
84+
maximum values allowed for features. If floats are provided, these will be used as the range of all
85+
features. If arrays are provided, each value will be considered the bound for a feature, thus
86+
the shape of clip values needs to match the total number of features.
87+
:param preprocessing_defences: Preprocessing defence(s) to be applied by the classifier.
88+
:param postprocessing_defences: Postprocessing defence(s) to be applied by the classifier.
89+
:param preprocessing: Tuple of the form `(subtrahend, divisor)` of floats or `np.ndarray` of values to be
90+
used for data preprocessing. The first value will be subtracted from the input. The input will then
91+
be divided by the second one.
92+
"""
93+
super().__init__(
94+
model=model,
95+
clip_values=clip_values,
96+
channels_first=channels_first,
97+
preprocessing_defences=preprocessing_defences,
98+
postprocessing_defences=postprocessing_defences,
99+
preprocessing=preprocessing,
100+
)
101+
102+
self._predict_func = predict_func
103+
self._loss_func = loss_func
104+
self._update_func = update_func
105+
self._nb_classes = nb_classes
106+
self._input_shape = input_shape
107+
108+
@property
109+
def model(self) -> List:
110+
return self._model
111+
112+
@property
113+
def input_shape(self) -> Tuple[int, ...]:
114+
"""
115+
Return the shape of one input sample.
116+
117+
:return: Shape of one input sample.
118+
"""
119+
return self._input_shape # type: ignore
120+
121+
@property
122+
def predict_func(self) -> Callable:
123+
"""
124+
Return the predict function.
125+
126+
:return: The predict function.
127+
"""
128+
return self._predict_func
129+
130+
@property
131+
def loss_func(self) -> Callable:
132+
"""
133+
Return the loss function.
134+
135+
:return: The loss function.
136+
"""
137+
return self._loss_func
138+
139+
@property
140+
def update_func(self) -> Callable:
141+
"""
142+
Return the update function.
143+
144+
:return: The update function.
145+
"""
146+
return self._update_func
147+
148+
def predict(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> np.ndarray:
149+
"""
150+
Perform prediction for a batch of inputs.
151+
152+
:param x: Input samples.
153+
:param batch_size: Size of batches.
154+
:return: Array of predictions of shape `(nb_inputs, nb_classes)`.
155+
"""
156+
# Apply preprocessing
157+
x_preprocessed, _ = self._apply_preprocessing(x, y=None, fit=False)
158+
159+
results_list = []
160+
161+
# Run prediction with batch processing
162+
num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))
163+
for m in range(num_batch):
164+
# Batch indexes
165+
begin, end = (
166+
m * batch_size,
167+
min((m + 1) * batch_size, x_preprocessed.shape[0]),
168+
)
169+
170+
output = self.predict_func(self.model, x_preprocessed[begin:end])
171+
results_list.append(output)
172+
173+
results = np.vstack(results_list)
174+
175+
# Apply postprocessing
176+
predictions = self._apply_postprocessing(preds=results, fit=False)
177+
178+
return predictions
179+
180+
def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 10, **kwargs) -> None:
181+
"""
182+
Fit the classifier on the training set `(x, y)`.
183+
184+
:param x: Training data.
185+
:param y: Target values.
186+
:param batch_size: Size of batches.
187+
:param nb_epochs: Number of epochs to use for training.
188+
:param kwargs: Dictionary of framework-specific arguments. This parameter is not currently supported for PyTorch
189+
and providing it takes no effect.
190+
"""
191+
# Apply preprocessing
192+
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y, fit=True)
193+
194+
num_batch = int(np.ceil(len(x_preprocessed) / float(batch_size)))
195+
ind = np.arange(len(x_preprocessed))
196+
197+
# Start training
198+
for _ in range(nb_epochs):
199+
# Shuffle the examples
200+
random.shuffle(ind)
201+
202+
# Train for one epoch
203+
for m in range(num_batch):
204+
i_batch = x_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
205+
o_batch = y_preprocessed[ind[m * batch_size : (m + 1) * batch_size]]
206+
self._model = self.update_func(self.model, i_batch, o_batch)
207+
208+
def loss_gradient(self, x: np.ndarray, y: np.ndarray, **kwargs) -> np.ndarray:
209+
"""
210+
Compute the gradient of the loss function w.r.t. `x`.
211+
212+
:param x: Sample input with shape as expected by the model.
213+
:param y: Target values.
214+
:return: Array of gradients of the same shape as `x`.
215+
"""
216+
from jax import grad
217+
218+
# Apply preprocessing
219+
x_preprocessed, y_preprocessed = self._apply_preprocessing(x, y, fit=False)
220+
221+
# Compute gradients
222+
grads = grad(self.loss_func, argnums=(0, 1))(self.model, x_preprocessed, y_preprocessed)[1]
223+
224+
assert grads.shape == x.shape
225+
226+
return grads.copy()
227+
228+
def fit_generator(self, generator: "DataGenerator", nb_epochs: int = 20, **kwargs) -> None:
229+
"""
230+
Fit the classifier using the generator that yields batches as specified.
231+
232+
:param generator: Batch generator providing `(x, y)` for each epoch.
233+
:param nb_epochs: Number of epochs to use for training.
234+
:param kwargs: Dictionary of framework-specific arguments.
235+
"""
236+
raise NotImplementedError
237+
238+
def class_gradient( # pylint: disable=W0221
239+
self, x: np.ndarray, label: Union[int, List[int], None] = None, **kwargs
240+
) -> np.ndarray:
241+
"""
242+
Compute per-class derivatives w.r.t. `x`.
243+
244+
:param x: Sample input with shape as expected by the model.
245+
:param label: Index of a specific per-class derivative. If an integer is provided, the gradient of that class
246+
output is computed for all samples. If multiple values as provided, the first dimension should
247+
match the batch size of `x`, and each value will be used as target for its corresponding sample in
248+
`x`. If `None`, then gradients for all classes will be computed for each sample.
249+
:return: Array of gradients of input features w.r.t. each class in the form
250+
`(batch_size, nb_classes, input_shape)` when computing for all classes, otherwise shape becomes
251+
`(batch_size, 1, input_shape)` when `label` parameter is specified.
252+
"""
253+
raise NotImplementedError
254+
255+
def get_activations(
256+
self,
257+
x: np.ndarray,
258+
layer: Optional[Union[int, str]] = None,
259+
batch_size: int = 128,
260+
framework: bool = False,
261+
) -> np.ndarray:
262+
"""
263+
Return the output of the specified layer for input `x`. `layer` is specified by layer index (between 0 and
264+
`nb_layers - 1`) or by name. The number of layers can be determined by counting the results returned by
265+
calling `layer_names`.
266+
267+
:param x: Input for computing the activations.
268+
:param layer: Layer for computing the activations
269+
:param batch_size: Size of batches.
270+
:param framework: If true, return the intermediate tensor representation of the activation.
271+
:return: The output of `layer`, where the first dimension is the batch size corresponding to `x`.
272+
"""
273+
raise NotImplementedError
274+
275+
def save(self, filename: str, path: Optional[str] = None) -> None:
276+
"""
277+
Save a model to file in the format specific to the backend framework.
278+
279+
:param filename: Name of the file where to store the model.
280+
:param path: Path of the folder where to store the model. If no path is specified, the model will be stored in
281+
the default data location of the library `ART_DATA_PATH`.
282+
"""
283+
raise NotImplementedError
284+
285+
def __getstate__(self) -> Dict[str, Any]:
286+
"""
287+
Use to ensure `JaxClassifier` can be pickled.
288+
289+
:return: State dictionary with instance parameters.
290+
"""
291+
raise NotImplementedError
292+
293+
def __setstate__(self, state: Dict[str, Any]) -> None:
294+
"""
295+
Use to ensure `JaxClassifier` can be unpickled.
296+
297+
:param state: State dictionary with instance parameters to restore.
298+
"""
299+
raise NotImplementedError
300+
301+
def compute_loss(self, x: np.ndarray, y: np.ndarray, **kwargs) -> np.ndarray:
302+
raise NotImplementedError

0 commit comments

Comments
 (0)