Skip to content

Commit 95067ea

Browse files
Irina NicolaeIrina Nicolae
authored andcommitted
Merge JPEG compression
2 parents 23d00b7 + f115284 commit 95067ea

File tree

2 files changed

+172
-0
lines changed

2 files changed

+172
-0
lines changed

art/defences/jpeg_compression.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
from __future__ import absolute_import, division, print_function, unicode_literals
2+
3+
from io import BytesIO
4+
import logging
5+
6+
import numpy as np
7+
from PIL import Image
8+
9+
from art.defences.preprocessor import Preprocessor
10+
from art import NUMPY_DTYPE
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
class JpegCompression(Preprocessor):
16+
"""
17+
Implement the jpeg compression defence approach.
18+
"""
19+
params = ['quality', 'channel_index']
20+
21+
def __init__(self, quality=50, channel_index=3):
22+
"""
23+
Create an instance of jpeg compression.
24+
25+
:param quality: The image quality, on a scale from 1 (worst) to 95 (best). Values above 95 should be avoided.
26+
:type quality: `int`
27+
:param channel_index: Index of the axis in data containing the color channels or features.
28+
:type channel_index: `int`
29+
"""
30+
super(JpegCompression, self).__init__()
31+
self._is_fitted = True
32+
self.set_params(quality=quality, channel_index=channel_index)
33+
34+
def __call__(self, x, y=None, quality=None):
35+
"""
36+
Apply jpeg compression to sample `x`.
37+
38+
:param x: Sample to compress with shape `(batch_size, width, height, depth)`.
39+
:type x: `np.ndarray`
40+
:param y: Labels of the sample `x`. This function does not affect them in any way.
41+
:type y: `np.ndarray`
42+
:param quality: The image quality, on a scale from 1 (worst) to 95 (best). Values above 95 should be avoided.
43+
:type quality: `int`
44+
:return: compressed sample
45+
:rtype: `np.ndarray`
46+
"""
47+
if quality is not None:
48+
self.set_params(quality=quality)
49+
50+
assert self.channel_index < len(x.shape)
51+
52+
# Swap channel index
53+
if self.channel_index < 3:
54+
x_ = np.swapaxes(x, self.channel_index, 3)
55+
else:
56+
x_ = x.copy()
57+
58+
# Convert into `uint8`
59+
x_ = x_ * 255
60+
x_ = x_.astype("uint8")
61+
62+
# Convert to 'L' mode
63+
if x_.shape[-1] == 1:
64+
x_ = np.reshape(x_, x_.shape[:-1])
65+
66+
# Compress one image per time
67+
for i, xi in enumerate(x_):
68+
if len(xi.shape) == 2:
69+
xi = Image.fromarray(xi, mode='L')
70+
elif xi.shape[-1] == 3:
71+
xi = Image.fromarray(xi, mode='RGB')
72+
else:
73+
logger.log(level=40, msg="Currently only support `RGB` and `L` images.")
74+
raise NotImplementedError("Currently only support `RGB` and `L` images.")
75+
76+
out = BytesIO()
77+
xi.save(out, format="jpeg", quality=self.quality)
78+
xi = Image.open(out)
79+
xi = np.array(xi)
80+
x_[i] = xi
81+
del out
82+
83+
# Expand dim if black/white images
84+
if len(x_.shape) < 4:
85+
x_ = np.expand_dims(x_, 3)
86+
87+
# Convert to old dtype
88+
x_ = x_ / 255.0
89+
x_ = x_.astype(NUMPY_DTYPE)
90+
91+
# Swap channel index
92+
if self.channel_index < 3:
93+
x_ = np.swapaxes(x_, self.channel_index, 3)
94+
95+
return x_
96+
97+
def fit(self, x, y=None, **kwargs):
98+
"""
99+
No parameters to learn for this method; do nothing.
100+
"""
101+
pass
102+
103+
def set_params(self, **kwargs):
104+
"""
105+
Take in a dictionary of parameters and applies defence-specific checks before saving them as attributes.
106+
107+
:param quality: The image quality, on a scale from 1 (worst) to 95 (best). Values above 95 should be avoided.
108+
:type quality: `int`
109+
:param channel_index: Index of the axis in data containing the color channels or features.
110+
:type channel_index: `int`
111+
"""
112+
# Save defense-specific parameters
113+
super(JpegCompression, self).set_params(**kwargs)
114+
115+
if type(self.quality) is not int or self.quality <= 0 or self.quality > 100:
116+
raise ValueError('Image quality must be a positive integer and smaller than 101.')
117+
118+
if type(self.channel_index) is not int or self.channel_index <= 0:
119+
raise ValueError('Data channel must be a positive integer. The batch dimension is not a valid channel.')
120+
121+
return True
122+
123+
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from __future__ import absolute_import, division, print_function, unicode_literals
2+
3+
import logging
4+
import unittest
5+
6+
from tensorflow.examples.tutorials.mnist import input_data
7+
from keras.datasets import cifar10
8+
import numpy as np
9+
10+
from art.defences.jpeg_compression import JpegCompression
11+
12+
logger = logging.getLogger('testLogger')
13+
14+
15+
class TestJpegCompression(unittest.TestCase):
16+
def test_one_channel(self):
17+
mnist = input_data.read_data_sets("tmp/MNIST_data/")
18+
x = np.reshape(mnist.test.images[0:2], (-1, 28, 28, 1))
19+
preprocess = JpegCompression()
20+
compressed_x = preprocess(x, quality=70)
21+
self.assertTrue((compressed_x.shape == x.shape))
22+
self.assertTrue((compressed_x <= 1.0).all())
23+
self.assertTrue((compressed_x >= 0.0).all())
24+
25+
def test_three_channels(self):
26+
(train_features, train_labels), (test_data, test_label) = cifar10.load_data()
27+
x = train_features[:2] / 255.0
28+
preprocess = JpegCompression()
29+
compressed_x = preprocess(x, quality=80)
30+
self.assertTrue((compressed_x.shape == x.shape))
31+
self.assertTrue((compressed_x <= 1.0).all())
32+
self.assertTrue((compressed_x >= 0.0).all())
33+
34+
def test_channel_index(self):
35+
(train_features, train_labels), (test_data, test_label) = cifar10.load_data()
36+
x = train_features[:2] / 255.0
37+
x = np.swapaxes(x, 1, 3)
38+
preprocess = JpegCompression(channel_index=1)
39+
compressed_x = preprocess(x, quality=80)
40+
self.assertTrue((compressed_x.shape == x.shape))
41+
self.assertTrue((compressed_x <= 1.0).all())
42+
self.assertTrue((compressed_x >= 0.0).all())
43+
44+
45+
if __name__ == '__main__':
46+
unittest.main()
47+
48+
49+

0 commit comments

Comments
 (0)