Skip to content

Commit 8651b4c

Browse files
authored
Merge pull request #1210 from Trusted-AI/feature/pytorch-wrapper-mp3-defense
Pytorch wrapper MP3 defense
2 parents 5ed7e20 + 0ca45ab commit 8651b4c

File tree

3 files changed

+150
-5
lines changed

3 files changed

+150
-5
lines changed

art/defences/preprocessor/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
"""
44
from art.defences.preprocessor.feature_squeezing import FeatureSqueezing
55
from art.defences.preprocessor.gaussian_augmentation import GaussianAugmentation
6-
from art.defences.preprocessor.inverse_gan import InverseGAN, DefenseGAN
6+
from art.defences.preprocessor.inverse_gan import DefenseGAN, InverseGAN
77
from art.defences.preprocessor.jpeg_compression import JpegCompression
88
from art.defences.preprocessor.label_smoothing import LabelSmoothing
99
from art.defences.preprocessor.mp3_compression import Mp3Compression
10+
from art.defences.preprocessor.mp3_compression_pytorch import Mp3CompressionPyTorch
1011
from art.defences.preprocessor.pixel_defend import PixelDefend
1112
from art.defences.preprocessor.preprocessor import Preprocessor
1213
from art.defences.preprocessor.resample import Resample

art/defences/preprocessor/mp3_compression.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def __call__(self, x: np.ndarray, y: Optional[np.ndarray] = None) -> Tuple[np.nd
7272
Apply MP3 compression to sample `x`.
7373
7474
:param x: Sample to compress with shape `(batch_size, length, channel)` or an array of sample arrays with shape
75-
(length,) or (length, channel). `x` values are recommended to be of type `np.int16`.
75+
(length,) or (length, channel).
7676
:param y: Labels of the sample `x`. This function does not affect them in any way.
7777
:return: Compressed sample.
7878
"""
@@ -84,11 +84,12 @@ def wav_to_mp3(x, sample_rate):
8484
from pydub import AudioSegment
8585
from scipy.io.wavfile import write
8686

87+
x_dtype = x.dtype
8788
normalized = bool(x.min() >= -1.0 and x.max() <= 1.0)
88-
if x.dtype != np.int16 and not normalized:
89+
if x_dtype != np.int16 and not normalized:
8990
# input is not of type np.int16 and seems to be unnormalized. Therefore casting to np.int16.
9091
x = x.astype(np.int16)
91-
elif x.dtype != np.int16 and normalized:
92+
elif x_dtype != np.int16 and normalized:
9293
# x is not of type np.int16 and seems to be normalized. Therefore undoing normalization and
9394
# casting to np.int16.
9495
x = (x * 2 ** 15).astype(np.int16)
@@ -100,7 +101,19 @@ def wav_to_mp3(x, sample_rate):
100101
tmp_wav.close()
101102
tmp_mp3.close()
102103
x_mp3 = np.array(audio_segment.get_array_of_samples()).reshape((-1, audio_segment.channels))
103-
return x_mp3
104+
105+
# WARNING: Sometimes we *still* need to manually resize x_mp3 to original length.
106+
# This should not be the case, e.g. see https://github.com/jiaaro/pydub/issues/474
107+
if x.shape[0] != x_mp3.shape[0]:
108+
logger.warning(
109+
"Lengths original input and compressed output don't match. Truncating compressed result."
110+
)
111+
x_mp3 = x_mp3[: x.shape[0]]
112+
113+
if normalized:
114+
# x was normalized. Therefore normalizing x_mp3.
115+
x_mp3 = x_mp3 * 2 ** -15
116+
return x_mp3.astype(x_dtype)
104117

