|
20 | 20 | import os |
21 | 21 | import pytest |
22 | 22 |
|
23 | | -import torch |
24 | 23 | import numpy as np |
25 | | - |
26 | | -from torch import nn |
| 24 | +import torch |
27 | 25 |
|
28 | 26 | from art.utils import load_dataset |
29 | 27 | from art.estimators.certification.deep_z import PytorchDeepZ |
@@ -76,7 +74,7 @@ def test_mnist_certification(art_warning, fix_get_mnist_data): |
76 | 74 | ptc = get_image_classifier_pt(from_logits=True, use_maxpool=False) |
77 | 75 |
|
78 | 76 | zonotope_model = PytorchDeepZ( |
79 | | - model=ptc.model, clip_values=(0, 1), loss=nn.CrossEntropyLoss(), input_shape=(1, 28, 28), nb_classes=10 |
| 77 | + model=ptc.model, clip_values=(0, 1), loss=torch.nn.CrossEntropyLoss(), input_shape=(1, 28, 28), nb_classes=10 |
80 | 78 | ) |
81 | 79 |
|
82 | 80 | correct_upper_bounds = np.asarray( |
@@ -173,7 +171,7 @@ def test_cifar_certification(art_warning, fix_get_cifar10_data): |
173 | 171 |
|
174 | 172 | ptc = get_cifar10_image_classifier_pt(from_logits=True) |
175 | 173 | zonotope_model = PytorchDeepZ( |
176 | | - model=ptc.model, clip_values=(0, 1), loss=nn.CrossEntropyLoss(), input_shape=(3, 32, 32), nb_classes=10 |
| 174 | + model=ptc.model, clip_values=(0, 1), loss=torch.nn.CrossEntropyLoss(), input_shape=(3, 32, 32), nb_classes=10 |
177 | 175 | ) |
178 | 176 |
|
179 | 177 | correct_upper_bounds = np.asarray( |
|
0 commit comments