Skip to content

Commit 1a62aa0

Browse files
committed
fix test cases for pytorch detection transformer
Signed-off-by: Farhan Ahmed <[email protected]>
1 parent e50e5e5 commit 1a62aa0

File tree

1 file changed

+88
-106
lines changed

1 file changed

+88
-106
lines changed

tests/estimators/object_detection/test_pytorch_detection_transformer.py

Lines changed: 88 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -30,36 +30,31 @@
3030
@pytest.mark.only_with_platform("pytorch")
3131
def test_predict(art_warning, get_pytorch_detr):
3232
try:
33-
object_detector, x_test, _ = get_pytorch_detr
34-
35-
result = object_detector.predict(x=x_test)
33+
from art.utils import non_maximum_suppression
3634

37-
assert list(result[0].keys()) == ["boxes", "labels", "scores"]
35+
object_detector, x_test, _ = get_pytorch_detr
3836

39-
assert result[0]["boxes"].shape == (100, 4)
40-
expected_detection_boxes = np.asarray([-0.12423098, 361.80136, 82.385345, 795.50305])
41-
np.testing.assert_array_almost_equal(result[0]["boxes"][2, :], expected_detection_boxes, decimal=1)
37+
preds = object_detector.predict(x_test)
38+
result = non_maximum_suppression(preds[0], iou_threshold=0.4, confidence_threshold=0.3)
39+
assert list(result.keys()) == ["boxes", "labels", "scores"]
4240

