Skip to content

Commit 36aa351

Browse files
committed
Fix dependencies
Signed-off-by: Beat Buesser <[email protected]>
1 parent f562e36 commit 36aa351

File tree

1 file changed

+42
-28
lines changed

1 file changed

+42
-28
lines changed

tests/attacks/evasion/test_overload_attack.py

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -35,24 +35,31 @@ def test_generate(art_warning):
3535
from ultralytics.nn.modules import Conv
3636
from ultralytics import YOLO
3737

38-
torch.serialization.add_safe_globals([torch.nn.modules.container.Sequential])
39-
torch.serialization.add_safe_globals([torch.nn.modules.container.ModuleList])
40-
torch.serialization.add_safe_globals([torch.nn.modules.pooling.MaxPool2d])
41-
torch.serialization.add_safe_globals([torch.nn.modules.batchnorm.BatchNorm2d])
42-
torch.serialization.add_safe_globals([torch.nn.modules.activation.SiLU])
43-
torch.serialization.add_safe_globals([torch.nn.modules.conv.Conv2d])
44-
torch.serialization.add_safe_globals([torch.nn.modules.upsampling.Upsample])
45-
torch.serialization.add_safe_globals([ultralytics.nn.tasks.DetectionModel])
46-
torch.serialization.add_safe_globals([ultralytics.nn.tasks.SPPF])
47-
torch.serialization.add_safe_globals([ultralytics.nn.tasks.C3])
48-
torch.serialization.add_safe_globals([ultralytics.nn.tasks.Bottleneck])
49-
torch.serialization.add_safe_globals([ultralytics.nn.tasks.Detect])
50-
torch.serialization.add_safe_globals([Conv])
51-
torch.serialization.add_safe_globals([ultralytics.nn.modules.Conv])
38+
# torch.serialization.add_safe_globals([torch.nn.modules.container.Sequential])
39+
# torch.serialization.add_safe_globals([torch.nn.modules.container.ModuleList])
40+
# torch.serialization.add_safe_globals([torch.nn.modules.pooling.MaxPool2d])
41+
# torch.serialization.add_safe_globals([torch.nn.modules.batchnorm.BatchNorm2d])
42+
# torch.serialization.add_safe_globals([torch.nn.modules.activation.SiLU])
43+
# torch.serialization.add_safe_globals([torch.nn.modules.conv.Conv2d])
44+
# torch.serialization.add_safe_globals([torch.nn.modules.upsampling.Upsample])
45+
# torch.serialization.add_safe_globals([ultralytics.nn.tasks.DetectionModel])
46+
# torch.serialization.add_safe_globals([ultralytics.nn.tasks.SPPF])
47+
# torch.serialization.add_safe_globals([ultralytics.nn.tasks.C3])
48+
# torch.serialization.add_safe_globals([ultralytics.nn.tasks.Bottleneck])
49+
# torch.serialization.add_safe_globals([ultralytics.nn.tasks.Detect])
50+
# torch.serialization.add_safe_globals([Conv])
51+
# torch.serialization.add_safe_globals([ultralytics.nn.modules.Conv])
5252
# torch.serialization.add_safe_globals([ultralytics.nn.modules.Concat])
5353
# torch.serialization.add_safe_globals([ultralytics.nn.modules.DFL])
5454

5555
model = YOLO("yolov5su.pt")
56+
57+
# Collect all unique classes used in the model
58+
all_classes = set(type(module) for module in model.modules())
59+
60+
# Add them to safe_globals
61+
torch.serialization.add_safe_globals(list(all_classes))
62+
5663
py_model = PyTorchYolo(model=model, input_shape=(3, 640, 640), channels_first=True)
5764

5865
# Download a sample image
@@ -89,24 +96,31 @@ def test_check_params(art_warning):
8996
from ultralytics import YOLO
9097
from ultralytics.nn.modules import Conv
9198

92-
torch.serialization.add_safe_globals([torch.nn.modules.container.Sequential])
93-
torch.serialization.add_safe_globals([torch.nn.modules.container.ModuleList])
94-
torch.serialization.add_safe_globals([torch.nn.modules.pooling.MaxPool2d])
95-
torch.serialization.add_safe_globals([torch.nn.modules.batchnorm.BatchNorm2d])
96-
torch.serialization.add_safe_globals([torch.nn.modules.activation.SiLU])
97-
torch.serialization.add_safe_globals([torch.nn.modules.conv.Conv2d])
98-
torch.serialization.add_safe_globals([torch.nn.modules.upsampling.Upsample])
99-
torch.serialization.add_safe_globals([ultralytics.nn.tasks.DetectionModel])
100-
torch.serialization.add_safe_globals([ultralytics.nn.tasks.SPPF])
101-
torch.serialization.add_safe_globals([ultralytics.nn.tasks.C3])
102-
torch.serialization.add_safe_globals([ultralytics.nn.tasks.Bottleneck])
103-
torch.serialization.add_safe_globals([ultralytics.nn.tasks.Detect])
104-
torch.serialization.add_safe_globals([Conv])
105-
torch.serialization.add_safe_globals([ultralytics.nn.modules.Conv])
99+
# torch.serialization.add_safe_globals([torch.nn.modules.container.Sequential])
100+
# torch.serialization.add_safe_globals([torch.nn.modules.container.ModuleList])
101+
# torch.serialization.add_safe_globals([torch.nn.modules.pooling.MaxPool2d])
102+
# torch.serialization.add_safe_globals([torch.nn.modules.batchnorm.BatchNorm2d])
103+
# torch.serialization.add_safe_globals([torch.nn.modules.activation.SiLU])
104+
# torch.serialization.add_safe_globals([torch.nn.modules.conv.Conv2d])
105+
# torch.serialization.add_safe_globals([torch.nn.modules.upsampling.Upsample])
106+
# torch.serialization.add_safe_globals([ultralytics.nn.tasks.DetectionModel])
107+
# torch.serialization.add_safe_globals([ultralytics.nn.tasks.SPPF])
108+
# torch.serialization.add_safe_globals([ultralytics.nn.tasks.C3])
109+
# torch.serialization.add_safe_globals([ultralytics.nn.tasks.Bottleneck])
110+
# torch.serialization.add_safe_globals([ultralytics.nn.tasks.Detect])
111+
# torch.serialization.add_safe_globals([Conv])
112+
# torch.serialization.add_safe_globals([ultralytics.nn.modules.Conv])
106113
# torch.serialization.add_safe_globals([ultralytics.nn.modules.Concat])
107114
# torch.serialization.add_safe_globals([ultralytics.nn.modules.DFL])
108115

109116
model = YOLO("yolov5su.pt")
117+
118+
# Collect all unique classes used in the model
119+
all_classes = set(type(module) for module in model.modules())
120+
121+
# Add them to safe_globals
122+
torch.serialization.add_safe_globals(list(all_classes))
123+
110124
py_model = PyTorchYolo(model=model, input_shape=(3, 640, 640), channels_first=True)
111125

112126
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)