105118
if x.dtype != np.object and x.ndim != 3:
106119
raise ValueError("Mp3 compression can only be applied to temporal data across at least one channel.")
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# MIT License
2+
#
3+
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2021
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6+
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
7+
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
8+
# persons to whom the Software is furnished to do so, subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
11+
# Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
14+
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
15+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
16+
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
17+
# SOFTWARE.
18+
"""
19+
This module implements a wrapper for MP3 compression defence.
20+
21+
| Please keep in mind the limitations of defences. For details on how to evaluate classifier security in general,
22+
see https://arxiv.org/abs/1902.06705.
23+
"""
24+
from __future__ import absolute_import, division, print_function, unicode_literals
25+
26+
import logging
27+
from typing import TYPE_CHECKING, Optional, Tuple
28+
29+
import numpy as np
30+
31+
from art.defences.preprocessor.mp3_compression import Mp3Compression
32+
from art.defences.preprocessor.preprocessor import PreprocessorPyTorch
33+
34+
logger = logging.getLogger(__name__)
35+
36+
if TYPE_CHECKING:
37+
# pylint: disable=C0412
38+
import torch
39+
40+
41+
class Mp3CompressionPyTorch(PreprocessorPyTorch):
42+
"""
43+
Implement the MP3 compression defense approach.
44+
"""
45+
46+
params = ["channels_first", "sample_rate", "verbose"]
47+
48+
def __init__(
49+
self,
50+
sample_rate: int,
51+
channels_first: bool = False,
52+
apply_fit: bool = False,
53+
apply_predict: bool = True,
54+
verbose: bool = False,
55+
device_type: str = "gpu",
56+
):
57+
"""
58+
Create an instance of MP3 compression.
59+
60+
:param sample_rate: Specifies the sampling rate of sample.
61+
:param channels_first: Set channels first or last.
62+
:param apply_fit: True if applied during fitting/training.
63+
:param apply_predict: True if applied during predicting.
64+
:param verbose: Show progress bars.
65+
:param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`.
66+
"""
67+
import torch # lgtm [py/repeated-import]
68+
from torch.autograd import Function
69+
70+
super().__init__(is_fitted=True, apply_fit=apply_fit, apply_predict=apply_predict)
71+
self.channels_first = channels_first
72+
self.sample_rate = sample_rate
73+
self.verbose = verbose
74+
self._check_params()
75+
76+
# Set device
77+
if device_type == "cpu" or not torch.cuda.is_available():
78+
self._device = torch.device("cpu")
79+
else:
80+
cuda_idx = torch.cuda.current_device()
81+
self._device = torch.device("cuda:{}".format(cuda_idx))
82+
83+
self.compression_numpy = Mp3Compression(
84+
sample_rate=sample_rate,
85+
channels_first=channels_first,
86+
apply_fit=apply_fit,
87+
apply_predict=apply_predict,
88+
verbose=verbose,
89+
)
90+
91+
class CompressionPyTorchNumpy(Function):
92+
"""
93+
Function running Preprocessor.
94+
"""
95+
96+
@staticmethod
97+
def forward(ctx, input): # pylint: disable=W0622,W0221
98+
numpy_input = input.detach().cpu().numpy()
99+
result, _ = self.compression_numpy(numpy_input)
100+
return input.new(result)
101+
102+
@staticmethod
103+
def backward(ctx, grad_output): # pylint: disable=W0221
104+
numpy_go = grad_output.cpu().numpy()
105+
np.expand_dims(input, axis=[0, 2])
106+
result = self.compression_numpy.estimate_gradient(None, numpy_go)
107+
result = result.squeeze()
108+
return grad_output.new(result)
109+
110+
self._compression_pytorch_numpy = CompressionPyTorchNumpy
111+
112+
def forward(
113+
self, x: "torch.Tensor", y: Optional["torch.Tensor"] = None
114+
) -> Tuple["torch.Tensor", Optional["torch.Tensor"]]:
115+
"""
116+
Apply MP3 compression to sample `x`.
117+
118+
:param x: Sample to compress with shape `(length, channel)` or an array of sample arrays with shape
119+
(length,) or (length, channel).
120+
:param y: Labels of the sample `x`. This function does not affect them in any way.
121+
:return: Compressed sample.
122+
"""
123+
x_compressed = self._compression_pytorch_numpy.apply(x)
124+
return x_compressed, y
125+
126+
def _check_params(self) -> None:
127+
if not (isinstance(self.sample_rate, (int, np.int)) and self.sample_rate > 0):
128+
raise ValueError("Sample rate be must a positive integer.")
129+
130+
if not isinstance(self.verbose, bool):
131+
raise ValueError("The argument `verbose` has to be of type bool.")

0 commit comments

Comments
 (0)