Skip to content

Commit b70cf34

Browse files
committed
update pytorch faster rcnn test cases
Signed-off-by: Farhan Ahmed <[email protected]>
1 parent 41a0143 commit b70cf34

File tree

1 file changed

+64
-74
lines changed

1 file changed

+64
-74
lines changed

tests/estimators/object_detection/test_pytorch_faster_rcnn.py

Lines changed: 64 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,30 @@
3030
@pytest.mark.only_with_platform("pytorch")
3131
def test_predict(art_warning, get_pytorch_faster_rcnn):
3232
try:
33-
object_detector, x_test, _ = get_pytorch_faster_rcnn
33+
from art.utils import non_maximum_suppression
3434

35-
result = object_detector.predict(x_test)
36-
assert list(result[0].keys()) == ["boxes", "labels", "scores"]
35+
object_detector, x_test, _ = get_pytorch_faster_rcnn
3736

38-
assert result[0]["boxes"].shape == (7, 4)
39-
expected_detection_boxes = np.asarray([4.4017954, 6.3090835, 22.128296, 27.570665])
40-
np.testing.assert_array_almost_equal(result[0]["boxes"][2, :], expected_detection_boxes, decimal=3)
37+
preds = object_detector.predict(x_test)
38+
result = non_maximum_suppression(preds[0], iou_threshold=0.4, confidence_threshold=0.3)
39+
assert list(result.keys()) == ["boxes", "labels", "scores"]
4140

