|
| 1 | +# MIT License |
| 2 | +# |
| 3 | +# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2024 |
| 4 | +# |
| 5 | +# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated |
| 6 | +# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the |
| 7 | +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit |
| 8 | +# persons to whom the Software is furnished to do so, subject to the following conditions: |
| 9 | +# |
| 10 | +# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the |
| 11 | +# Software. |
| 12 | +# |
| 13 | +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE |
| 14 | +# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 15 | +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, |
| 16 | +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 17 | +# SOFTWARE. |
| 18 | +from __future__ import absolute_import, division, print_function, unicode_literals |
| 19 | + |
| 20 | +import pytest |
| 21 | +import numpy as np |
| 22 | + |
| 23 | +from art.attacks.evasion.fast_gradient import FastGradientMethod |
| 24 | +from art.defences.detector.evasion import BeyondDetectorPyTorch |
| 25 | +from art.estimators.classification import PyTorchClassifier |
| 26 | +from tests.utils import ARTTestException |
| 27 | + |
| 28 | + |
| 29 | +def get_ssl_model(weights_path): |
| 30 | + """ |
| 31 | + Loads the SSL model (SimSiamWithCls). |
| 32 | + """ |
| 33 | + import torch |
| 34 | + import torch.nn as nn |
| 35 | + |
| 36 | + class SimSiamWithCls(nn.Module): |
| 37 | + """ |
| 38 | + SimSiam with Classifier |
| 39 | + """ |
| 40 | + |
| 41 | + def __init__(self, arch="resnet18", feat_dim=2048, num_proj_layers=2): |
| 42 | + from torchvision import models |
| 43 | + |
| 44 | + super(SimSiamWithCls, self).__init__() |
| 45 | + self.backbone = models.resnet18() |
| 46 | + out_dim = self.backbone.fc.weight.shape[1] |
| 47 | + self.backbone.conv1 = nn.Conv2d( |
| 48 | + in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=2, bias=False |
| 49 | + ) |
| 50 | + self.backbone.maxpool = nn.Identity() |
| 51 | + self.backbone.fc = nn.Identity() |
| 52 | + self.classifier = nn.Linear(out_dim, out_features=10) |
| 53 | + |
| 54 | + pred_hidden_dim = int(feat_dim / 4) |
| 55 | + |
| 56 | + self.projector = nn.Sequential( |
| 57 | + nn.Linear(out_dim, feat_dim, bias=False), |
| 58 | + nn.BatchNorm1d(feat_dim), |
| 59 | + nn.ReLU(), |
| 60 | + nn.Linear(feat_dim, feat_dim, bias=False), |
| 61 | + nn.BatchNorm1d(feat_dim), |
| 62 | + nn.ReLU(), |
| 63 | + nn.Linear(feat_dim, feat_dim), |
| 64 | + nn.BatchNorm1d(feat_dim, affine=False), |
| 65 | + ) |
| 66 | + self.projector[6].bias.requires_grad = False |
| 67 | + |
| 68 | + self.predictor = nn.Sequential( |
| 69 | + nn.Linear(feat_dim, pred_hidden_dim, bias=False), |
| 70 | + nn.BatchNorm1d(pred_hidden_dim), |
| 71 | + nn.ReLU(), |
| 72 | + nn.Linear(pred_hidden_dim, feat_dim), |
| 73 | + ) |
| 74 | + |
| 75 | + def forward(self, img, im_aug1=None, im_aug2=None): |
| 76 | + |
| 77 | + r_ori = self.backbone(img) |
| 78 | + if im_aug1 is None and im_aug2 is None: |
| 79 | + cls = self.classifier(r_ori) |
| 80 | + rep = self.projector(r_ori) |
| 81 | + return {"cls": cls, "rep": rep} |
| 82 | + else: |
| 83 | + |
| 84 | + r1 = self.backbone(im_aug1) |
| 85 | + r2 = self.backbone(im_aug2) |
| 86 | + |
| 87 | + z1 = self.projector(r1) |
| 88 | + z2 = self.projector(r2) |
| 89 | + |
| 90 | + p1 = self.predictor(z1) |
| 91 | + p2 = self.predictor(z2) |
| 92 | + |
| 93 | + return {"z1": z1, "z2": z2, "p1": p1, "p2": p2} |
| 94 | + |
| 95 | + model = SimSiamWithCls() |
| 96 | + model.load_state_dict(torch.load(weights_path)) |
| 97 | + return model |
| 98 | + |
| 99 | + |
| 100 | +@pytest.mark.only_with_platform("pytorch") |
| 101 | +def test_beyond_detector(art_warning, get_default_cifar10_subset): |
| 102 | + try: |
| 103 | + import torch |
| 104 | + from torchvision import models, transforms |
| 105 | + |
| 106 | + # Load CIFAR10 data |
| 107 | + (x_train, y_train), (x_test, _) = get_default_cifar10_subset |
| 108 | + |
| 109 | + x_train = x_train[0:100] |
| 110 | + y_train = y_train[0:100] |
| 111 | + x_test = x_test[0:100] |
| 112 | + |
| 113 | + # Load models |
| 114 | + # Download pretrained weights from |
| 115 | + # https://drive.google.com/drive/folders/1ieEdd7hOj2CIl1FQfu4-3RGZmEj-mesi?usp=sharing |
| 116 | + target_model = models.resnet18() |
| 117 | + # target_model.load_state_dict(torch.load("./utils/resources/models/resnet_c10.pth", map_location=torch.device('cpu'))) |
| 118 | + ssl_model = get_ssl_model(weights_path="./utils/resources/models/simsiam_c10.pth") |
| 119 | + |
| 120 | + target_classifier = PyTorchClassifier( |
| 121 | + model=target_model, nb_classes=10, input_shape=(3, 32, 32), loss=torch.nn.CrossEntropyLoss() |
| 122 | + ) |
| 123 | + ssl_classifier = PyTorchClassifier( |
| 124 | + model=ssl_model, nb_classes=10, input_shape=(3, 32, 32), loss=torch.nn.CrossEntropyLoss() |
| 125 | + ) |
| 126 | + |
| 127 | + # Generate adversarial samples |
| 128 | + attack = FastGradientMethod(estimator=target_classifier, eps=0.05) |
| 129 | + x_test_adv = attack.generate(x_test) |
| 130 | + |
| 131 | + img_augmentations = transforms.Compose( |
| 132 | + [ |
| 133 | + transforms.RandomResizedCrop(32, scale=(0.2, 1.0)), |
| 134 | + transforms.RandomHorizontalFlip(), |
| 135 | + transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8), # not strengthened |
| 136 | + transforms.RandomGrayscale(p=0.2), |
| 137 | + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), |
| 138 | + ] |
| 139 | + ) |
| 140 | + |
| 141 | + # Initialize BeyondDetector |
| 142 | + detector = BeyondDetectorPyTorch( |
| 143 | + target_classifier=target_classifier, |
| 144 | + ssl_classifier=ssl_classifier, |
| 145 | + augmentations=img_augmentations, |
| 146 | + aug_num=50, |
| 147 | + alpha=0.8, |
| 148 | + var_K=20, |
| 149 | + percentile=5, |
| 150 | + ) |
| 151 | + |
| 152 | + # Fit the detector |
| 153 | + detector.fit(x_train, y_train, batch_size=128) |
| 154 | + |
| 155 | + # Apply detector on clean and adversarial test data |
| 156 | + _, test_detection = detector.detect(x_test) |
| 157 | + _, test_adv_detection = detector.detect(x_test_adv) |
| 158 | + |
| 159 | + # Assert there is at least one true positive and negative |
| 160 | + nb_true_positives = np.sum(test_adv_detection) |
| 161 | + nb_true_negatives = len(test_detection) - np.sum(test_detection) |
| 162 | + |
| 163 | + assert nb_true_positives > 0 |
| 164 | + assert nb_true_negatives > 0 |
| 165 | + |
| 166 | + clean_accuracy = 1 - np.mean(test_detection) |
| 167 | + adv_accuracy = np.mean(test_adv_detection) |
| 168 | + |
| 169 | + assert clean_accuracy > 0.0 |
| 170 | + assert adv_accuracy > 0.0 |
| 171 | + |
| 172 | + except ARTTestException as e: |
| 173 | + art_warning(e) |
| 174 | + |
| 175 | + |
| 176 | +if __name__ == "__main__": |
| 177 | + |
| 178 | + test_beyond_detector() |
0 commit comments