Skip to content

Commit 3cefd41

Browse files
committed
Update unit tests for #2663
Signed-off-by: Beat Buesser <[email protected]>
1 parent fb5c7a4 commit 3cefd41

File tree

1 file changed

+75
-77
lines changed

1 file changed

+75
-77
lines changed

tests/estimators/object_detection/test_pytorch_yolo.py

Lines changed: 75 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -85,76 +85,76 @@ def test_loss_gradient(art_warning, get_pytorch_yolo):
8585

8686
expected_gradients1 = np.asarray(
8787
[
88-
-0.00033619,
89-
0.00458546,
90-
-0.00084969,
91-
-0.00095304,
92-
-0.00403843,
93-
0.00225406,
94-
-0.00369539,
95-
-0.0099816,
96-
-0.01046214,
97-
-0.00290693,
98-
0.00075546,
99-
-0.0002135,
100-
-0.00659937,
101-
-0.00380152,
102-
-0.00593928,
103-
-0.00179838,
104-
-0.00213012,
105-
0.00526429,
106-
0.00332446,
107-
0.00543861,
108-
0.00284291,
109-
0.00426832,
110-
-0.00586808,
111-
-0.0017767,
112-
-0.00231807,
113-
-0.01142277,
114-
-0.00021731,
115-
0.00076714,
116-
0.00289533,
117-
0.00993828,
118-
0.00472939,
119-
0.00232432,
88+
-7.8263599e-04,
89+
-3.2761338e-04,
90+
-1.7732104e-04,
91+
5.0963718e-07,
92+
-1.2021367e-04,
93+
1.5550642e-05,
94+
-1.2371356e-04,
95+
6.0041926e-05,
96+
7.2321229e-05,
97+
2.8970995e-04,
98+
3.2069255e-04,
99+
-9.7214943e-06,
100+
4.1050217e-04,
101+
3.4139317e-04,
102+
3.2144223e-04,
103+
8.0305658e-04,
104+
1.0029323e-03,
105+
5.4904580e-04,
106+
3.4701737e-04,
107+
9.2334412e-05,
108+
4.5694585e-05,
109+
-4.1882982e-04,
110+
-1.1162873e-03,
111+
-1.2383220e-03,
112+
-1.2119032e-03,
113+
-1.3792568e-03,
114+
-1.0219158e-03,
115+
-1.7796915e-04,
116+
1.6578102e-04,
117+
-4.0390861e-04,
118+
5.0578610e-04,
119+
3.2289932e-05,
120120
]
121121
)
122122
np.testing.assert_array_almost_equal(grads[0, 0, 208, 192:224], expected_gradients1, decimal=2)
123123

124124
expected_gradients2 = np.asarray(
125125
[
126-
0.00079487,
127-
0.00426403,
128-
-0.00151893,
129-
0.00798506,
130-
0.00937666,
131-
0.01206836,
132-
-0.00319753,
133-
0.00506421,
134-
0.00291614,
135-
-0.00053876,
136-
0.00281978,
137-
-0.0027451,
138-
0.00319698,
139-
0.00287863,
140-
0.00370754,
141-
0.004611,
142-
-0.00213012,
143-
0.00440465,
144-
-0.00077857,
145-
0.00023536,
146-
0.0035248,
147-
-0.00810297,
148-
0.00698602,
149-
0.00877033,
150-
0.01452724,
151-
0.00161957,
152-
0.02649526,
153-
-0.0071549,
154-
0.02670361,
155-
-0.00759722,
156-
-0.02353876,
157-
0.00860081,
126+
4.02656849e-04,
127+
1.32368109e-03,
128+
1.06753211e-03,
129+
1.02746498e-03,
130+
4.34952060e-04,
131+
1.30278734e-03,
132+
1.65620341e-03,
133+
8.48031021e-04,
134+
2.80185544e-04,
135+
2.04326061e-04,
136+
-9.31014947e-05,
137+
-4.90375911e-04,
138+
-3.42604442e-04,
139+
1.36689676e-04,
140+
3.08552640e-04,
141+
3.88148270e-04,
142+
1.00293232e-03,
143+
-1.08163455e-04,
144+
-1.41605944e-03,
145+
-1.96112506e-03,
146+
-6.27453031e-04,
147+
-9.53144976e-04,
148+
-6.66696171e-04,
149+
-5.78872336e-04,
150+
-1.52492896e-04,
151+
-1.06580940e-03,
152+
1.04899483e-03,
153+
5.83183893e-04,
154+
8.98627564e-04,
155+
3.37607635e-04,
156+
8.34865321e-04,
157+
5.12865488e-04,
158158
]
159159
)
160160
np.testing.assert_array_almost_equal(grads[0, 0, 192:224, 208], expected_gradients2, decimal=2)
@@ -306,21 +306,19 @@ def test_patch(art_warning, get_pytorch_yolo):
306306
assert result[0]["scores"].shape == (10647,)
307307
expected_detection_scores = np.asarray(
308308
[
309-
2.0061936e-08,
310-
8.2958641e-06,
311-
1.5368976e-05,
312-
8.5753290e-06,
313-
1.5901747e-05,
314-
3.8245958e-05,
315-
4.6325898e-05,
316-
7.1730128e-06,
317-
4.3095843e-06,
318-
1.0766385e-06,
309+
2.0058684e-08,
310+
8.2879878e-06,
311+
1.5323505e-05,
312+
8.5337388e-06,
313+
1.5668766e-05,
314+
3.7196922e-05,
315+
4.5348370e-05,
316+
6.9575308e-06,
317+
4.2298670e-06,
318+
1.0316832e-06,
319319
]
320320
)
321-
np.testing.assert_raises(
322-
AssertionError, np.testing.assert_array_almost_equal, result[0]["scores"][:10], expected_detection_scores, 6
323-
)
321+
np.testing.assert_allclose(result[0]["scores"][:10], expected_detection_scores, rtol=1e-5, atol=1e-8)
324322

325323
except ARTTestException as e:
326324
art_warning(e)

0 commit comments

Comments
 (0)