42-
assert result[0]["scores"].shape == (7,)
43-
expected_detection_scores = np.asarray(
44-
[0.3314798, 0.14125851, 0.13928168, 0.0996184, 0.08550017, 0.06690315, 0.05359321]
41+
assert result["boxes"].shape == (2, 4)
42+
expected_detection_boxes = np.asarray(
43+
[
44+
[6.136914, 22.481018, 413.05814, 346.08746],
45+
[0.000000, 24.181173, 406.47644, 342.62213],
46+
]
4547
)
46-
np.testing.assert_array_almost_equal(result[0]["scores"][:10], expected_detection_scores, decimal=6)
48+
np.testing.assert_array_almost_equal(result["boxes"], expected_detection_boxes, decimal=3)
49+
50+
assert result["scores"].shape == (2,)
51+
expected_detection_scores = np.asarray([0.4237412, 0.35696018])
52+
np.testing.assert_array_almost_equal(result["scores"], expected_detection_scores, decimal=3)
4753

48-
assert result[0]["labels"].shape == (7,)
49-
expected_detection_classes = np.asarray([72, 79, 1, 72, 78, 72, 82])
50-
np.testing.assert_array_almost_equal(result[0]["labels"][:10], expected_detection_classes, decimal=6)
54+
assert result["labels"].shape == (2,)
55+
expected_detection_classes = np.asarray([21, 18])
56+
np.testing.assert_array_equal(result["labels"], expected_detection_classes)
5157

5258
except ARTTestException as e:
5359
art_warning(e)
@@ -80,75 +86,59 @@ def test_loss_gradient(art_warning, get_pytorch_faster_rcnn):
8086

8187
# Compute gradients
8288
grads = object_detector.loss_gradient(x_test, y_test)
83-
assert grads.shape == (2, 28, 28, 3)
89+
assert grads.shape == (1, 3, 416, 416)
8490

8591
expected_gradients1 = np.asarray(
8692
[
87-
[4.6265591e-04, 1.2323459e-03, 1.3915040e-03],
88-
[-3.2658060e-04, -3.6941725e-03, -4.5638453e-04],
89-
[7.8702159e-04, -3.3072452e-03, 3.0583731e-04],
90-
[1.0381485e-03, -2.0846087e-03, 2.3015277e-04],
91-
[2.1460971e-03, -1.3157589e-03, 3.5176644e-04],
92-
[3.3839934e-03, 1.3083456e-03, 1.6155940e-03],
93-
[3.8621046e-03, 1.6645766e-03, 1.8313043e-03],
94-
[3.0887076e-03, 1.4632678e-03, 1.1174511e-03],
95-
[3.3404885e-03, 2.0578136e-03, 9.6874911e-04],
96-
[3.2202434e-03, 7.2660763e-04, 8.9162006e-04],
97-
[3.5761783e-03, 2.3615893e-03, 8.8510796e-04],
98-
[3.4721815e-03, 1.9500104e-03, 9.2907902e-04],
99-
[3.4767685e-03, 2.1154548e-03, 5.5654044e-04],
100-
[3.9492580e-03, 3.5505455e-03, 6.5863604e-04],
101-
[3.9963769e-03, 4.0338552e-03, 3.9539216e-04],
102-
[2.2312226e-03, 5.1399925e-06, -1.0743635e-03],
103-
[2.3955442e-03, 6.7116896e-04, -1.2389944e-03],
104-
[1.9969011e-03, -4.5717746e-04, -1.5225793e-03],
105-
[1.8131963e-03, -7.7948131e-04, -1.6078206e-03],
106-
[1.4277012e-03, -7.7973347e-04, -1.3463887e-03],
107-
[7.3705515e-04, -1.1704378e-03, -9.8979671e-04],
108-
[1.0899740e-04, -1.2144407e-03, -1.1339665e-03],
109-
[1.2254890e-04, -4.7438752e-04, -8.8673591e-04],
110-
[7.0695346e-04, 7.2568876e-04, -2.5591519e-04],
111-
[5.0835893e-04, 2.6866698e-04, 2.2731400e-04],
112-
[-5.9932750e-04, -1.1667561e-03, -4.8044650e-04],
113-
[4.0421321e-04, 3.1692928e-04, -8.3296909e-05],
114-
[4.0506107e-05, -3.1728629e-04, -4.4132984e-04],
93+
-2.7225273e-05,
94+
-2.7225284e-05,
95+
-3.2535860e-05,
96+
-9.3287526e-06,
97+
-1.1088990e-05,
98+
-3.4527478e-05,
99+
5.7807661e-06,
100+
1.1616970e-05,
101+
2.9732121e-06,
102+
1.1190044e-05,
103+
-6.4673945e-06,
104+
-1.6562306e-05,
105+
-1.5946282e-05,
106+
-1.8079168e-06,
107+
-9.7664342e-06,
108+
6.2075532e-07,
109+
-8.9023115e-06,
110+
-1.5546989e-06,
111+
-7.2730008e-06,
112+
-7.5181362e-07,
115113
]
116114
)
117-
np.testing.assert_array_almost_equal(grads[0, 0, :, :], expected_gradients1, decimal=2)
115+
np.testing.assert_array_almost_equal(grads[0, 0, 0, :20], expected_gradients1, decimal=2)
118116

119117
expected_gradients2 = np.asarray(
120118
[
121-
[4.7986404e-04, 7.7701372e-04, 1.1786318e-03],
122-
[7.3503907e-04, -2.3474507e-03, -3.9008856e-04],
123-
[4.1874062e-04, -2.5707064e-03, -1.1054531e-03],
124-
[-1.7942721e-03, -3.3968450e-03, -1.4989552e-03],
125-
[-2.9697213e-03, -4.6922294e-03, -1.3162185e-03],
126-
[-3.1759157e-03, -9.8660104e-03, -4.7163852e-03],
127-
[1.8666144e-03, -2.8793041e-03, -3.1324378e-03],
128-
[1.0555880e-02, 7.6373261e-03, 5.3013843e-03],
129-
[8.9815725e-04, -1.0321697e-02, 1.4192325e-03],
130-
[8.5643278e-03, 3.0152409e-03, 2.0114987e-03],
131-
[-2.7870361e-03, -1.1686913e-02, -7.0649502e-03],
132-
[-7.7482774e-03, -1.3334424e-03, -9.1927368e-03],
133-
[-8.1487820e-03, -3.8133820e-03, -4.3300558e-03],
134-
[-7.7006700e-03, -1.2594147e-02, -3.9680018e-03],
135-
[-9.5743872e-03, -2.1007264e-02, -9.1963671e-03],
136-
[-8.6777220e-03, -1.7278835e-02, -1.3328674e-02],
137-
[-1.7368209e-02, -2.3461722e-02, -1.1538444e-02],
138-
[-4.6307812e-03, -5.7058665e-03, 1.3555109e-03],
139-
[4.8570461e-03, -5.8050654e-03, 8.1082489e-03],
140-
[6.4304657e-03, 2.8407066e-03, 8.7463465e-03],
141-
[5.0593228e-03, 1.4102085e-03, 5.2116364e-03],
142-
[2.5003455e-03, -6.0178695e-04, 2.0183939e-03],
143-
[2.1247163e-03, 4.7659015e-04, 7.5940741e-04],
144-
[1.3499497e-03, 6.2203623e-04, 1.2288829e-04],
145-
[2.8991612e-04, -4.0216290e-04, -7.2287643e-05],
146-
[6.6898909e-05, -6.3778006e-04, -3.6294860e-04],
147-
[5.3613615e-04, 9.9137833e-05, -1.6657988e-05],
148-
[-3.9828232e-05, -3.8453130e-04, -2.3702848e-04],
119+
-2.7307957e-05,
120+
-1.9417710e-05,
121+
-2.0928457e-05,
122+
-2.1384752e-05,
123+
-2.5035972e-05,
124+
-3.6572790e-05,
125+
-8.2444545e-05,
126+
-7.3255811e-05,
127+
-4.5060227e-05,
128+
-1.9829258e-05,
129+
-2.2043951e-05,
130+
-3.6746951e-05,
131+
-4.2588043e-05,
132+
-3.1833035e-05,
133+
-1.5923406e-05,
134+
-3.5026955e-05,
135+
-4.4511849e-05,
136+
-3.3867167e-05,
137+
-1.8569792e-05,
138+
-3.5141209e-05,
149139
]
150140
)
151-
np.testing.assert_array_almost_equal(grads[1, 0, :, :], expected_gradients2, decimal=2)
141+
np.testing.assert_array_almost_equal(grads[0, 0, :20, 0], expected_gradients2, decimal=2)
152142

153143
except ARTTestException as e:
154144
art_warning(e)
@@ -198,7 +188,7 @@ def test_preprocessing_defences(art_warning, get_pytorch_faster_rcnn):
198188
# Compute gradients
199189
grads = object_detector.loss_gradient(x=x_test, y=y_test)
200190

201-
assert grads.shape == (2, 28, 28, 3)
191+
assert grads.shape == (1, 3, 416, 416)
202192

203193
except ARTTestException as e:
204194
art_warning(e)
@@ -221,7 +211,7 @@ def test_compute_loss(art_warning, get_pytorch_faster_rcnn):
221211
# Compute loss
222212
loss = object_detector.compute_loss(x=x_test, y=y_test)
223213

224-
assert pytest.approx(0.84883332, abs=0.01) == float(loss)
214+
assert pytest.approx(0.0995874, abs=0.05) == float(loss)
225215

226216
except ARTTestException as e:
227217
art_warning(e)

0 commit comments

Comments
 (0)