Skip to content

Commit 66e0e5a

Browse files
committed
fix test cases for pytorch yolo
Signed-off-by: Farhan Ahmed <[email protected]>
1 parent 1a62aa0 commit 66e0e5a

File tree

2 files changed

+94
-141
lines changed

2 files changed

+94
-141
lines changed

tests/estimators/object_detection/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def get_pytorch_object_detector_data(get_default_cifar10_subset):
3737
{
3838
"boxes": np.asarray([[6, 22, 413, 346], [0, 24, 406, 342]], dtype=np.float32),
3939
"labels": np.asarray([21, 18]),
40+
"scores": np.asarray([1, 1], dtype=np.float32),
4041
"masks": np.ones((1, 416, 416)) / 2,
4142
}
4243
]

tests/estimators/object_detection/test_pytorch_yolo.py

Lines changed: 93 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -30,36 +30,25 @@
3030
@pytest.mark.only_with_platform("pytorch")
3131
def test_predict(art_warning, get_pytorch_yolo):
3232
try:
33-
object_detector, x_test, _ = get_pytorch_yolo
33+
from art.utils import non_maximum_suppression
3434

35-
result = object_detector.predict(x=x_test)
35+
object_detector, x_test, _ = get_pytorch_yolo
3636

37-
assert list(result[0].keys()) == ["boxes", "labels", "scores"]
37+
preds = object_detector.predict(x_test)
38+
result = non_maximum_suppression(preds[0], iou_threshold=0.4, confidence_threshold=0.3)
39+
assert list(result.keys()) == ["boxes", "labels", "scores"]
3840

39-
assert result[0]["boxes"].shape == (10647, 4)
40-
expected_detection_boxes = np.asarray([0.0000000e00, 0.0000000e00, 1.6367816e02, 4.4342079e01])
41-
np.testing.assert_array_almost_equal(result[0]["boxes"][2, :], expected_detection_boxes, decimal=3)
41+
assert result["boxes"].shape == (1, 4)
42+
expected_detection_boxes = np.asarray([[19.709427, 39.02864, 402.08032, 383.65576]])
43+
np.testing.assert_array_almost_equal(result["boxes"], expected_detection_boxes, decimal=3)
4244

43-
assert result[0]["scores"].shape == (10647,)
44-
expected_detection_scores = np.asarray(
45-
[
46-
4.3653536e-08,
47-
3.3987994e-06,
48-
2.5681820e-06,
49-
3.9782722e-06,
50-
2.1766680e-05,
51-
2.6138965e-05,
52-
6.3377396e-05,
53-
7.6248516e-06,
54-
4.3447722e-06,
55-
3.6515078e-06,
56-
]
57-
)
58-
np.testing.assert_array_almost_equal(result[0]["scores"][:10], expected_detection_scores, decimal=6)
45+
assert result["scores"].shape == (1,)
46+
expected_detection_scores = np.asarray([0.40862876])
47+
np.testing.assert_array_almost_equal(result["scores"], expected_detection_scores, decimal=3)
5948

60-
assert result[0]["labels"].shape == (10647,)
61-
expected_detection_classes = np.asarray([0, 0, 14, 14, 14, 14, 14, 14, 14, 0])
62-
np.testing.assert_array_almost_equal(result[0]["labels"][:10], expected_detection_classes, decimal=6)
49+
assert result["labels"].shape == (1,)
50+
expected_detection_classes = np.asarray([23])
51+
np.testing.assert_array_equal(result["labels"], expected_detection_classes)
6352

6453
except ARTTestException as e:
6554
art_warning(e)
@@ -92,120 +81,83 @@ def test_loss_gradient(art_warning, get_pytorch_yolo):
9281

9382
grads = object_detector.loss_gradient(x=x_test, y=y_test)
9483

95-
assert grads.shape == (2, 3, 416, 416)
84+
assert grads.shape == (1, 3, 416, 416)
9685

9786
expected_gradients1 = np.asarray(
9887
[
99-
0.012576922,
100-
-0.005133151,
101-
-0.0028872574,
102-
-0.0029357928,
103-
-0.008929219,
104-
0.012767567,
105-
-0.00715934,
106-
0.00987368,
107-
-0.0014089097,
108-
-0.004765472,
109-
-0.007845592,
110-
-0.0065127434,
111-
-0.00047654763,
112-
-0.018194549,
113-
0.00025652442,
114-
-0.01420591,
115-
0.03873131,
116-
0.080963746,
117-
-0.009225381,
118-
0.026824722,
119-
0.005942673,
120-
-0.025760904,
121-
0.008754236,
122-
-0.037260942,
123-
0.027838552,
124-
0.0485742,
125-
0.020763855,
126-
-0.013568859,
127-
-0.0071423287,
128-
0.000802512,
129-
0.012983642,
130-
0.006466129,
131-
0.0025194373,
132-
-0.012298459,
133-
-0.01168492,
134-
-0.0013298508,
135-
-0.007176587,
136-
0.01996972,
137-
-0.004173076,
138-
0.029163878,
139-
0.022482246,
140-
0.008151911,
141-
0.025543496,
142-
0.0007374112,
143-
0.0008220682,
144-
-0.005740379,
145-
0.009537468,
146-
-0.01116704,
147-
0.0010225883,
148-
0.00026052812,
88+
-0.00033619,
89+
0.00458546,
90+
-0.00084969,
91+
-0.00095304,
92+
-0.00403843,
93+
0.00225406,
94+
-0.00369539,
95+
-0.0099816,
96+
-0.01046214,
97+
-0.00290693,
98+
0.00075546,
99+
-0.0002135,
100+
-0.00659937,
101+
-0.00380152,
102+
-0.00593928,
103+
-0.00179838,
104+
-0.00213012,
105+
0.00526429,
106+
0.00332446,
107+
0.00543861,
108+
0.00284291,
109+
0.00426832,
110+
-0.00586808,
111+
-0.0017767,
112+
-0.00231807,
113+
-0.01142277,
114+
-0.00021731,
115+
0.00076714,
116+
0.00289533,
117+
0.00993828,
118+
0.00472939,
119+
0.00232432,
149120
]
150121
)
151-
152-
np.testing.assert_array_almost_equal(grads[0, 0, 208, 175:225], expected_gradients1, decimal=1)
122+
np.testing.assert_array_almost_equal(grads[0, 0, 208, 192:224], expected_gradients1, decimal=1)
153123

