30
30
@pytest .mark .only_with_platform ("pytorch" )
31
31
def test_predict (art_warning , get_pytorch_object_detector ):
32
32
try :
33
- object_detector , x_test , _ = get_pytorch_object_detector
33
+ from art . utils import non_maximum_suppression
34
34
35
- result = object_detector .predict (x_test )
36
- assert list (result [0 ].keys ()) == ["boxes" , "labels" , "scores" ]
35
+ object_detector , x_test , _ = get_pytorch_object_detector
37
36
38
- assert result [ 0 ][ "boxes" ]. shape == ( 7 , 4 )
39
- expected_detection_boxes = np . asarray ([ 4.4017954 , 6.3090835 , 22.128296 , 27.570665 ] )
40
- np . testing . assert_array_almost_equal (result [ 0 ][ "boxes" ][ 2 , :], expected_detection_boxes , decimal = 3 )
37
+ preds = object_detector . predict ( x_test )
38
+ result = non_maximum_suppression ( preds [ 0 ], iou_threshold = 0.4 , confidence_threshold = 0.3 )
39
+ assert list (result . keys ()) == [ "boxes" , "labels" , "scores" ]
41
40
42
- assert result [0 ]["scores" ].shape == (7 ,)
43
- expected_detection_scores = np .asarray (
44
- [0.3314798 , 0.14125851 , 0.13928168 , 0.0996184 , 0.08550017 , 0.06690315 , 0.05359321 ]
41
+ assert result ["boxes" ].shape == (2 , 4 )
42
+ expected_detection_boxes = np .asarray (
43
+ [
44
+ [6.136914 , 22.481018 , 413.05814 , 346.08746 ],
45
+ [0.000000 , 24.181173 , 406.47644 , 342.62213 ],
46
+ ]
45
47
)
46
- np .testing .assert_array_almost_equal (result [0 ]["scores" ][:10 ], expected_detection_scores , decimal = 6 )
48
+ np .testing .assert_array_almost_equal (result ["boxes" ], expected_detection_boxes , decimal = 3 )
49
+
50
+ assert result ["scores" ].shape == (2 ,)
51
+ expected_detection_scores = np .asarray ([0.4237412 , 0.35696018 ])
52
+ np .testing .assert_array_almost_equal (result ["scores" ], expected_detection_scores , decimal = 3 )
47
53
48
- assert result [0 ][ "labels" ].shape == (7 ,)
49
- expected_detection_classes = np .asarray ([72 , 79 , 1 , 72 , 78 , 72 , 82 ])
50
- np .testing .assert_array_almost_equal (result [0 ][ "labels" ][: 10 ] , expected_detection_classes , decimal = 6 )
54
+ assert result ["labels" ].shape == (2 ,)
55
+ expected_detection_classes = np .asarray ([21 , 18 ])
56
+ np .testing .assert_array_equal (result ["labels" ], expected_detection_classes )
51
57
52
58
except ARTTestException as e :
53
59
art_warning (e )
@@ -56,22 +62,30 @@ def test_predict(art_warning, get_pytorch_object_detector):
56
62
@pytest .mark .only_with_platform ("pytorch" )
57
63
def test_predict_mask (art_warning , get_pytorch_object_detector_mask ):
58
64
try :
65
+ from art .utils import non_maximum_suppression
66
+
59
67
object_detector , x_test , _ = get_pytorch_object_detector_mask
60
68
61
- result = object_detector .predict (x_test )
62
- assert list (result [0 ].keys ()) == ["boxes" , "labels" , "scores" , "masks" ]
69
+ preds = object_detector .predict (x_test )
70
+ result = non_maximum_suppression (preds [0 ], iou_threshold = 0.4 , confidence_threshold = 0.3 )
71
+ assert list (result .keys ()) == ["boxes" , "labels" , "scores" ]
63
72
64
- assert result [0 ]["boxes" ].shape == (4 , 4 )
65
- expected_detection_boxes = np .asarray ([8.62889 , 11.735134 , 16.353355 , 27.565004 ])
66
- np .testing .assert_array_almost_equal (result [0 ]["boxes" ][2 , :], expected_detection_boxes , decimal = 3 )
73
+ assert result ["boxes" ].shape == (2 , 4 )
74
+ expected_detection_boxes = np .asarray (
75
+ [
76
+ [44.097942 , 22.865257 , 415.32070 , 294.20483 ],
77
+ [25.739365 , 33.178577 , 416.00000 , 338.51460 ],
78
+ ]
79
+ )
80
+ np .testing .assert_array_almost_equal (result ["boxes" ], expected_detection_boxes , decimal = 3 )
67
81
68
- assert result [0 ][ "scores" ].shape == (4 ,)
69
- expected_detection_scores = np .asarray ([0.45197296 , 0.12707493 , 0.082677 , 0.05386855 ])
70
- np .testing .assert_array_almost_equal (result [0 ][ "scores" ][: 10 ] , expected_detection_scores , decimal = 4 )
82
+ assert result ["scores" ].shape == (2 ,)
83
+ expected_detection_scores = np .asarray ([0.67316836 , 0.5686724 ])
84
+ np .testing .assert_array_almost_equal (result ["scores" ], expected_detection_scores , decimal = 3 )
71
85
72
- assert result [0 ][ "labels" ].shape == (4 ,)
73
- expected_detection_classes = np .asarray ([72 , 72 , 1 , 1 ])
74
- np .testing .assert_array_almost_equal (result [0 ][ "labels" ][: 10 ] , expected_detection_classes , decimal = 6 )
86
+ assert result ["labels" ].shape == (2 ,)
87
+ expected_detection_classes = np .asarray ([18 , 21 ])
88
+ np .testing .assert_array_equal (result ["labels" ], expected_detection_classes )
75
89
76
90
except ARTTestException as e :
77
91
art_warning (e )
@@ -124,75 +138,59 @@ def test_loss_gradient(art_warning, get_pytorch_object_detector):
124
138
125
139
# Compute gradients
126
140
grads = object_detector .loss_gradient (x_test , y_test )
127
- assert grads .shape == (2 , 28 , 28 , 3 )
141
+ assert grads .shape == (1 , 3 , 416 , 416 )
128
142
129
143
expected_gradients1 = np .asarray (
130
144
[
131
- [4.6265591e-04 , 1.2323459e-03 , 1.3915040e-03 ],
132
- [- 3.2658060e-04 , - 3.6941725e-03 , - 4.5638453e-04 ],
133
- [7.8702159e-04 , - 3.3072452e-03 , 3.0583731e-04 ],
134
- [1.0381485e-03 , - 2.0846087e-03 , 2.3015277e-04 ],
135
- [2.1460971e-03 , - 1.3157589e-03 , 3.5176644e-04 ],
136
- [3.3839934e-03 , 1.3083456e-03 , 1.6155940e-03 ],
137
- [3.8621046e-03 , 1.6645766e-03 , 1.8313043e-03 ],
138
- [3.0887076e-03 , 1.4632678e-03 , 1.1174511e-03 ],
139
- [3.3404885e-03 , 2.0578136e-03 , 9.6874911e-04 ],
140
- [3.2202434e-03 , 7.2660763e-04 , 8.9162006e-04 ],
141
- [3.5761783e-03 , 2.3615893e-03 , 8.8510796e-04 ],
142
- [3.4721815e-03 , 1.9500104e-03 , 9.2907902e-04 ],
143
- [3.4767685e-03 , 2.1154548e-03 , 5.5654044e-04 ],
144
- [3.9492580e-03 , 3.5505455e-03 , 6.5863604e-04 ],
145
- [3.9963769e-03 , 4.0338552e-03 , 3.9539216e-04 ],
146
- [2.2312226e-03 , 5.1399925e-06 , - 1.0743635e-03 ],
147
- [2.3955442e-03 , 6.7116896e-04 , - 1.2389944e-03 ],
148
- [1.9969011e-03 , - 4.5717746e-04 , - 1.5225793e-03 ],
149
- [1.8131963e-03 , - 7.7948131e-04 , - 1.6078206e-03 ],
150
- [1.4277012e-03 , - 7.7973347e-04 , - 1.3463887e-03 ],
151
- [7.3705515e-04 , - 1.1704378e-03 , - 9.8979671e-04 ],
152
- [1.0899740e-04 , - 1.2144407e-03 , - 1.1339665e-03 ],
153
- [1.2254890e-04 , - 4.7438752e-04 , - 8.8673591e-04 ],
154
- [7.0695346e-04 , 7.2568876e-04 , - 2.5591519e-04 ],
155
- [5.0835893e-04 , 2.6866698e-04 , 2.2731400e-04 ],
156
- [- 5.9932750e-04 , - 1.1667561e-03 , - 4.8044650e-04 ],
157
- [4.0421321e-04 , 3.1692928e-04 , - 8.3296909e-05 ],
158
- [4.0506107e-05 , - 3.1728629e-04 , - 4.4132984e-04 ],
145
+ - 2.7225273e-05 ,
146
+ - 2.7225284e-05 ,
147
+ - 3.2535860e-05 ,
148
+ - 9.3287526e-06 ,
149
+ - 1.1088990e-05 ,
150
+ - 3.4527478e-05 ,
151
+ 5.7807661e-06 ,
152
+ 1.1616970e-05 ,
153
+ 2.9732121e-06 ,
154
+ 1.1190044e-05 ,
155
+ - 6.4673945e-06 ,
156
+ - 1.6562306e-05 ,
157
+ - 1.5946282e-05 ,
158
+ - 1.8079168e-06 ,
159
+ - 9.7664342e-06 ,
160
+ 6.2075532e-07 ,
161
+ - 8.9023115e-06 ,
162
+ - 1.5546989e-06 ,
163
+ - 7.2730008e-06 ,
164
+ - 7.5181362e-07 ,
159
165
]
160
166
)
161
- np .testing .assert_array_almost_equal (grads [0 , 0 , : , :], expected_gradients1 , decimal = 2 )
167
+ np .testing .assert_array_almost_equal (grads [0 , 0 , 0 , :20 ], expected_gradients1 , decimal = 2 )
162
168
163
169
expected_gradients2 = np .asarray (
164
170
[
165
- [4.7986404e-04 , 7.7701372e-04 , 1.1786318e-03 ],
166
- [7.3503907e-04 , - 2.3474507e-03 , - 3.9008856e-04 ],
167
- [4.1874062e-04 , - 2.5707064e-03 , - 1.1054531e-03 ],
168
- [- 1.7942721e-03 , - 3.3968450e-03 , - 1.4989552e-03 ],
169
- [- 2.9697213e-03 , - 4.6922294e-03 , - 1.3162185e-03 ],
170
- [- 3.1759157e-03 , - 9.8660104e-03 , - 4.7163852e-03 ],
171
- [1.8666144e-03 , - 2.8793041e-03 , - 3.1324378e-03 ],
172
- [1.0555880e-02 , 7.6373261e-03 , 5.3013843e-03 ],
173
- [8.9815725e-04 , - 1.0321697e-02 , 1.4192325e-03 ],
174
- [8.5643278e-03 , 3.0152409e-03 , 2.0114987e-03 ],
175
- [- 2.7870361e-03 , - 1.1686913e-02 , - 7.0649502e-03 ],
176
- [- 7.7482774e-03 , - 1.3334424e-03 , - 9.1927368e-03 ],
177
- [- 8.1487820e-03 , - 3.8133820e-03 , - 4.3300558e-03 ],
178
- [- 7.7006700e-03 , - 1.2594147e-02 , - 3.9680018e-03 ],
179
- [- 9.5743872e-03 , - 2.1007264e-02 , - 9.1963671e-03 ],
180
- [- 8.6777220e-03 , - 1.7278835e-02 , - 1.3328674e-02 ],
181
- [- 1.7368209e-02 , - 2.3461722e-02 , - 1.1538444e-02 ],
182
- [- 4.6307812e-03 , - 5.7058665e-03 , 1.3555109e-03 ],
183
- [4.8570461e-03 , - 5.8050654e-03 , 8.1082489e-03 ],
184
- [6.4304657e-03 , 2.8407066e-03 , 8.7463465e-03 ],
185
- [5.0593228e-03 , 1.4102085e-03 , 5.2116364e-03 ],
186
- [2.5003455e-03 , - 6.0178695e-04 , 2.0183939e-03 ],
187
- [2.1247163e-03 , 4.7659015e-04 , 7.5940741e-04 ],
188
- [1.3499497e-03 , 6.2203623e-04 , 1.2288829e-04 ],
189
- [2.8991612e-04 , - 4.0216290e-04 , - 7.2287643e-05 ],
190
- [6.6898909e-05 , - 6.3778006e-04 , - 3.6294860e-04 ],
191
- [5.3613615e-04 , 9.9137833e-05 , - 1.6657988e-05 ],
192
- [- 3.9828232e-05 , - 3.8453130e-04 , - 2.3702848e-04 ],
171
+ - 2.7307957e-05 ,
172
+ - 1.9417710e-05 ,
173
+ - 2.0928457e-05 ,
174
+ - 2.1384752e-05 ,
175
+ - 2.5035972e-05 ,
176
+ - 3.6572790e-05 ,
177
+ - 8.2444545e-05 ,
178
+ - 7.3255811e-05 ,
179
+ - 4.5060227e-05 ,
180
+ - 1.9829258e-05 ,
181
+ - 2.2043951e-05 ,
182
+ - 3.6746951e-05 ,
183
+ - 4.2588043e-05 ,
184
+ - 3.1833035e-05 ,
185
+ - 1.5923406e-05 ,
186
+ - 3.5026955e-05 ,
187
+ - 4.4511849e-05 ,
188
+ - 3.3867167e-05 ,
189
+ - 1.8569792e-05 ,
190
+ - 3.5141209e-05 ,
193
191
]
194
192
)
195
- np .testing .assert_array_almost_equal (grads [1 , 0 , :, : ], expected_gradients2 , decimal = 2 )
193
+ np .testing .assert_array_almost_equal (grads [0 , 0 , :20 , 0 ], expected_gradients2 , decimal = 2 )
196
194
197
195
except ARTTestException as e :
198
196
art_warning (e )
@@ -205,75 +203,66 @@ def test_loss_gradient_mask(art_warning, get_pytorch_object_detector_mask):
205
203
206
204
# Compute gradients
207
205
grads = object_detector .loss_gradient (x_test , y_test )
208
- assert grads .shape == (2 , 28 , 28 , 3 )
206
+ assert grads .shape == (1 , 3 , 416 , 416 )
207
+
208
+ import pprint
209
+
210
+ print ()
211
+ pprint .pprint (grads [0 , 0 , 0 , :20 ])
212
+ print ()
213
+ pprint .pprint (grads [0 , 0 , :20 , 0 ])
209
214
210
215
expected_gradients1 = np .asarray (
211
216
[
212
- [1.2062087e-03 , 6.7400718e-03 , 9.5682510e-04 ],
213
- [- 3.6111937e-03 , - 5.3175041e-03 , - 3.2421902e-03 ],
214
- [1.4717830e-03 , 1.0347518e-03 , 1.7675158e-04 ],
215
- [2.9278828e-03 , 5.0933827e-03 , 3.5095078e-04 ],
216
- [- 3.1896026e-04 , 3.6363016e-04 , - 6.6032895e-04 ],
217
- [- 3.8130947e-03 , - 5.5106943e-03 , - 2.3003859e-03 ],
218
- [- 4.1348115e-03 , - 6.5722968e-03 , - 1.5899740e-03 ],
219
- [- 2.4562061e-03 , - 4.1960045e-03 , - 1.7881666e-03 ],
220
- [2.2911791e-04 , - 6.4335053e-04 , - 1.6564501e-03 ],
221
- [- 1.2582233e-03 , - 1.5607923e-03 , - 2.2904854e-03 ],
222
- [- 1.8436739e-03 , - 2.7200577e-03 , - 2.9125123e-03 ],
223
- [- 1.5151387e-03 , - 4.4148900e-03 , - 1.7429549e-03 ],
224
- [5.4955669e-03 , 8.1859864e-03 , 1.6560742e-03 ],
225
- [3.1721895e-03 , 2.4013112e-03 , - 1.9453048e-04 ],
226
- [5.1122587e-03 , 7.4281446e-03 , 2.4133435e-04 ],
227
- [2.7988979e-03 , 4.4798232e-03 , - 1.2488490e-03 ],
228
- [3.1651880e-03 , 4.5040119e-03 , - 1.6507130e-03 ],
229
- [8.5774017e-04 , 9.9022139e-04 , - 3.1324981e-03 ],
230
- [3.8568545e-04 , 4.7918499e-04 , - 2.4925626e-03 ],
231
- [- 1.8368122e-03 , - 3.9491002e-03 , - 3.9275796e-03 ],
232
- [1.6731160e-03 , 1.5304115e-03 , - 1.4627117e-03 ],
233
- [1.4445755e-03 , 1.4263670e-03 , - 2.0084691e-03 ],
234
- [2.0193408e-04 , 7.2605687e-04 , - 1.8740210e-03 ],
235
- [- 1.3681910e-03 , 1.7499415e-05 , - 2.4952039e-03 ],
236
- [1.3475126e-04 , 3.0096075e-03 , - 2.4493274e-04 ],
237
- [- 6.2653446e-03 , - 9.5283017e-03 , - 2.9458744e-03 ],
238
- [- 2.6554640e-03 , - 1.4588287e-03 , - 3.2393888e-03 ],
239
- [- 6.4712246e-03 , - 7.2136321e-03 , - 5.4933843e-03 ],
217
+ - 4.2168313e-06 ,
218
+ - 4.4972450e-05 ,
219
+ - 3.6137710e-05 ,
220
+ - 1.2499937e-06 ,
221
+ 1.2728384e-05 ,
222
+ - 1.7352231e-05 ,
223
+ 5.6671047e-06 ,
224
+ 1.4085637e-05 ,
225
+ 5.9047998e-06 ,
226
+ 1.0826078e-05 ,
227
+ 2.2078505e-06 ,
228
+ - 1.3319310e-05 ,
229
+ - 2.4521427e-05 ,
230
+ - 1.8251436e-05 ,
231
+ - 1.9938851e-05 ,
232
+ - 3.6778667e-07 ,
233
+ 1.1899039e-05 ,
234
+ 1.9246204e-06 ,
235
+ - 2.7922330e-05 ,
236
+ - 3.2529952e-06 ,
240
237
]
241
238
)
242
- np .testing .assert_array_almost_equal (grads [0 , 0 , : , :], expected_gradients1 , decimal = 2 )
239
+ np .testing .assert_array_almost_equal (grads [0 , 0 , 0 , :20 ], expected_gradients1 , decimal = 2 )
243
240
244
241
expected_gradients2 = np .asarray (
245
242
[
246
- [- 2.0123991e-04 , - 9.0955076e-04 , - 2.2947363e-04 ],
247
- [3.0414842e-04 , 3.4150464e-04 , 2.1101040e-04 ],
248
- [6.6070761e-06 , - 1.8034373e-04 , 1.3608378e-05 ],
249
- [- 1.3393547e-05 , - 3.2230929e-04 , - 5.5581659e-05 ],
250
- [- 1.0353983e-04 , - 2.7751207e-04 , - 2.3205159e-04 ],
251
- [- 5.3371373e-04 , - 1.1550108e-03 , - 2.6975147e-04 ],
252
- [- 2.6593581e-04 , - 7.3971582e-04 , - 7.4292002e-05 ],
253
- [- 9.3046663e-05 , - 4.0410538e-04 , - 1.4271366e-04 ],
254
- [- 1.3833238e-04 , - 5.6283473e-04 , - 8.4650565e-05 ],
255
- [- 8.0315210e-04 , - 1.4300735e-03 , - 9.3330207e-05 ],
256
- [2.7694018e-04 , 6.8307301e-04 , 5.5274006e-04 ],
257
- [3.1839000e-04 , 9.7277382e-04 , 4.6252453e-04 ],
258
- [2.8279822e-04 , 6.2632316e-04 , 3.3778447e-04 ],
259
- [4.0508871e-04 , 1.2438387e-03 , 3.6151547e-04 ],
260
- [- 7.5090391e-04 , - 2.6640363e-04 , - 2.6418429e-04 ],
261
- [- 2.3455340e-03 , - 4.9932003e-03 , - 8.0432917e-04 ],
262
- [4.1711782e-03 , 5.3390805e-03 , 2.4412808e-03 ],
263
- [5.1162727e-03 , 5.2886135e-03 , 3.6190096e-03 ],
264
- [6.9976337e-03 , 9.7018024e-03 , 3.8526775e-03 ],
265
- [4.5005931e-03 , 4.3762275e-03 , 1.7228650e-03 ],
266
- [6.3695023e-03 , 8.4943371e-03 , 1.7638379e-03 ],
267
- [3.0587378e-03 , 3.9485283e-03 , 4.9000646e-05 ],
268
- [- 3.2190280e-04 , - 6.6311209e-04 , - 9.8086358e-04 ],
269
- [8.3606638e-04 , 2.0184387e-03 , - 3.5464868e-04 ],
270
- [- 1.8979331e-04 , 3.1042210e-04 , - 4.2471994e-04 ],
271
- [- 8.8790455e-04 , - 1.4127755e-03 , - 4.4270226e-04 ],
272
- [4.1172301e-04 , 2.9453568e-04 , 2.1122720e-04 ],
273
- [1.6500468e-04 , 3.7142841e-04 , - 4.5339554e-04 ],
243
+ - 4.2168313e-06 ,
244
+ - 9.3028730e-06 ,
245
+ 1.5900954e-06 ,
246
+ - 9.7032771e-06 ,
247
+ - 7.9553565e-06 ,
248
+ - 1.9485701e-06 ,
249
+ - 1.3360468e-05 ,
250
+ - 2.7804586e-05 ,
251
+ - 4.2667002e-06 ,
252
+ - 6.1407286e-06 ,
253
+ - 6.6768125e-06 ,
254
+ - 1.6444834e-06 ,
255
+ 4.7967392e-06 ,
256
+ 2.4288647e-06 ,
257
+ 1.0280205e-05 ,
258
+ 4.2001102e-06 ,
259
+ 2.9494076e-05 ,
260
+ 1.4654281e-05 ,
261
+ 2.5580388e-05 ,
262
+ 3.0241908e-05 ,
274
263
]
275
264
)
276
- np .testing .assert_array_almost_equal (grads [1 , 0 , :, : ], expected_gradients2 , decimal = 2 )
265
+ np .testing .assert_array_almost_equal (grads [0 , 0 , :20 , 0 ], expected_gradients2 , decimal = 2 )
277
266
278
267
except ARTTestException as e :
279
268
art_warning (e )
0 commit comments