3030@pytest .mark .only_with_platform ("pytorch" )
3131def test_predict (art_warning , get_pytorch_object_detector ):
3232 try :
33- object_detector , x_test , _ = get_pytorch_object_detector
33+ from art . utils import non_maximum_suppression
3434
35- result = object_detector .predict (x_test )
36- assert list (result [0 ].keys ()) == ["boxes" , "labels" , "scores" ]
35+ object_detector , x_test , _ = get_pytorch_object_detector
3736
38- assert result [ 0 ][ "boxes" ]. shape == ( 7 , 4 )
39- expected_detection_boxes = np . asarray ([ 4.4017954 , 6.3090835 , 22.128296 , 27.570665 ] )
40- np . testing . assert_array_almost_equal (result [ 0 ][ "boxes" ][ 2 , :], expected_detection_boxes , decimal = 3 )
37+ preds = object_detector . predict ( x_test )
38+ result = non_maximum_suppression ( preds [ 0 ], iou_threshold = 0.4 , confidence_threshold = 0.3 )
39+ assert list (result . keys ()) == [ "boxes" , "labels" , "scores" ]
4140
42- assert result [0 ]["scores" ].shape == (7 ,)
43- expected_detection_scores = np .asarray (
44- [0.3314798 , 0.14125851 , 0.13928168 , 0.0996184 , 0.08550017 , 0.06690315 , 0.05359321 ]
41+ assert result ["boxes" ].shape == (2 , 4 )
42+ expected_detection_boxes = np .asarray (
43+ [
44+ [6.136914 , 22.481018 , 413.05814 , 346.08746 ],
45+ [0.000000 , 24.181173 , 406.47644 , 342.62213 ],
46+ ]
4547 )
46- np .testing .assert_array_almost_equal (result [0 ]["scores" ][:10 ], expected_detection_scores , decimal = 6 )
48+ np .testing .assert_array_almost_equal (result ["boxes" ], expected_detection_boxes , decimal = 3 )
49+
50+ assert result ["scores" ].shape == (2 ,)
51+ expected_detection_scores = np .asarray ([0.4237412 , 0.35696018 ])
52+ np .testing .assert_array_almost_equal (result ["scores" ], expected_detection_scores , decimal = 3 )
4753
48- assert result [0 ][ "labels" ].shape == (7 ,)
49- expected_detection_classes = np .asarray ([72 , 79 , 1 , 72 , 78 , 72 , 82 ])
50- np .testing .assert_array_almost_equal (result [0 ][ "labels" ][: 10 ] , expected_detection_classes , decimal = 6 )
54+ assert result ["labels" ].shape == (2 ,)
55+ expected_detection_classes = np .asarray ([21 , 18 ])
56+ np .testing .assert_array_equal (result ["labels" ], expected_detection_classes )
5157
5258 except ARTTestException as e :
5359 art_warning (e )
@@ -56,22 +62,30 @@ def test_predict(art_warning, get_pytorch_object_detector):
5662@pytest .mark .only_with_platform ("pytorch" )
5763def test_predict_mask (art_warning , get_pytorch_object_detector_mask ):
5864 try :
65+ from art .utils import non_maximum_suppression
66+
5967 object_detector , x_test , _ = get_pytorch_object_detector_mask
6068
61- result = object_detector .predict (x_test )
62- assert list (result [0 ].keys ()) == ["boxes" , "labels" , "scores" , "masks" ]
69+ preds = object_detector .predict (x_test )
70+ result = non_maximum_suppression (preds [0 ], iou_threshold = 0.4 , confidence_threshold = 0.3 )
71+ assert list (result .keys ()) == ["boxes" , "labels" , "scores" ]
6372
64- assert result [0 ]["boxes" ].shape == (4 , 4 )
65- expected_detection_boxes = np .asarray ([8.62889 , 11.735134 , 16.353355 , 27.565004 ])
66- np .testing .assert_array_almost_equal (result [0 ]["boxes" ][2 , :], expected_detection_boxes , decimal = 3 )
73+ assert result ["boxes" ].shape == (2 , 4 )
74+ expected_detection_boxes = np .asarray (
75+ [
76+ [44.097942 , 22.865257 , 415.32070 , 294.20483 ],
77+ [25.739365 , 33.178577 , 416.00000 , 338.51460 ],
78+ ]
79+ )
80+ np .testing .assert_array_almost_equal (result ["boxes" ], expected_detection_boxes , decimal = 3 )
6781
68- assert result [0 ][ "scores" ].shape == (4 ,)
69- expected_detection_scores = np .asarray ([0.45197296 , 0.12707493 , 0.082677 , 0.05386855 ])
70- np .testing .assert_array_almost_equal (result [0 ][ "scores" ][: 10 ] , expected_detection_scores , decimal = 4 )
82+ assert result ["scores" ].shape == (2 ,)
83+ expected_detection_scores = np .asarray ([0.67316836 , 0.5686724 ])
84+ np .testing .assert_array_almost_equal (result ["scores" ], expected_detection_scores , decimal = 3 )
7185
72- assert result [0 ][ "labels" ].shape == (4 ,)
73- expected_detection_classes = np .asarray ([72 , 72 , 1 , 1 ])
74- np .testing .assert_array_almost_equal (result [0 ][ "labels" ][: 10 ] , expected_detection_classes , decimal = 6 )
86+ assert result ["labels" ].shape == (2 ,)
87+ expected_detection_classes = np .asarray ([18 , 21 ])
88+ np .testing .assert_array_equal (result ["labels" ], expected_detection_classes )
7589
7690 except ARTTestException as e :
7791 art_warning (e )
@@ -124,75 +138,59 @@ def test_loss_gradient(art_warning, get_pytorch_object_detector):
124138
125139 # Compute gradients
126140 grads = object_detector .loss_gradient (x_test , y_test )
127- assert grads .shape == (2 , 28 , 28 , 3 )
141+ assert grads .shape == (1 , 3 , 416 , 416 )
128142
129143 expected_gradients1 = np .asarray (
130144 [
131- [4.6265591e-04 , 1.2323459e-03 , 1.3915040e-03 ],
132- [- 3.2658060e-04 , - 3.6941725e-03 , - 4.5638453e-04 ],
133- [7.8702159e-04 , - 3.3072452e-03 , 3.0583731e-04 ],
134- [1.0381485e-03 , - 2.0846087e-03 , 2.3015277e-04 ],
135- [2.1460971e-03 , - 1.3157589e-03 , 3.5176644e-04 ],
136- [3.3839934e-03 , 1.3083456e-03 , 1.6155940e-03 ],
137- [3.8621046e-03 , 1.6645766e-03 , 1.8313043e-03 ],
138- [3.0887076e-03 , 1.4632678e-03 , 1.1174511e-03 ],
139- [3.3404885e-03 , 2.0578136e-03 , 9.6874911e-04 ],
140- [3.2202434e-03 , 7.2660763e-04 , 8.9162006e-04 ],
141- [3.5761783e-03 , 2.3615893e-03 , 8.8510796e-04 ],
142- [3.4721815e-03 , 1.9500104e-03 , 9.2907902e-04 ],
143- [3.4767685e-03 , 2.1154548e-03 , 5.5654044e-04 ],
144- [3.9492580e-03 , 3.5505455e-03 , 6.5863604e-04 ],
145- [3.9963769e-03 , 4.0338552e-03 , 3.9539216e-04 ],
146- [2.2312226e-03 , 5.1399925e-06 , - 1.0743635e-03 ],
147- [2.3955442e-03 , 6.7116896e-04 , - 1.2389944e-03 ],
148- [1.9969011e-03 , - 4.5717746e-04 , - 1.5225793e-03 ],
149- [1.8131963e-03 , - 7.7948131e-04 , - 1.6078206e-03 ],
150- [1.4277012e-03 , - 7.7973347e-04 , - 1.3463887e-03 ],
151- [7.3705515e-04 , - 1.1704378e-03 , - 9.8979671e-04 ],
152- [1.0899740e-04 , - 1.2144407e-03 , - 1.1339665e-03 ],
153- [1.2254890e-04 , - 4.7438752e-04 , - 8.8673591e-04 ],
154- [7.0695346e-04 , 7.2568876e-04 , - 2.5591519e-04 ],
155- [5.0835893e-04 , 2.6866698e-04 , 2.2731400e-04 ],
156- [- 5.9932750e-04 , - 1.1667561e-03 , - 4.8044650e-04 ],
157- [4.0421321e-04 , 3.1692928e-04 , - 8.3296909e-05 ],
158- [4.0506107e-05 , - 3.1728629e-04 , - 4.4132984e-04 ],
145+ - 2.7225273e-05 ,
146+ - 2.7225284e-05 ,
147+ - 3.2535860e-05 ,
148+ - 9.3287526e-06 ,
149+ - 1.1088990e-05 ,
150+ - 3.4527478e-05 ,
151+ 5.7807661e-06 ,
152+ 1.1616970e-05 ,
153+ 2.9732121e-06 ,
154+ 1.1190044e-05 ,
155+ - 6.4673945e-06 ,
156+ - 1.6562306e-05 ,
157+ - 1.5946282e-05 ,
158+ - 1.8079168e-06 ,
159+ - 9.7664342e-06 ,
160+ 6.2075532e-07 ,
161+ - 8.9023115e-06 ,
162+ - 1.5546989e-06 ,
163+ - 7.2730008e-06 ,
164+ - 7.5181362e-07 ,
159165 ]
160166 )
161- np .testing .assert_array_almost_equal (grads [0 , 0 , : , :], expected_gradients1 , decimal = 2 )
167+ np .testing .assert_array_almost_equal (grads [0 , 0 , 0 , :20 ], expected_gradients1 , decimal = 2 )
162168
163169 expected_gradients2 = np .asarray (
164170 [
165- [4.7986404e-04 , 7.7701372e-04 , 1.1786318e-03 ],
166- [7.3503907e-04 , - 2.3474507e-03 , - 3.9008856e-04 ],
167- [4.1874062e-04 , - 2.5707064e-03 , - 1.1054531e-03 ],
168- [- 1.7942721e-03 , - 3.3968450e-03 , - 1.4989552e-03 ],
169- [- 2.9697213e-03 , - 4.6922294e-03 , - 1.3162185e-03 ],
170- [- 3.1759157e-03 , - 9.8660104e-03 , - 4.7163852e-03 ],
171- [1.8666144e-03 , - 2.8793041e-03 , - 3.1324378e-03 ],
172- [1.0555880e-02 , 7.6373261e-03 , 5.3013843e-03 ],
173- [8.9815725e-04 , - 1.0321697e-02 , 1.4192325e-03 ],
174- [8.5643278e-03 , 3.0152409e-03 , 2.0114987e-03 ],
175- [- 2.7870361e-03 , - 1.1686913e-02 , - 7.0649502e-03 ],
176- [- 7.7482774e-03 , - 1.3334424e-03 , - 9.1927368e-03 ],
177- [- 8.1487820e-03 , - 3.8133820e-03 , - 4.3300558e-03 ],
178- [- 7.7006700e-03 , - 1.2594147e-02 , - 3.9680018e-03 ],
179- [- 9.5743872e-03 , - 2.1007264e-02 , - 9.1963671e-03 ],
180- [- 8.6777220e-03 , - 1.7278835e-02 , - 1.3328674e-02 ],
181- [- 1.7368209e-02 , - 2.3461722e-02 , - 1.1538444e-02 ],
182- [- 4.6307812e-03 , - 5.7058665e-03 , 1.3555109e-03 ],
183- [4.8570461e-03 , - 5.8050654e-03 , 8.1082489e-03 ],
184- [6.4304657e-03 , 2.8407066e-03 , 8.7463465e-03 ],
185- [5.0593228e-03 , 1.4102085e-03 , 5.2116364e-03 ],
186- [2.5003455e-03 , - 6.0178695e-04 , 2.0183939e-03 ],
187- [2.1247163e-03 , 4.7659015e-04 , 7.5940741e-04 ],
188- [1.3499497e-03 , 6.2203623e-04 , 1.2288829e-04 ],
189- [2.8991612e-04 , - 4.0216290e-04 , - 7.2287643e-05 ],
190- [6.6898909e-05 , - 6.3778006e-04 , - 3.6294860e-04 ],
191- [5.3613615e-04 , 9.9137833e-05 , - 1.6657988e-05 ],
192- [- 3.9828232e-05 , - 3.8453130e-04 , - 2.3702848e-04 ],
171+ - 2.7307957e-05 ,
172+ - 1.9417710e-05 ,
173+ - 2.0928457e-05 ,
174+ - 2.1384752e-05 ,
175+ - 2.5035972e-05 ,
176+ - 3.6572790e-05 ,
177+ - 8.2444545e-05 ,
178+ - 7.3255811e-05 ,
179+ - 4.5060227e-05 ,
180+ - 1.9829258e-05 ,
181+ - 2.2043951e-05 ,
182+ - 3.6746951e-05 ,
183+ - 4.2588043e-05 ,
184+ - 3.1833035e-05 ,
185+ - 1.5923406e-05 ,
186+ - 3.5026955e-05 ,
187+ - 4.4511849e-05 ,
188+ - 3.3867167e-05 ,
189+ - 1.8569792e-05 ,
190+ - 3.5141209e-05 ,
193191 ]
194192 )
195- np .testing .assert_array_almost_equal (grads [1 , 0 , :, : ], expected_gradients2 , decimal = 2 )
193+ np .testing .assert_array_almost_equal (grads [0 , 0 , :20 , 0 ], expected_gradients2 , decimal = 2 )
196194
197195 except ARTTestException as e :
198196 art_warning (e )
@@ -205,75 +203,66 @@ def test_loss_gradient_mask(art_warning, get_pytorch_object_detector_mask):
205203
206204 # Compute gradients
207205 grads = object_detector .loss_gradient (x_test , y_test )
208- assert grads .shape == (2 , 28 , 28 , 3 )
206+ assert grads .shape == (1 , 3 , 416 , 416 )
207+
208+ import pprint
209+
210+ print ()
211+ pprint .pprint (grads [0 , 0 , 0 , :20 ])
212+ print ()
213+ pprint .pprint (grads [0 , 0 , :20 , 0 ])
209214
210215 expected_gradients1 = np .asarray (
211216 [
212- [1.2062087e-03 , 6.7400718e-03 , 9.5682510e-04 ],
213- [- 3.6111937e-03 , - 5.3175041e-03 , - 3.2421902e-03 ],
214- [1.4717830e-03 , 1.0347518e-03 , 1.7675158e-04 ],
215- [2.9278828e-03 , 5.0933827e-03 , 3.5095078e-04 ],
216- [- 3.1896026e-04 , 3.6363016e-04 , - 6.6032895e-04 ],
217- [- 3.8130947e-03 , - 5.5106943e-03 , - 2.3003859e-03 ],
218- [- 4.1348115e-03 , - 6.5722968e-03 , - 1.5899740e-03 ],
219- [- 2.4562061e-03 , - 4.1960045e-03 , - 1.7881666e-03 ],
220- [2.2911791e-04 , - 6.4335053e-04 , - 1.6564501e-03 ],
221- [- 1.2582233e-03 , - 1.5607923e-03 , - 2.2904854e-03 ],
222- [- 1.8436739e-03 , - 2.7200577e-03 , - 2.9125123e-03 ],
223- [- 1.5151387e-03 , - 4.4148900e-03 , - 1.7429549e-03 ],
224- [5.4955669e-03 , 8.1859864e-03 , 1.6560742e-03 ],
225- [3.1721895e-03 , 2.4013112e-03 , - 1.9453048e-04 ],
226- [5.1122587e-03 , 7.4281446e-03 , 2.4133435e-04 ],
227- [2.7988979e-03 , 4.4798232e-03 , - 1.2488490e-03 ],
228- [3.1651880e-03 , 4.5040119e-03 , - 1.6507130e-03 ],
229- [8.5774017e-04 , 9.9022139e-04 , - 3.1324981e-03 ],
230- [3.8568545e-04 , 4.7918499e-04 , - 2.4925626e-03 ],
231- [- 1.8368122e-03 , - 3.9491002e-03 , - 3.9275796e-03 ],
232- [1.6731160e-03 , 1.5304115e-03 , - 1.4627117e-03 ],
233- [1.4445755e-03 , 1.4263670e-03 , - 2.0084691e-03 ],
234- [2.0193408e-04 , 7.2605687e-04 , - 1.8740210e-03 ],
235- [- 1.3681910e-03 , 1.7499415e-05 , - 2.4952039e-03 ],
236- [1.3475126e-04 , 3.0096075e-03 , - 2.4493274e-04 ],
237- [- 6.2653446e-03 , - 9.5283017e-03 , - 2.9458744e-03 ],
238- [- 2.6554640e-03 , - 1.4588287e-03 , - 3.2393888e-03 ],
239- [- 6.4712246e-03 , - 7.2136321e-03 , - 5.4933843e-03 ],
217+ - 4.2168313e-06 ,
218+ - 4.4972450e-05 ,
219+ - 3.6137710e-05 ,
220+ - 1.2499937e-06 ,
221+ 1.2728384e-05 ,
222+ - 1.7352231e-05 ,
223+ 5.6671047e-06 ,
224+ 1.4085637e-05 ,
225+ 5.9047998e-06 ,
226+ 1.0826078e-05 ,
227+ 2.2078505e-06 ,
228+ - 1.3319310e-05 ,
229+ - 2.4521427e-05 ,
230+ - 1.8251436e-05 ,
231+ - 1.9938851e-05 ,
232+ - 3.6778667e-07 ,
233+ 1.1899039e-05 ,
234+ 1.9246204e-06 ,
235+ - 2.7922330e-05 ,
236+ - 3.2529952e-06 ,
240237 ]
241238 )
242- np .testing .assert_array_almost_equal (grads [0 , 0 , : , :], expected_gradients1 , decimal = 2 )
239+ np .testing .assert_array_almost_equal (grads [0 , 0 , 0 , :20 ], expected_gradients1 , decimal = 2 )
243240
244241 expected_gradients2 = np .asarray (
245242 [
246- [- 2.0123991e-04 , - 9.0955076e-04 , - 2.2947363e-04 ],
247- [3.0414842e-04 , 3.4150464e-04 , 2.1101040e-04 ],
248- [6.6070761e-06 , - 1.8034373e-04 , 1.3608378e-05 ],
249- [- 1.3393547e-05 , - 3.2230929e-04 , - 5.5581659e-05 ],
250- [- 1.0353983e-04 , - 2.7751207e-04 , - 2.3205159e-04 ],
251- [- 5.3371373e-04 , - 1.1550108e-03 , - 2.6975147e-04 ],
252- [- 2.6593581e-04 , - 7.3971582e-04 , - 7.4292002e-05 ],
253- [- 9.3046663e-05 , - 4.0410538e-04 , - 1.4271366e-04 ],
254- [- 1.3833238e-04 , - 5.6283473e-04 , - 8.4650565e-05 ],
255- [- 8.0315210e-04 , - 1.4300735e-03 , - 9.3330207e-05 ],
256- [2.7694018e-04 , 6.8307301e-04 , 5.5274006e-04 ],
257- [3.1839000e-04 , 9.7277382e-04 , 4.6252453e-04 ],
258- [2.8279822e-04 , 6.2632316e-04 , 3.3778447e-04 ],
259- [4.0508871e-04 , 1.2438387e-03 , 3.6151547e-04 ],
260- [- 7.5090391e-04 , - 2.6640363e-04 , - 2.6418429e-04 ],
261- [- 2.3455340e-03 , - 4.9932003e-03 , - 8.0432917e-04 ],
262- [4.1711782e-03 , 5.3390805e-03 , 2.4412808e-03 ],
263- [5.1162727e-03 , 5.2886135e-03 , 3.6190096e-03 ],
264- [6.9976337e-03 , 9.7018024e-03 , 3.8526775e-03 ],
265- [4.5005931e-03 , 4.3762275e-03 , 1.7228650e-03 ],
266- [6.3695023e-03 , 8.4943371e-03 , 1.7638379e-03 ],
267- [3.0587378e-03 , 3.9485283e-03 , 4.9000646e-05 ],
268- [- 3.2190280e-04 , - 6.6311209e-04 , - 9.8086358e-04 ],
269- [8.3606638e-04 , 2.0184387e-03 , - 3.5464868e-04 ],
270- [- 1.8979331e-04 , 3.1042210e-04 , - 4.2471994e-04 ],
271- [- 8.8790455e-04 , - 1.4127755e-03 , - 4.4270226e-04 ],
272- [4.1172301e-04 , 2.9453568e-04 , 2.1122720e-04 ],
273- [1.6500468e-04 , 3.7142841e-04 , - 4.5339554e-04 ],
243+ - 4.2168313e-06 ,
244+ - 9.3028730e-06 ,
245+ 1.5900954e-06 ,
246+ - 9.7032771e-06 ,
247+ - 7.9553565e-06 ,
248+ - 1.9485701e-06 ,
249+ - 1.3360468e-05 ,
250+ - 2.7804586e-05 ,
251+ - 4.2667002e-06 ,
252+ - 6.1407286e-06 ,
253+ - 6.6768125e-06 ,
254+ - 1.6444834e-06 ,
255+ 4.7967392e-06 ,
256+ 2.4288647e-06 ,
257+ 1.0280205e-05 ,
258+ 4.2001102e-06 ,
259+ 2.9494076e-05 ,
260+ 1.4654281e-05 ,
261+ 2.5580388e-05 ,
262+ 3.0241908e-05 ,
274263 ]
275264 )
276- np .testing .assert_array_almost_equal (grads [1 , 0 , :, : ], expected_gradients2 , decimal = 2 )
265+ np .testing .assert_array_almost_equal (grads [0 , 0 , :20 , 0 ], expected_gradients2 , decimal = 2 )
277266
278267 except ARTTestException as e :
279268 art_warning (e )
0 commit comments