3030@pytest .mark .only_with_platform ("pytorch" )
3131def test_predict (art_warning , get_pytorch_yolo ):
3232 try :
33- object_detector , x_test , _ = get_pytorch_yolo
33+ from art . utils import non_maximum_suppression
3434
35- result = object_detector . predict ( x = x_test )
35+ object_detector , x_test , _ = get_pytorch_yolo
3636
37- assert list (result [0 ].keys ()) == ["boxes" , "labels" , "scores" ]
37+ preds = object_detector .predict (x_test )
38+ result = non_maximum_suppression (preds [0 ], iou_threshold = 0.4 , confidence_threshold = 0.3 )
39+ assert list (result .keys ()) == ["boxes" , "labels" , "scores" ]
3840
39- assert result [0 ][ "boxes" ].shape == (10647 , 4 )
40- expected_detection_boxes = np .asarray ([0.0000000e00 , 0.0000000e00 , 1.6367816e02 , 4.4342079e01 ])
41- np .testing .assert_array_almost_equal (result [0 ][ "boxes" ][ 2 , : ], expected_detection_boxes , decimal = 3 )
41+ assert result ["boxes" ].shape == (1 , 4 )
42+ expected_detection_boxes = np .asarray ([[ 19.709427 , 39.02864 , 402.08032 , 383.65576 ] ])
43+ np .testing .assert_array_almost_equal (result ["boxes" ], expected_detection_boxes , decimal = 3 )
4244
43- assert result [0 ]["scores" ].shape == (10647 ,)
44- expected_detection_scores = np .asarray (
45- [
46- 4.3653536e-08 ,
47- 3.3987994e-06 ,
48- 2.5681820e-06 ,
49- 3.9782722e-06 ,
50- 2.1766680e-05 ,
51- 2.6138965e-05 ,
52- 6.3377396e-05 ,
53- 7.6248516e-06 ,
54- 4.3447722e-06 ,
55- 3.6515078e-06 ,
56- ]
57- )
58- np .testing .assert_array_almost_equal (result [0 ]["scores" ][:10 ], expected_detection_scores , decimal = 6 )
45+ assert result ["scores" ].shape == (1 ,)
46+ expected_detection_scores = np .asarray ([0.40862876 ])
47+ np .testing .assert_array_almost_equal (result ["scores" ], expected_detection_scores , decimal = 3 )
5948
60- assert result [0 ][ "labels" ].shape == (10647 ,)
61- expected_detection_classes = np .asarray ([0 , 0 , 14 , 14 , 14 , 14 , 14 , 14 , 14 , 0 ])
62- np .testing .assert_array_almost_equal (result [0 ][ "labels" ][: 10 ] , expected_detection_classes , decimal = 6 )
49+ assert result ["labels" ].shape == (1 ,)
50+ expected_detection_classes = np .asarray ([23 ])
51+ np .testing .assert_array_equal (result ["labels" ], expected_detection_classes )
6352
6453 except ARTTestException as e :
6554 art_warning (e )
@@ -92,120 +81,83 @@ def test_loss_gradient(art_warning, get_pytorch_yolo):
9281
9382 grads = object_detector .loss_gradient (x = x_test , y = y_test )
9483
95- assert grads .shape == (2 , 3 , 416 , 416 )
84+ assert grads .shape == (1 , 3 , 416 , 416 )
9685
9786 expected_gradients1 = np .asarray (
9887 [
99- 0.012576922 ,
100- - 0.005133151 ,
101- - 0.0028872574 ,
102- - 0.0029357928 ,
103- - 0.008929219 ,
104- 0.012767567 ,
105- - 0.00715934 ,
106- 0.00987368 ,
107- - 0.0014089097 ,
108- - 0.004765472 ,
109- - 0.007845592 ,
110- - 0.0065127434 ,
111- - 0.00047654763 ,
112- - 0.018194549 ,
113- 0.00025652442 ,
114- - 0.01420591 ,
115- 0.03873131 ,
116- 0.080963746 ,
117- - 0.009225381 ,
118- 0.026824722 ,
119- 0.005942673 ,
120- - 0.025760904 ,
121- 0.008754236 ,
122- - 0.037260942 ,
123- 0.027838552 ,
124- 0.0485742 ,
125- 0.020763855 ,
126- - 0.013568859 ,
127- - 0.0071423287 ,
128- 0.000802512 ,
129- 0.012983642 ,
130- 0.006466129 ,
131- 0.0025194373 ,
132- - 0.012298459 ,
133- - 0.01168492 ,
134- - 0.0013298508 ,
135- - 0.007176587 ,
136- 0.01996972 ,
137- - 0.004173076 ,
138- 0.029163878 ,
139- 0.022482246 ,
140- 0.008151911 ,
141- 0.025543496 ,
142- 0.0007374112 ,
143- 0.0008220682 ,
144- - 0.005740379 ,
145- 0.009537468 ,
146- - 0.01116704 ,
147- 0.0010225883 ,
148- 0.00026052812 ,
88+ - 0.00033619 ,
89+ 0.00458546 ,
90+ - 0.00084969 ,
91+ - 0.00095304 ,
92+ - 0.00403843 ,
93+ 0.00225406 ,
94+ - 0.00369539 ,
95+ - 0.0099816 ,
96+ - 0.01046214 ,
97+ - 0.00290693 ,
98+ 0.00075546 ,
99+ - 0.0002135 ,
100+ - 0.00659937 ,
101+ - 0.00380152 ,
102+ - 0.00593928 ,
103+ - 0.00179838 ,
104+ - 0.00213012 ,
105+ 0.00526429 ,
106+ 0.00332446 ,
107+ 0.00543861 ,
108+ 0.00284291 ,
109+ 0.00426832 ,
110+ - 0.00586808 ,
111+ - 0.0017767 ,
112+ - 0.00231807 ,
113+ - 0.01142277 ,
114+ - 0.00021731 ,
115+ 0.00076714 ,
116+ 0.00289533 ,
117+ 0.00993828 ,
118+ 0.00472939 ,
119+ 0.00232432 ,
149120 ]
150121 )
151-
152- np .testing .assert_array_almost_equal (grads [0 , 0 , 208 , 175 :225 ], expected_gradients1 , decimal = 1 )
122+ np .testing .assert_array_almost_equal (grads [0 , 0 , 208 , 192 :224 ], expected_gradients1 , decimal = 1 )
153123
154124 expected_gradients2 = np .asarray (
155125 [
156- 0.0049910736 ,
157- - 0.008941505 ,
158- - 0.013645802 ,
159- 0.0060615 ,
160- 0.0021073571 ,
161- - 0.0022195925 ,
162- - 0.006654369 ,
163- 0.010533731 ,
164- 0.0013077373 ,
165- - 0.010422451 ,
166- - 0.00034834983 ,
167- - 0.0040517827 ,
168- - 0.0001514384 ,
169- - 0.031307846 ,
170- - 0.008412821 ,
171- - 0.044170827 ,
172- 0.055609763 ,
173- 0.0220191 ,
174- - 0.019813634 ,
175- - 0.035893522 ,
176- 0.023970673 ,
177- - 0.08727841 ,
178- 0.0411198 ,
179- 0.0072751334 ,
180- 0.01716753 ,
181- 0.0391037 ,
182- 0.020182624 ,
183- 0.021557821 ,
184- 0.011461802 ,
185- 0.0046976856 ,
186- - 0.00304008 ,
187- - 0.010215744 ,
188- - 0.0074639097 ,
189- - 0.020115864 ,
190- - 0.05325762 ,
191- - 0.006238129 ,
192- - 0.006486116 ,
193- 0.09806269 ,
194- 0.03115965 ,
195- 0.066279344 ,
196- 0.05367205 ,
197- - 0.042338565 ,
198- 0.04456845 ,
199- 0.040167376 ,
200- 0.03357561 ,
201- 0.01510548 ,
202- 0.0006220075 ,
203- - 0.027102726 ,
204- - 0.020182101 ,
205- - 0.04347762 ,
126+ 0.00079487 ,
127+ 0.00426403 ,
128+ - 0.00151893 ,
129+ 0.00798506 ,
130+ 0.00937666 ,
131+ 0.01206836 ,
132+ - 0.00319753 ,
133+ 0.00506421 ,
134+ 0.00291614 ,
135+ - 0.00053876 ,
136+ 0.00281978 ,
137+ - 0.0027451 ,
138+ 0.00319698 ,
139+ 0.00287863 ,
140+ 0.00370754 ,
141+ 0.004611 ,
142+ - 0.00213012 ,
143+ 0.00440465 ,
144+ - 0.00077857 ,
145+ 0.00023536 ,
146+ 0.0035248 ,
147+ - 0.00810297 ,
148+ 0.00698602 ,
149+ 0.00877033 ,
150+ 0.01452724 ,
151+ 0.00161957 ,
152+ 0.02649526 ,
153+ - 0.0071549 ,
154+ 0.02670361 ,
155+ - 0.00759722 ,
156+ - 0.02353876 ,
157+ 0.00860081 ,
206158 ]
207159 )
208- np .testing .assert_array_almost_equal (grads [1 , 0 , 208 , 175 : 225 ], expected_gradients2 , decimal = 1 )
160+ np .testing .assert_array_almost_equal (grads [0 , 0 , 192 : 224 , 208 ], expected_gradients2 , decimal = 1 )
209161
210162 except ARTTestException as e :
211163 art_warning (e )
@@ -265,7 +217,7 @@ def test_preprocessing_defences(art_warning, get_pytorch_yolo):
265217 # Compute gradients
266218 grads = object_detector .loss_gradient (x = x_test , y = y_test )
267219
268- assert grads .shape == (2 , 3 , 416 , 416 )
220+ assert grads .shape == (1 , 3 , 416 , 416 )
269221
270222 except ARTTestException as e :
271223 art_warning (e )
@@ -290,7 +242,7 @@ def test_compute_loss(art_warning, get_pytorch_yolo):
290242 # Compute loss
291243 loss = object_detector .compute_loss (x = x_test , y = y_test )
292244
293- assert pytest .approx (11.20741 , abs = 1.5 ) == float (loss )
245+ assert pytest .approx (0.0920641 , abs = 0.05 ) == float (loss )
294246
295247 except ARTTestException as e :
296248 art_warning (e )
@@ -354,16 +306,16 @@ def test_patch(art_warning, get_pytorch_yolo):
354306 assert result [0 ]["scores" ].shape == (10647 ,)
355307 expected_detection_scores = np .asarray (
356308 [
357- 4.3653536e -08 ,
358- 3.3987994e -06 ,
359- 2.5681820e-06 ,
360- 3.9782722e -06 ,
361- 2.1766680e -05 ,
362- 2.6138965e -05 ,
363- 6.3377396e -05 ,
364- 7.6248516e -06 ,
365- 4.3447722e -06 ,
366- 3.6515078e -06 ,
309+ 2.0061936e -08 ,
310+ 8.2958641e -06 ,
311+ 1.5368976e-05 ,
312+ 8.5753290e -06 ,
313+ 1.5901747e -05 ,
314+ 3.8245958e -05 ,
315+ 4.6325898e -05 ,
316+ 7.1730128e -06 ,
317+ 4.3095843e -06 ,
318+ 1.0766385e -06 ,
367319 ]
368320 )
369321 np .testing .assert_raises (
0 commit comments