154124
expected_gradients2 = np.asarray(
155125
[
156-
0.0049910736,
157-
-0.008941505,
158-
-0.013645802,
159-
0.0060615,
160-
0.0021073571,
161-
-0.0022195925,
162-
-0.006654369,
163-
0.010533731,
164-
0.0013077373,
165-
-0.010422451,
166-
-0.00034834983,
167-
-0.0040517827,
168-
-0.0001514384,
169-
-0.031307846,
170-
-0.008412821,
171-
-0.044170827,
172-
0.055609763,
173-
0.0220191,
174-
-0.019813634,
175-
-0.035893522,
176-
0.023970673,
177-
-0.08727841,
178-
0.0411198,
179-
0.0072751334,
180-
0.01716753,
181-
0.0391037,
182-
0.020182624,
183-
0.021557821,
184-
0.011461802,
185-
0.0046976856,
186-
-0.00304008,
187-
-0.010215744,
188-
-0.0074639097,
189-
-0.020115864,
190-
-0.05325762,
191-
-0.006238129,
192-
-0.006486116,
193-
0.09806269,
194-
0.03115965,
195-
0.066279344,
196-
0.05367205,
197-
-0.042338565,
198-
0.04456845,
199-
0.040167376,
200-
0.03357561,
201-
0.01510548,
202-
0.0006220075,
203-
-0.027102726,
204-
-0.020182101,
205-
-0.04347762,
126+
0.00079487,
127+
0.00426403,
128+
-0.00151893,
129+
0.00798506,
130+
0.00937666,
131+
0.01206836,
132+
-0.00319753,
133+
0.00506421,
134+
0.00291614,
135+
-0.00053876,
136+
0.00281978,
137+
-0.0027451,
138+
0.00319698,
139+
0.00287863,
140+
0.00370754,
141+
0.004611,
142+
-0.00213012,
143+
0.00440465,
144+
-0.00077857,
145+
0.00023536,
146+
0.0035248,
147+
-0.00810297,
148+
0.00698602,
149+
0.00877033,
150+
0.01452724,
151+
0.00161957,
152+
0.02649526,
153+
-0.0071549,
154+
0.02670361,
155+
-0.00759722,
156+
-0.02353876,
157+
0.00860081,
206158
]
207159
)
208-
np.testing.assert_array_almost_equal(grads[1, 0, 208, 175:225], expected_gradients2, decimal=1)
160+
np.testing.assert_array_almost_equal(grads[0, 0, 192:224, 208], expected_gradients2, decimal=1)
209161

210162
except ARTTestException as e:
211163
art_warning(e)
@@ -265,7 +217,7 @@ def test_preprocessing_defences(art_warning, get_pytorch_yolo):
265217
# Compute gradients
266218
grads = object_detector.loss_gradient(x=x_test, y=y_test)
267219

268-
assert grads.shape == (2, 3, 416, 416)
220+
assert grads.shape == (1, 3, 416, 416)
269221

270222
except ARTTestException as e:
271223
art_warning(e)
@@ -290,7 +242,7 @@ def test_compute_loss(art_warning, get_pytorch_yolo):
290242
# Compute loss
291243
loss = object_detector.compute_loss(x=x_test, y=y_test)
292244

293-
assert pytest.approx(11.20741, abs=1.5) == float(loss)
245+
assert pytest.approx(0.0920641, abs=0.05) == float(loss)
294246

295247
except ARTTestException as e:
296248
art_warning(e)
@@ -354,16 +306,16 @@ def test_patch(art_warning, get_pytorch_yolo):
354306
assert result[0]["scores"].shape == (10647,)
355307
expected_detection_scores = np.asarray(
356308
[
357-
4.3653536e-08,
358-
3.3987994e-06,
359-
2.5681820e-06,
360-
3.9782722e-06,
361-
2.1766680e-05,
362-
2.6138965e-05,
363-
6.3377396e-05,
364-
7.6248516e-06,
365-
4.3447722e-06,
366-
3.6515078e-06,
309+
2.0061936e-08,
310+
8.2958641e-06,
311+
1.5368976e-05,
312+
8.5753290e-06,
313+
1.5901747e-05,
314+
3.8245958e-05,
315+
4.6325898e-05,
316+
7.1730128e-06,
317+
4.3095843e-06,
318+
1.0766385e-06,
367319
]
368320
)
369321
np.testing.assert_raises(

0 commit comments

Comments
 (0)