43-
assert result[0]["scores"].shape == (100,)
44-
expected_detection_scores = np.asarray(
41+
assert result["boxes"].shape == (3, 4)
42+
expected_detection_boxes = np.asarray(
4543
[
46-
0.00105285,
47-
0.00261505,
48-
0.00060220,
49-
0.00121928,
50-
0.00154554,
51-
0.00021678,
52-
0.00077083,
53-
0.00045684,
54-
0.00180561,
55-
0.00067704,
44+
[1.0126123, 25.658852, 412.70746, 379.12537],
45+
[-0.089400, 272.08664, 415.90994, 416.25930],
46+
[0.1522941, 75.882440, 99.139565, 335.11273],
5647
]
5748
)
58-
np.testing.assert_array_almost_equal(result[0]["scores"][:10], expected_detection_scores, decimal=1)
49+
np.testing.assert_array_almost_equal(result["boxes"], expected_detection_boxes, decimal=3)
50+
51+
assert result["scores"].shape == (3,)
52+
expected_detection_scores = np.asarray([0.8424455, 0.7796526, 0.35387915])
53+
np.testing.assert_array_almost_equal(result["scores"], expected_detection_scores, decimal=3)
5954

60-
assert result[0]["labels"].shape == (100,)
61-
expected_detection_classes = np.asarray([1, 23, 23, 1, 1, 23, 23, 23, 1, 1])
62-
np.testing.assert_array_almost_equal(result[0]["labels"][:10], expected_detection_classes, decimal=1)
55+
assert result["labels"].shape == (3,)
56+
expected_detection_classes = np.asarray([17, 65, 17])
57+
np.testing.assert_array_equal(result["labels"], expected_detection_classes)
6358

6459
except ARTTestException as e:
6560
art_warning(e)
@@ -68,15 +63,8 @@ def test_predict(art_warning, get_pytorch_detr):
6863
@pytest.mark.only_with_platform("pytorch")
6964
def test_fit(art_warning, get_pytorch_detr):
7065
try:
71-
import torch
72-
7366
object_detector, x_test, y_test = get_pytorch_detr
7467

75-
# Create optimizer
76-
params = [p for p in object_detector.model.parameters() if p.requires_grad]
77-
optimizer = torch.optim.SGD(params, lr=0.01)
78-
object_detector.set_params(optimizer=optimizer)
79-
8068
# Compute loss before training
8169
loss1 = object_detector.compute_loss(x=x_test, y=y_test)
8270

@@ -99,84 +87,83 @@ def test_loss_gradient(art_warning, get_pytorch_detr):
9987

10088
grads = object_detector.loss_gradient(x=x_test, y=y_test)
10189

102-
assert grads.shape == (2, 3, 800, 800)
90+
assert grads.shape == (1, 3, 416, 416)
10391

10492
expected_gradients1 = np.asarray(
10593
[
106-
-0.00757495,
107-
-0.00101332,
108-
0.00368362,
109-
0.00283334,
110-
-0.00096027,
111-
0.00873749,
112-
0.00546095,
113-
-0.00823532,
114-
-0.00710872,
115-
0.00389713,
116-
-0.00966289,
117-
0.00448294,
118-
0.00754991,
119-
-0.00934104,
120-
-0.00350194,
121-
-0.00541577,
122-
-0.00395624,
123-
0.00147651,
124-
0.0105616,
125-
0.01231265,
126-
-0.00148831,
127-
-0.0043609,
128-
0.00093031,
129-
0.00884939,
130-
-0.00356749,
131-
0.00093475,
132-
-0.00353712,
133-
-0.0060132,
134-
-0.00067899,
135-
-0.00886974,
136-
0.00108483,
137-
-0.00052412,
94+
0.02891439,
95+
0.0055933,
96+
-0.00687808,
97+
0.0095074,
98+
0.00247894,
99+
0.00122704,
100+
-0.00482378,
101+
-0.00924361,
102+
-0.02870164,
103+
-0.00683936,
104+
0.00904205,
105+
-0.01315971,
106+
-0.0151937,
107+
-0.00156442,
108+
0.00775309,
109+
0.01946152,
110+
0.00523211,
111+
-0.01682214,
112+
0.00079588,
113+
0.01627164,
114+
-0.01347653,
115+
-0.00512358,
116+
0.00610363,
117+
0.02831643,
118+
0.00742467,
119+
0.00293561,
120+
0.01380033,
121+
0.02112359,
122+
0.01725711,
123+
-0.00431877,
124+
-0.01007722,
125+
-0.00526983,
138126
]
139127
)
140-
141-
np.testing.assert_array_almost_equal(grads[0, 0, 10, :32], expected_gradients1, decimal=1)
128+
np.testing.assert_array_almost_equal(grads[0, 0, 208, 192:224], expected_gradients1, decimal=1)
142129

143130
expected_gradients2 = np.asarray(
144131
[
145-
-0.00757495,
146-
-0.00101332,
147-
0.00368362,
148-
0.00283334,
149-
-0.00096027,
150-
0.00873749,
151-
0.00546095,
152-
-0.00823532,
153-
-0.00710872,
154-
0.00389713,
155-
-0.00966289,
156-
0.00448294,
157-
0.00754991,
158-
-0.00934104,
159-
-0.00350194,
160-
-0.00541577,
161-
-0.00395624,
162-
0.00147651,
163-
0.0105616,
164-
0.01231265,
165-
-0.00148831,
166-
-0.0043609,
167-
0.00093031,
168-
0.00884939,
169-
-0.00356749,
170-
0.00093475,
171-
-0.00353712,
172-
-0.0060132,
173-
-0.00067899,
174-
-0.00886974,
175-
0.00108483,
176-
-0.00052412,
132+
-0.00549417,
133+
-0.01592844,
134+
-0.01073932,
135+
-0.00443333,
136+
-0.00780143,
137+
-0.02033146,
138+
-0.0191503,
139+
0.01227987,
140+
0.019971,
141+
0.01034214,
142+
-0.00918145,
143+
-0.02458049,
144+
-0.00708776,
145+
-0.00826812,
146+
-0.01284431,
147+
-0.00195021,
148+
0.00523211,
149+
0.00661678,
150+
0.00851441,
151+
0.01157211,
152+
-0.00324841,
153+
-0.00395823,
154+
0.00756641,
155+
0.00405913,
156+
-0.00055517,
157+
0.00221484,
158+
-0.02415526,
159+
-0.02096599,
160+
0.00980014,
161+
0.00174731,
162+
-0.01008899,
163+
0.00305779,
177164
]
178165
)
179-
np.testing.assert_array_almost_equal(grads[1, 0, 10, :32], expected_gradients2, decimal=1)
166+
np.testing.assert_array_almost_equal(grads[0, 0, 192:224, 208], expected_gradients2, decimal=1)
180167

181168
except ARTTestException as e:
182169
art_warning(e)
@@ -239,18 +226,13 @@ def test_preprocessing_defences(art_warning, get_pytorch_detr):
239226
"boxes": result[0]["boxes"],
240227
"labels": result[0]["labels"],
241228
"scores": np.ones_like(result[0]["labels"]),
242-
},
243-
{
244-
"boxes": result[1]["boxes"],
245-
"labels": result[1]["labels"],
246-
"scores": np.ones_like(result[1]["labels"]),
247-
},
229+
}
248230
]
249231

250232
# Compute gradients
251233
grads = object_detector.loss_gradient(x=x_test, y=y)
252234

253-
assert grads.shape == (2, 3, 800, 800)
235+
assert grads.shape == (1, 3, 416, 416)
254236

255237
except ARTTestException as e:
256238
art_warning(e)
@@ -275,7 +257,7 @@ def test_compute_loss(art_warning, get_pytorch_detr):
275257
# Compute loss
276258
loss = object_detector.compute_loss(x=x_test, y=y_test)
277259

278-
assert pytest.approx(6.7767677, abs=0.1) == float(loss)
260+
assert pytest.approx(2.172381, abs=0.1) == float(loss)
279261

280262
except ARTTestException as e:
281263
art_warning(e)

0 commit comments

Comments
 (0)