Skip to content

Commit 4465875

Browse files
committed
update pytorch object detector test cases
Signed-off-by: Farhan Ahmed <[email protected]>
1 parent b70cf34 commit 4465875

File tree

1 file changed

+131
-142
lines changed

1 file changed

+131
-142
lines changed

tests/estimators/object_detection/test_pytorch_object_detector.py

Lines changed: 131 additions & 142 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,30 @@
3030
@pytest.mark.only_with_platform("pytorch")
3131
def test_predict(art_warning, get_pytorch_object_detector):
3232
try:
33-
object_detector, x_test, _ = get_pytorch_object_detector
33+
from art.utils import non_maximum_suppression
3434

35-
result = object_detector.predict(x_test)
36-
assert list(result[0].keys()) == ["boxes", "labels", "scores"]
35+
object_detector, x_test, _ = get_pytorch_object_detector
3736

38-
assert result[0]["boxes"].shape == (7, 4)
39-
expected_detection_boxes = np.asarray([4.4017954, 6.3090835, 22.128296, 27.570665])
40-
np.testing.assert_array_almost_equal(result[0]["boxes"][2, :], expected_detection_boxes, decimal=3)
37+
preds = object_detector.predict(x_test)
38+
result = non_maximum_suppression(preds[0], iou_threshold=0.4, confidence_threshold=0.3)
39+
assert list(result.keys()) == ["boxes", "labels", "scores"]
4140

