30
30
@pytest .mark .only_with_platform ("pytorch" )
31
31
def test_predict (art_warning , get_pytorch_detr ):
32
32
try :
33
- object_detector , x_test , _ = get_pytorch_detr
34
-
35
- result = object_detector .predict (x = x_test )
33
+ from art .utils import non_maximum_suppression
36
34
37
- assert list ( result [ 0 ]. keys ()) == [ "boxes" , "labels" , "scores" ]
35
+ object_detector , x_test , _ = get_pytorch_detr
38
36
39
- assert result [ 0 ][ "boxes" ]. shape == ( 100 , 4 )
40
- expected_detection_boxes = np . asarray ([ - 0.12423098 , 361.80136 , 82.385345 , 795.50305 ] )
41
- np . testing . assert_array_almost_equal (result [ 0 ][ "boxes" ][ 2 , :], expected_detection_boxes , decimal = 1 )
37
+ preds = object_detector . predict ( x_test )
38
+ result = non_maximum_suppression ( preds [ 0 ], iou_threshold = 0.4 , confidence_threshold = 0.3 )
39
+ assert list (result . keys ()) == [ "boxes" , "labels" , "scores" ]
42
40
43
- assert result [0 ][ "scores " ].shape == (100 , )
44
- expected_detection_scores = np .asarray (
41
+ assert result ["boxes " ].shape == (3 , 4 )
42
+ expected_detection_boxes = np .asarray (
45
43
[
46
- 0.00105285 ,
47
- 0.00261505 ,
48
- 0.00060220 ,
49
- 0.00121928 ,
50
- 0.00154554 ,
51
- 0.00021678 ,
52
- 0.00077083 ,
53
- 0.00045684 ,
54
- 0.00180561 ,
55
- 0.00067704 ,
44
+ [1.0126123 , 25.658852 , 412.70746 , 379.12537 ],
45
+ [- 0.089400 , 272.08664 , 415.90994 , 416.25930 ],
46
+ [0.1522941 , 75.882440 , 99.139565 , 335.11273 ],
56
47
]
57
48
)
58
- np .testing .assert_array_almost_equal (result [0 ]["scores" ][:10 ], expected_detection_scores , decimal = 1 )
49
+ np .testing .assert_array_almost_equal (result ["boxes" ], expected_detection_boxes , decimal = 3 )
50
+
51
+ assert result ["scores" ].shape == (3 ,)
52
+ expected_detection_scores = np .asarray ([0.8424455 , 0.7796526 , 0.35387915 ])
53
+ np .testing .assert_array_almost_equal (result ["scores" ], expected_detection_scores , decimal = 3 )
59
54
60
- assert result [0 ][ "labels" ].shape == (100 ,)
61
- expected_detection_classes = np .asarray ([1 , 23 , 23 , 1 , 1 , 23 , 23 , 23 , 1 , 1 ])
62
- np .testing .assert_array_almost_equal (result [0 ][ "labels" ][: 10 ] , expected_detection_classes , decimal = 1 )
55
+ assert result ["labels" ].shape == (3 ,)
56
+ expected_detection_classes = np .asarray ([17 , 65 , 17 ])
57
+ np .testing .assert_array_equal (result ["labels" ], expected_detection_classes )
63
58
64
59
except ARTTestException as e :
65
60
art_warning (e )
@@ -68,15 +63,8 @@ def test_predict(art_warning, get_pytorch_detr):
68
63
@pytest .mark .only_with_platform ("pytorch" )
69
64
def test_fit (art_warning , get_pytorch_detr ):
70
65
try :
71
- import torch
72
-
73
66
object_detector , x_test , y_test = get_pytorch_detr
74
67
75
- # Create optimizer
76
- params = [p for p in object_detector .model .parameters () if p .requires_grad ]
77
- optimizer = torch .optim .SGD (params , lr = 0.01 )
78
- object_detector .set_params (optimizer = optimizer )
79
-
80
68
# Compute loss before training
81
69
loss1 = object_detector .compute_loss (x = x_test , y = y_test )
82
70
@@ -99,84 +87,83 @@ def test_loss_gradient(art_warning, get_pytorch_detr):
99
87
100
88
grads = object_detector .loss_gradient (x = x_test , y = y_test )
101
89
102
- assert grads .shape == (2 , 3 , 800 , 800 )
90
+ assert grads .shape == (1 , 3 , 416 , 416 )
103
91
104
92
expected_gradients1 = np .asarray (
105
93
[
106
- - 0.00757495 ,
107
- - 0.00101332 ,
108
- 0.00368362 ,
109
- 0.00283334 ,
110
- - 0.00096027 ,
111
- 0.00873749 ,
112
- 0.00546095 ,
113
- - 0.00823532 ,
114
- - 0.00710872 ,
115
- 0.00389713 ,
116
- - 0.00966289 ,
117
- 0.00448294 ,
118
- 0.00754991 ,
119
- - 0.00934104 ,
120
- - 0.00350194 ,
121
- - 0.00541577 ,
122
- - 0.00395624 ,
123
- 0.00147651 ,
124
- 0.0105616 ,
125
- 0.01231265 ,
126
- - 0.00148831 ,
127
- - 0.0043609 ,
128
- 0.00093031 ,
129
- 0.00884939 ,
130
- - 0.00356749 ,
131
- 0.00093475 ,
132
- - 0.00353712 ,
133
- - 0.0060132 ,
134
- - 0.00067899 ,
135
- - 0.00886974 ,
136
- 0.00108483 ,
137
- - 0.00052412 ,
94
+ 0.02891439 ,
95
+ 0.0055933 ,
96
+ - 0.00687808 ,
97
+ 0.0095074 ,
98
+ 0.00247894 ,
99
+ 0.00122704 ,
100
+ - 0.00482378 ,
101
+ - 0.00924361 ,
102
+ - 0.02870164 ,
103
+ - 0.00683936 ,
104
+ 0.00904205 ,
105
+ - 0.01315971 ,
106
+ - 0.0151937 ,
107
+ - 0.00156442 ,
108
+ 0.00775309 ,
109
+ 0.01946152 ,
110
+ 0.00523211 ,
111
+ - 0.01682214 ,
112
+ 0.00079588 ,
113
+ 0.01627164 ,
114
+ - 0.01347653 ,
115
+ - 0.00512358 ,
116
+ 0.00610363 ,
117
+ 0.02831643 ,
118
+ 0.00742467 ,
119
+ 0.00293561 ,
120
+ 0.01380033 ,
121
+ 0.02112359 ,
122
+ 0.01725711 ,
123
+ - 0.00431877 ,
124
+ - 0.01007722 ,
125
+ - 0.00526983 ,
138
126
]
139
127
)
140
-
141
- np .testing .assert_array_almost_equal (grads [0 , 0 , 10 , :32 ], expected_gradients1 , decimal = 1 )
128
+ np .testing .assert_array_almost_equal (grads [0 , 0 , 208 , 192 :224 ], expected_gradients1 , decimal = 1 )
142
129
143
130
expected_gradients2 = np .asarray (
144
131
[
145
- - 0.00757495 ,
146
- - 0.00101332 ,
147
- 0.00368362 ,
148
- 0.00283334 ,
149
- - 0.00096027 ,
150
- 0.00873749 ,
151
- 0.00546095 ,
152
- - 0.00823532 ,
153
- - 0.00710872 ,
154
- 0.00389713 ,
155
- - 0.00966289 ,
156
- 0.00448294 ,
157
- 0.00754991 ,
158
- - 0.00934104 ,
159
- - 0.00350194 ,
160
- - 0.00541577 ,
161
- - 0.00395624 ,
162
- 0.00147651 ,
163
- 0.0105616 ,
164
- 0.01231265 ,
165
- - 0.00148831 ,
166
- - 0.0043609 ,
167
- 0.00093031 ,
168
- 0.00884939 ,
169
- - 0.00356749 ,
170
- 0.00093475 ,
171
- - 0.00353712 ,
172
- - 0.0060132 ,
173
- - 0.00067899 ,
174
- - 0.00886974 ,
175
- 0.00108483 ,
176
- - 0.00052412 ,
132
+ - 0.00549417 ,
133
+ - 0.01592844 ,
134
+ - 0.01073932 ,
135
+ - 0.00443333 ,
136
+ - 0.00780143 ,
137
+ - 0.02033146 ,
138
+ - 0.0191503 ,
139
+ 0.01227987 ,
140
+ 0.019971 ,
141
+ 0.01034214 ,
142
+ - 0.00918145 ,
143
+ - 0.02458049 ,
144
+ - 0.00708776 ,
145
+ - 0.00826812 ,
146
+ - 0.01284431 ,
147
+ - 0.00195021 ,
148
+ 0.00523211 ,
149
+ 0.00661678 ,
150
+ 0.00851441 ,
151
+ 0.01157211 ,
152
+ - 0.00324841 ,
153
+ - 0.00395823 ,
154
+ 0.00756641 ,
155
+ 0.00405913 ,
156
+ - 0.00055517 ,
157
+ 0.00221484 ,
158
+ - 0.02415526 ,
159
+ - 0.02096599 ,
160
+ 0.00980014 ,
161
+ 0.00174731 ,
162
+ - 0.01008899 ,
163
+ 0.00305779 ,
177
164
]
178
165
)
179
- np .testing .assert_array_almost_equal (grads [1 , 0 , 10 , : 32 ], expected_gradients2 , decimal = 1 )
166
+ np .testing .assert_array_almost_equal (grads [0 , 0 , 192 : 224 , 208 ], expected_gradients2 , decimal = 1 )
180
167
181
168
except ARTTestException as e :
182
169
art_warning (e )
@@ -239,18 +226,13 @@ def test_preprocessing_defences(art_warning, get_pytorch_detr):
239
226
"boxes" : result [0 ]["boxes" ],
240
227
"labels" : result [0 ]["labels" ],
241
228
"scores" : np .ones_like (result [0 ]["labels" ]),
242
- },
243
- {
244
- "boxes" : result [1 ]["boxes" ],
245
- "labels" : result [1 ]["labels" ],
246
- "scores" : np .ones_like (result [1 ]["labels" ]),
247
- },
229
+ }
248
230
]
249
231
250
232
# Compute gradients
251
233
grads = object_detector .loss_gradient (x = x_test , y = y )
252
234
253
- assert grads .shape == (2 , 3 , 800 , 800 )
235
+ assert grads .shape == (1 , 3 , 416 , 416 )
254
236
255
237
except ARTTestException as e :
256
238
art_warning (e )
@@ -275,7 +257,7 @@ def test_compute_loss(art_warning, get_pytorch_detr):
275
257
# Compute loss
276
258
loss = object_detector .compute_loss (x = x_test , y = y_test )
277
259
278
- assert pytest .approx (6.7767677 , abs = 0.1 ) == float (loss )
260
+ assert pytest .approx (2.172381 , abs = 0.1 ) == float (loss )
279
261
280
262
except ARTTestException as e :
281
263
art_warning (e )
0 commit comments