|
| 1 | +# MIT License |
| 2 | +# |
| 3 | +# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2021 |
| 4 | +# |
| 5 | +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated |
| 6 | +# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the |
| 7 | +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit |
| 8 | +# persons to whom the Software is furnished to do so, subject to the following conditions: |
| 9 | +# |
| 10 | +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the |
| 11 | +# Software. |
| 12 | +# |
| 13 | +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE |
| 14 | +# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 15 | +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, |
| 16 | +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 17 | +# SOFTWARE. |
| 18 | +""" |
| 19 | +This module implements a wrapper for MP3 compression defence. |
| 20 | +
|
| 21 | +| Please keep in mind the limitations of defences. For details on how to evaluate classifier security in general, |
| 22 | + see https://arxiv.org/abs/1902.06705. |
| 23 | +""" |
| 24 | +from __future__ import absolute_import, division, print_function, unicode_literals |
| 25 | + |
| 26 | +import logging |
| 27 | +from typing import TYPE_CHECKING, Optional, Tuple |
| 28 | + |
| 29 | +import numpy as np |
| 30 | + |
| 31 | +from art.defences.preprocessor.mp3_compression import Mp3Compression |
| 32 | +from art.defences.preprocessor.preprocessor import PreprocessorPyTorch |
| 33 | + |
| 34 | +logger = logging.getLogger(__name__) |
| 35 | + |
| 36 | +if TYPE_CHECKING: |
| 37 | + # pylint: disable=C0412 |
| 38 | + import torch |
| 39 | + |
| 40 | + |
| 41 | +class Mp3CompressionPyTorch(PreprocessorPyTorch): |
| 42 | + """ |
| 43 | + Implement the MP3 compression defense approach. |
| 44 | + """ |
| 45 | + |
| 46 | + params = ["channels_first", "sample_rate", "verbose"] |
| 47 | + |
| 48 | + def __init__( |
| 49 | + self, |
| 50 | + sample_rate: int, |
| 51 | + channels_first: bool = False, |
| 52 | + apply_fit: bool = False, |
| 53 | + apply_predict: bool = True, |
| 54 | + verbose: bool = False, |
| 55 | + device_type: str = "gpu", |
| 56 | + ): |
| 57 | + """ |
| 58 | + Create an instance of MP3 compression. |
| 59 | +
|
| 60 | + :param sample_rate: Specifies the sampling rate of sample. |
| 61 | + :param channels_first: Set channels first or last. |
| 62 | + :param apply_fit: True if applied during fitting/training. |
| 63 | + :param apply_predict: True if applied during predicting. |
| 64 | + :param verbose: Show progress bars. |
| 65 | + :param device_type: Type of device on which the classifier is run, either `gpu` or `cpu`. |
| 66 | + """ |
| 67 | + import torch # lgtm [py/repeated-import] |
| 68 | + from torch.autograd import Function |
| 69 | + |
| 70 | + super().__init__(is_fitted=True, apply_fit=apply_fit, apply_predict=apply_predict) |
| 71 | + self.channels_first = channels_first |
| 72 | + self.sample_rate = sample_rate |
| 73 | + self.verbose = verbose |
| 74 | + self._check_params() |
| 75 | + |
| 76 | + # Set device |
| 77 | + if device_type == "cpu" or not torch.cuda.is_available(): |
| 78 | + self._device = torch.device("cpu") |
| 79 | + else: |
| 80 | + cuda_idx = torch.cuda.current_device() |
| 81 | + self._device = torch.device("cuda:{}".format(cuda_idx)) |
| 82 | + |
| 83 | + self.compression_numpy = Mp3Compression( |
| 84 | + sample_rate=sample_rate, |
| 85 | + channels_first=channels_first, |
| 86 | + apply_fit=apply_fit, |
| 87 | + apply_predict=apply_predict, |
| 88 | + verbose=verbose, |
| 89 | + ) |
| 90 | + |
| 91 | + class CompressionPyTorchNumpy(Function): |
| 92 | + """ |
| 93 | + Function running Preprocessor. |
| 94 | + """ |
| 95 | + |
| 96 | + @staticmethod |
| 97 | + def forward(ctx, input): # pylint: disable=W0622,W0221 |
| 98 | + numpy_input = input.detach().cpu().numpy() |
| 99 | + result, _ = self.compression_numpy(numpy_input) |
| 100 | + return input.new(result) |
| 101 | + |
| 102 | + @staticmethod |
| 103 | + def backward(ctx, grad_output): # pylint: disable=W0221 |
| 104 | + numpy_go = grad_output.cpu().numpy() |
| 105 | + np.expand_dims(input, axis=[0, 2]) |
| 106 | + result = self.compression_numpy.estimate_gradient(None, numpy_go) |
| 107 | + result = result.squeeze() |
| 108 | + return grad_output.new(result) |
| 109 | + |
| 110 | + self._compression_pytorch_numpy = CompressionPyTorchNumpy |
| 111 | + |
| 112 | + def forward( |
| 113 | + self, x: "torch.Tensor", y: Optional["torch.Tensor"] = None |
| 114 | + ) -> Tuple["torch.Tensor", Optional["torch.Tensor"]]: |
| 115 | + """ |
| 116 | + Apply MP3 compression to sample `x`. |
| 117 | +
|
| 118 | + :param x: Sample to compress with shape `(length, channel)` or an array of sample arrays with shape |
| 119 | + (length,) or (length, channel). |
| 120 | + :param y: Labels of the sample `x`. This function does not affect them in any way. |
| 121 | + :return: Compressed sample. |
| 122 | + """ |
| 123 | + x_compressed = self._compression_pytorch_numpy.apply(x) |
| 124 | + return x_compressed, y |
| 125 | + |
| 126 | + def _check_params(self) -> None: |
| 127 | + if not (isinstance(self.sample_rate, (int, np.int)) and self.sample_rate > 0): |
| 128 | + raise ValueError("Sample rate be must a positive integer.") |
| 129 | + |
| 130 | + if not isinstance(self.verbose, bool): |
| 131 | + raise ValueError("The argument `verbose` has to be of type bool.") |
0 commit comments