42-
assert result[0]["scores"].shape == (7,)
43-
expected_detection_scores = np.asarray(
44-
[0.3314798, 0.14125851, 0.13928168, 0.0996184, 0.08550017, 0.06690315, 0.05359321]
41+
assert result["boxes"].shape == (2, 4)
42+
expected_detection_boxes = np.asarray(
43+
[
44+
[6.136914, 22.481018, 413.05814, 346.08746],
45+
[0.000000, 24.181173, 406.47644, 342.62213],
46+
]
4547
)
46-
np.testing.assert_array_almost_equal(result[0]["scores"][:10], expected_detection_scores, decimal=6)
48+
np.testing.assert_array_almost_equal(result["boxes"], expected_detection_boxes, decimal=3)
49+
50+
assert result["scores"].shape == (2,)
51+
expected_detection_scores = np.asarray([0.4237412, 0.35696018])
52+
np.testing.assert_array_almost_equal(result["scores"], expected_detection_scores, decimal=3)
4753

48-
assert result[0]["labels"].shape == (7,)
49-
expected_detection_classes = np.asarray([72, 79, 1, 72, 78, 72, 82])
50-
np.testing.assert_array_almost_equal(result[0]["labels"][:10], expected_detection_classes, decimal=6)
54+
assert result["labels"].shape == (2,)
55+
expected_detection_classes = np.asarray([21, 18])
56+
np.testing.assert_array_equal(result["labels"], expected_detection_classes)
5157

5258
except ARTTestException as e:
5359
art_warning(e)
@@ -56,22 +62,30 @@ def test_predict(art_warning, get_pytorch_object_detector):
5662
@pytest.mark.only_with_platform("pytorch")
5763
def test_predict_mask(art_warning, get_pytorch_object_detector_mask):
5864
try:
65+
from art.utils import non_maximum_suppression
66+
5967
object_detector, x_test, _ = get_pytorch_object_detector_mask
6068

61-
result = object_detector.predict(x_test)
62-
assert list(result[0].keys()) == ["boxes", "labels", "scores", "masks"]
69+
preds = object_detector.predict(x_test)
70+
result = non_maximum_suppression(preds[0], iou_threshold=0.4, confidence_threshold=0.3)
71+
assert list(result.keys()) == ["boxes", "labels", "scores"]
6372

64-
assert result[0]["boxes"].shape == (4, 4)
65-
expected_detection_boxes = np.asarray([8.62889, 11.735134, 16.353355, 27.565004])
66-
np.testing.assert_array_almost_equal(result[0]["boxes"][2, :], expected_detection_boxes, decimal=3)
73+
assert result["boxes"].shape == (2, 4)
74+
expected_detection_boxes = np.asarray(
75+
[
76+
[44.097942, 22.865257, 415.32070, 294.20483],
77+
[25.739365, 33.178577, 416.00000, 338.51460],
78+
]
79+
)
80+
np.testing.assert_array_almost_equal(result["boxes"], expected_detection_boxes, decimal=3)
6781

68-
assert result[0]["scores"].shape == (4,)
69-
expected_detection_scores = np.asarray([0.45197296, 0.12707493, 0.082677, 0.05386855])
70-
np.testing.assert_array_almost_equal(result[0]["scores"][:10], expected_detection_scores, decimal=4)
82+
assert result["scores"].shape == (2,)
83+
expected_detection_scores = np.asarray([0.67316836, 0.5686724])
84+
np.testing.assert_array_almost_equal(result["scores"], expected_detection_scores, decimal=3)
7185

72-
assert result[0]["labels"].shape == (4,)
73-
expected_detection_classes = np.asarray([72, 72, 1, 1])
74-
np.testing.assert_array_almost_equal(result[0]["labels"][:10], expected_detection_classes, decimal=6)
86+
assert result["labels"].shape == (2,)
87+
expected_detection_classes = np.asarray([18, 21])
88+
np.testing.assert_array_equal(result["labels"], expected_detection_classes)
7589

7690
except ARTTestException as e:
7791
art_warning(e)
@@ -124,75 +138,59 @@ def test_loss_gradient(art_warning, get_pytorch_object_detector):
124138

125139
# Compute gradients
126140
grads = object_detector.loss_gradient(x_test, y_test)
127-
assert grads.shape == (2, 28, 28, 3)
141+
assert grads.shape == (1, 3, 416, 416)
128142

129143
expected_gradients1 = np.asarray(
130144
[
131-
[4.6265591e-04, 1.2323459e-03, 1.3915040e-03],
132-
[-3.2658060e-04, -3.6941725e-03, -4.5638453e-04],
133-
[7.8702159e-04, -3.3072452e-03, 3.0583731e-04],
134-
[1.0381485e-03, -2.0846087e-03, 2.3015277e-04],
135-
[2.1460971e-03, -1.3157589e-03, 3.5176644e-04],
136-
[3.3839934e-03, 1.3083456e-03, 1.6155940e-03],
137-
[3.8621046e-03, 1.6645766e-03, 1.8313043e-03],
138-
[3.0887076e-03, 1.4632678e-03, 1.1174511e-03],
139-
[3.3404885e-03, 2.0578136e-03, 9.6874911e-04],
140-
[3.2202434e-03, 7.2660763e-04, 8.9162006e-04],
141-
[3.5761783e-03, 2.3615893e-03, 8.8510796e-04],
142-
[3.4721815e-03, 1.9500104e-03, 9.2907902e-04],
143-
[3.4767685e-03, 2.1154548e-03, 5.5654044e-04],
144-
[3.9492580e-03, 3.5505455e-03, 6.5863604e-04],
145-
[3.9963769e-03, 4.0338552e-03, 3.9539216e-04],
146-
[2.2312226e-03, 5.1399925e-06, -1.0743635e-03],
147-
[2.3955442e-03, 6.7116896e-04, -1.2389944e-03],
148-
[1.9969011e-03, -4.5717746e-04, -1.5225793e-03],
149-
[1.8131963e-03, -7.7948131e-04, -1.6078206e-03],
150-
[1.4277012e-03, -7.7973347e-04, -1.3463887e-03],
151-
[7.3705515e-04, -1.1704378e-03, -9.8979671e-04],
152-
[1.0899740e-04, -1.2144407e-03, -1.1339665e-03],
153-
[1.2254890e-04, -4.7438752e-04, -8.8673591e-04],
154-
[7.0695346e-04, 7.2568876e-04, -2.5591519e-04],
155-
[5.0835893e-04, 2.6866698e-04, 2.2731400e-04],
156-
[-5.9932750e-04, -1.1667561e-03, -4.8044650e-04],
157-
[4.0421321e-04, 3.1692928e-04, -8.3296909e-05],
158-
[4.0506107e-05, -3.1728629e-04, -4.4132984e-04],
145+
-2.7225273e-05,
146+
-2.7225284e-05,
147+
-3.2535860e-05,
148+
-9.3287526e-06,
149+
-1.1088990e-05,
150+
-3.4527478e-05,
151+
5.7807661e-06,
152+
1.1616970e-05,
153+
2.9732121e-06,
154+
1.1190044e-05,
155+
-6.4673945e-06,
156+
-1.6562306e-05,
157+
-1.5946282e-05,
158+
-1.8079168e-06,
159+
-9.7664342e-06,
160+
6.2075532e-07,
161+
-8.9023115e-06,
162+
-1.5546989e-06,
163+
-7.2730008e-06,
164+
-7.5181362e-07,
159165
]
160166
)
161-
np.testing.assert_array_almost_equal(grads[0, 0, :, :], expected_gradients1, decimal=2)
167+
np.testing.assert_array_almost_equal(grads[0, 0, 0, :20], expected_gradients1, decimal=2)
162168

163169
expected_gradients2 = np.asarray(
164170
[
165-
[4.7986404e-04, 7.7701372e-04, 1.1786318e-03],
166-
[7.3503907e-04, -2.3474507e-03, -3.9008856e-04],
167-
[4.1874062e-04, -2.5707064e-03, -1.1054531e-03],
168-
[-1.7942721e-03, -3.3968450e-03, -1.4989552e-03],
169-
[-2.9697213e-03, -4.6922294e-03, -1.3162185e-03],
170-
[-3.1759157e-03, -9.8660104e-03, -4.7163852e-03],
171-
[1.8666144e-03, -2.8793041e-03, -3.1324378e-03],
172-
[1.0555880e-02, 7.6373261e-03, 5.3013843e-03],
173-
[8.9815725e-04, -1.0321697e-02, 1.4192325e-03],
174-
[8.5643278e-03, 3.0152409e-03, 2.0114987e-03],
175-
[-2.7870361e-03, -1.1686913e-02, -7.0649502e-03],
176-
[-7.7482774e-03, -1.3334424e-03, -9.1927368e-03],
177-
[-8.1487820e-03, -3.8133820e-03, -4.3300558e-03],
178-
[-7.7006700e-03, -1.2594147e-02, -3.9680018e-03],
179-
[-9.5743872e-03, -2.1007264e-02, -9.1963671e-03],
180-
[-8.6777220e-03, -1.7278835e-02, -1.3328674e-02],
181-
[-1.7368209e-02, -2.3461722e-02, -1.1538444e-02],
182-
[-4.6307812e-03, -5.7058665e-03, 1.3555109e-03],
183-
[4.8570461e-03, -5.8050654e-03, 8.1082489e-03],
184-
[6.4304657e-03, 2.8407066e-03, 8.7463465e-03],
185-
[5.0593228e-03, 1.4102085e-03, 5.2116364e-03],
186-
[2.5003455e-03, -6.0178695e-04, 2.0183939e-03],
187-
[2.1247163e-03, 4.7659015e-04, 7.5940741e-04],
188-
[1.3499497e-03, 6.2203623e-04, 1.2288829e-04],
189-
[2.8991612e-04, -4.0216290e-04, -7.2287643e-05],
190-
[6.6898909e-05, -6.3778006e-04, -3.6294860e-04],
191-
[5.3613615e-04, 9.9137833e-05, -1.6657988e-05],
192-
[-3.9828232e-05, -3.8453130e-04, -2.3702848e-04],
171+
-2.7307957e-05,
172+
-1.9417710e-05,
173+
-2.0928457e-05,
174+
-2.1384752e-05,
175+
-2.5035972e-05,
176+
-3.6572790e-05,
177+
-8.2444545e-05,
178+
-7.3255811e-05,
179+
-4.5060227e-05,
180+
-1.9829258e-05,
181+
-2.2043951e-05,
182+
-3.6746951e-05,
183+
-4.2588043e-05,
184+
-3.1833035e-05,
185+
-1.5923406e-05,
186+
-3.5026955e-05,
187+
-4.4511849e-05,
188+
-3.3867167e-05,
189+
-1.8569792e-05,
190+
-3.5141209e-05,
193191
]
194192
)
195-
np.testing.assert_array_almost_equal(grads[1, 0, :, :], expected_gradients2, decimal=2)
193+
np.testing.assert_array_almost_equal(grads[0, 0, :20, 0], expected_gradients2, decimal=2)
196194

197195
except ARTTestException as e:
198196
art_warning(e)
@@ -205,75 +203,66 @@ def test_loss_gradient_mask(art_warning, get_pytorch_object_detector_mask):
205203

206204
# Compute gradients
207205
grads = object_detector.loss_gradient(x_test, y_test)
208-
assert grads.shape == (2, 28, 28, 3)
206+
assert grads.shape == (1, 3, 416, 416)
207+
208+
import pprint
209+
210+
print()
211+
pprint.pprint(grads[0, 0, 0, :20])
212+
print()
213+
pprint.pprint(grads[0, 0, :20, 0])
209214

210215
expected_gradients1 = np.asarray(
211216
[
212-
[1.2062087e-03, 6.7400718e-03, 9.5682510e-04],
213-
[-3.6111937e-03, -5.3175041e-03, -3.2421902e-03],
214-
[1.4717830e-03, 1.0347518e-03, 1.7675158e-04],
215-
[2.9278828e-03, 5.0933827e-03, 3.5095078e-04],
216-
[-3.1896026e-04, 3.6363016e-04, -6.6032895e-04],
217-
[-3.8130947e-03, -5.5106943e-03, -2.3003859e-03],
218-
[-4.1348115e-03, -6.5722968e-03, -1.5899740e-03],
219-
[-2.4562061e-03, -4.1960045e-03, -1.7881666e-03],
220-
[2.2911791e-04, -6.4335053e-04, -1.6564501e-03],
221-
[-1.2582233e-03, -1.5607923e-03, -2.2904854e-03],
222-
[-1.8436739e-03, -2.7200577e-03, -2.9125123e-03],
223-
[-1.5151387e-03, -4.4148900e-03, -1.7429549e-03],
224-
[5.4955669e-03, 8.1859864e-03, 1.6560742e-03],
225-
[3.1721895e-03, 2.4013112e-03, -1.9453048e-04],
226-
[5.1122587e-03, 7.4281446e-03, 2.4133435e-04],
227-
[2.7988979e-03, 4.4798232e-03, -1.2488490e-03],
228-
[3.1651880e-03, 4.5040119e-03, -1.6507130e-03],
229-
[8.5774017e-04, 9.9022139e-04, -3.1324981e-03],
230-
[3.8568545e-04, 4.7918499e-04, -2.4925626e-03],
231-
[-1.8368122e-03, -3.9491002e-03, -3.9275796e-03],
232-
[1.6731160e-03, 1.5304115e-03, -1.4627117e-03],
233-
[1.4445755e-03, 1.4263670e-03, -2.0084691e-03],
234-
[2.0193408e-04, 7.2605687e-04, -1.8740210e-03],
235-
[-1.3681910e-03, 1.7499415e-05, -2.4952039e-03],
236-
[1.3475126e-04, 3.0096075e-03, -2.4493274e-04],
237-
[-6.2653446e-03, -9.5283017e-03, -2.9458744e-03],
238-
[-2.6554640e-03, -1.4588287e-03, -3.2393888e-03],
239-
[-6.4712246e-03, -7.2136321e-03, -5.4933843e-03],
217+
-4.2168313e-06,
218+
-4.4972450e-05,
219+
-3.6137710e-05,
220+
-1.2499937e-06,
221+
1.2728384e-05,
222+
-1.7352231e-05,
223+
5.6671047e-06,
224+
1.4085637e-05,
225+
5.9047998e-06,
226+
1.0826078e-05,
227+
2.2078505e-06,
228+
-1.3319310e-05,
229+
-2.4521427e-05,
230+
-1.8251436e-05,
231+
-1.9938851e-05,
232+
-3.6778667e-07,
233+
1.1899039e-05,
234+
1.9246204e-06,
235+
-2.7922330e-05,
236+
-3.2529952e-06,
240237
]
241238
)
242-
np.testing.assert_array_almost_equal(grads[0, 0, :, :], expected_gradients1, decimal=2)
239+
np.testing.assert_array_almost_equal(grads[0, 0, 0, :20], expected_gradients1, decimal=2)
243240

244241
expected_gradients2 = np.asarray(
245242
[
246-
[-2.0123991e-04, -9.0955076e-04, -2.2947363e-04],
247-
[3.0414842e-04, 3.4150464e-04, 2.1101040e-04],
248-
[6.6070761e-06, -1.8034373e-04, 1.3608378e-05],
249-
[-1.3393547e-05, -3.2230929e-04, -5.5581659e-05],
250-
[-1.0353983e-04, -2.7751207e-04, -2.3205159e-04],
251-
[-5.3371373e-04, -1.1550108e-03, -2.6975147e-04],
252-
[-2.6593581e-04, -7.3971582e-04, -7.4292002e-05],
253-
[-9.3046663e-05, -4.0410538e-04, -1.4271366e-04],
254-
[-1.3833238e-04, -5.6283473e-04, -8.4650565e-05],
255-
[-8.0315210e-04, -1.4300735e-03, -9.3330207e-05],
256-
[2.7694018e-04, 6.8307301e-04, 5.5274006e-04],
257-
[3.1839000e-04, 9.7277382e-04, 4.6252453e-04],
258-
[2.8279822e-04, 6.2632316e-04, 3.3778447e-04],
259-
[4.0508871e-04, 1.2438387e-03, 3.6151547e-04],
260-
[-7.5090391e-04, -2.6640363e-04, -2.6418429e-04],
261-
[-2.3455340e-03, -4.9932003e-03, -8.0432917e-04],
262-
[4.1711782e-03, 5.3390805e-03, 2.4412808e-03],
263-
[5.1162727e-03, 5.2886135e-03, 3.6190096e-03],
264-
[6.9976337e-03, 9.7018024e-03, 3.8526775e-03],
265-
[4.5005931e-03, 4.3762275e-03, 1.7228650e-03],
266-
[6.3695023e-03, 8.4943371e-03, 1.7638379e-03],
267-
[3.0587378e-03, 3.9485283e-03, 4.9000646e-05],
268-
[-3.2190280e-04, -6.6311209e-04, -9.8086358e-04],
269-
[8.3606638e-04, 2.0184387e-03, -3.5464868e-04],
270-
[-1.8979331e-04, 3.1042210e-04, -4.2471994e-04],
271-
[-8.8790455e-04, -1.4127755e-03, -4.4270226e-04],
272-
[4.1172301e-04, 2.9453568e-04, 2.1122720e-04],
273-
[1.6500468e-04, 3.7142841e-04, -4.5339554e-04],
243+
-4.2168313e-06,
244+
-9.3028730e-06,
245+
1.5900954e-06,
246+
-9.7032771e-06,
247+
-7.9553565e-06,
248+
-1.9485701e-06,
249+
-1.3360468e-05,
250+
-2.7804586e-05,
251+
-4.2667002e-06,
252+
-6.1407286e-06,
253+
-6.6768125e-06,
254+
-1.6444834e-06,
255+
4.7967392e-06,
256+
2.4288647e-06,
257+
1.0280205e-05,
258+
4.2001102e-06,
259+
2.9494076e-05,
260+
1.4654281e-05,
261+
2.5580388e-05,
262+
3.0241908e-05,
274263
]
275264
)
276-
np.testing.assert_array_almost_equal(grads[1, 0, :, :], expected_gradients2, decimal=2)
265+
np.testing.assert_array_almost_equal(grads[0, 0, :20, 0], expected_gradients2, decimal=2)
277266

278267
except ARTTestException as e:
279268
art_warning(e)

0 commit comments

Comments
 (0)