Skip to content

Commit e50e5e5

Browse files
committed
fix test cases for pytorch object detectors
Signed-off-by: Farhan Ahmed <[email protected]>
1 parent 4465875 commit e50e5e5

File tree

2 files changed

+198
-133
lines changed

2 files changed

+198
-133
lines changed

tests/estimators/object_detection/test_pytorch_faster_rcnn.py

Lines changed: 66 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -90,55 +90,79 @@ def test_loss_gradient(art_warning, get_pytorch_faster_rcnn):
9090

9191
expected_gradients1 = np.asarray(
9292
[
93-
-2.7225273e-05,
94-
-2.7225284e-05,
95-
-3.2535860e-05,
96-
-9.3287526e-06,
97-
-1.1088990e-05,
98-
-3.4527478e-05,
99-
5.7807661e-06,
100-
1.1616970e-05,
101-
2.9732121e-06,
102-
1.1190044e-05,
103-
-6.4673945e-06,
104-
-1.6562306e-05,
105-
-1.5946282e-05,
106-
-1.8079168e-06,
107-
-9.7664342e-06,
108-
6.2075532e-07,
109-
-8.9023115e-06,
110-
-1.5546989e-06,
111-
-7.2730008e-06,
112-
-7.5181362e-07,
93+
-7.2145270e-04,
94+
-3.9774503e-04,
95+
-5.5271841e-04,
96+
2.5251633e-04,
97+
-4.1167819e-05,
98+
1.2919735e-04,
99+
1.1148686e-04,
100+
4.9278833e-04,
101+
9.6094189e-04,
102+
1.1812975e-03,
103+
2.7167992e-04,
104+
9.7095188e-05,
105+
1.4456113e-04,
106+
-8.8345587e-06,
107+
4.7151549e-05,
108+
-1.3497710e-04,
109+
-2.3394797e-04,
110+
1.3777621e-04,
111+
3.2994794e-04,
112+
3.7001527e-04,
113+
-2.5945838e-04,
114+
-8.3444244e-04,
115+
-6.9832127e-04,
116+
-3.0403296e-04,
117+
-5.4019055e-04,
118+
-3.4545487e-04,
119+
-5.6993403e-04,
120+
-2.9818740e-04,
121+
-9.8479632e-04,
122+
-4.1015903e-04,
123+
-6.2145875e-04,
124+
-1.1365353e-03,
113125
]
114126
)
115-
np.testing.assert_array_almost_equal(grads[0, 0, 0, :20], expected_gradients1, decimal=2)
127+
np.testing.assert_array_almost_equal(grads[0, 0, 208, 192:224], expected_gradients1, decimal=2)
116128

117129
expected_gradients2 = np.asarray(
118130
[
119-
-2.7307957e-05,
120-
-1.9417710e-05,
121-
-2.0928457e-05,
122-
-2.1384752e-05,
123-
-2.5035972e-05,
124-
-3.6572790e-05,
125-
-8.2444545e-05,
126-
-7.3255811e-05,
127-
-4.5060227e-05,
128-
-1.9829258e-05,
129-
-2.2043951e-05,
130-
-3.6746951e-05,
131-
-4.2588043e-05,
132-
-3.1833035e-05,
133-
-1.5923406e-05,
134-
-3.5026955e-05,
135-
-4.4511849e-05,
136-
-3.3867167e-05,
137-
-1.8569792e-05,
138-
-3.5141209e-05,
131+
0.00015462,
132+
0.00028882,
133+
-0.00018248,
134+
-0.00114344,
135+
-0.00160104,
136+
-0.00190151,
137+
-0.00183488,
138+
-0.00191787,
139+
-0.00018382,
140+
0.00095297,
141+
0.00042502,
142+
0.00024631,
143+
0.0002915,
144+
0.00053676,
145+
0.00028635,
146+
0.00035274,
147+
-0.00023395,
148+
-0.00044685,
149+
-0.00016795,
150+
0.00059767,
151+
0.00060389,
152+
0.00010305,
153+
0.0011498,
154+
0.00135104,
155+
0.00095133,
156+
0.00081004,
157+
0.00061877,
158+
0.00089056,
159+
0.00056647,
160+
0.00070012,
161+
0.00016926,
162+
0.00026042,
139163
]
140164
)
141-
np.testing.assert_array_almost_equal(grads[0, 0, :20, 0], expected_gradients2, decimal=2)
165+
np.testing.assert_array_almost_equal(grads[0, 0, 192:224, 208], expected_gradients2, decimal=2)
142166

143167
except ARTTestException as e:
144168
art_warning(e)

tests/estimators/object_detection/test_pytorch_object_detector.py

Lines changed: 132 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -142,55 +142,79 @@ def test_loss_gradient(art_warning, get_pytorch_object_detector):
142142

143143
expected_gradients1 = np.asarray(
144144
[
145-
-2.7225273e-05,
146-
-2.7225284e-05,
147-
-3.2535860e-05,
148-
-9.3287526e-06,
149-
-1.1088990e-05,
150-
-3.4527478e-05,
151-
5.7807661e-06,
152-
1.1616970e-05,
153-
2.9732121e-06,
154-
1.1190044e-05,
155-
-6.4673945e-06,
156-
-1.6562306e-05,
157-
-1.5946282e-05,
158-
-1.8079168e-06,
159-
-9.7664342e-06,
160-
6.2075532e-07,
161-
-8.9023115e-06,
162-
-1.5546989e-06,
163-
-7.2730008e-06,
164-
-7.5181362e-07,
145+
-7.2145270e-04,
146+
-3.9774503e-04,
147+
-5.5271841e-04,
148+
2.5251633e-04,
149+
-4.1167819e-05,
150+
1.2919735e-04,
151+
1.1148686e-04,
152+
4.9278833e-04,
153+
9.6094189e-04,
154+
1.1812975e-03,
155+
2.7167992e-04,
156+
9.7095188e-05,
157+
1.4456113e-04,
158+
-8.8345587e-06,
159+
4.7151549e-05,
160+
-1.3497710e-04,
161+
-2.3394797e-04,
162+
1.3777621e-04,
163+
3.2994794e-04,
164+
3.7001527e-04,
165+
-2.5945838e-04,
166+
-8.3444244e-04,
167+
-6.9832127e-04,
168+
-3.0403296e-04,
169+
-5.4019055e-04,
170+
-3.4545487e-04,
171+
-5.6993403e-04,
172+
-2.9818740e-04,
173+
-9.8479632e-04,
174+
-4.1015903e-04,
175+
-6.2145875e-04,
176+
-1.1365353e-03,
165177
]
166178
)
167-
np.testing.assert_array_almost_equal(grads[0, 0, 0, :20], expected_gradients1, decimal=2)
179+
np.testing.assert_array_almost_equal(grads[0, 0, 208, 192:224], expected_gradients1, decimal=2)
168180

169181
expected_gradients2 = np.asarray(
170182
[
171-
-2.7307957e-05,
172-
-1.9417710e-05,
173-
-2.0928457e-05,
174-
-2.1384752e-05,
175-
-2.5035972e-05,
176-
-3.6572790e-05,
177-
-8.2444545e-05,
178-
-7.3255811e-05,
179-
-4.5060227e-05,
180-
-1.9829258e-05,
181-
-2.2043951e-05,
182-
-3.6746951e-05,
183-
-4.2588043e-05,
184-
-3.1833035e-05,
185-
-1.5923406e-05,
186-
-3.5026955e-05,
187-
-4.4511849e-05,
188-
-3.3867167e-05,
189-
-1.8569792e-05,
190-
-3.5141209e-05,
183+
0.00015462,
184+
0.00028882,
185+
-0.00018248,
186+
-0.00114344,
187+
-0.00160104,
188+
-0.00190151,
189+
-0.00183488,
190+
-0.00191787,
191+
-0.00018382,
192+
0.00095297,
193+
0.00042502,
194+
0.00024631,
195+
0.0002915,
196+
0.00053676,
197+
0.00028635,
198+
0.00035274,
199+
-0.00023395,
200+
-0.00044685,
201+
-0.00016795,
202+
0.00059767,
203+
0.00060389,
204+
0.00010305,
205+
0.0011498,
206+
0.00135104,
207+
0.00095133,
208+
0.00081004,
209+
0.00061877,
210+
0.00089056,
211+
0.00056647,
212+
0.00070012,
213+
0.00016926,
214+
0.00026042,
191215
]
192216
)
193-
np.testing.assert_array_almost_equal(grads[0, 0, :20, 0], expected_gradients2, decimal=2)
217+
np.testing.assert_array_almost_equal(grads[0, 0, 192:224, 208], expected_gradients2, decimal=2)
194218

195219
except ARTTestException as e:
196220
art_warning(e)
@@ -205,64 +229,81 @@ def test_loss_gradient_mask(art_warning, get_pytorch_object_detector_mask):
205229
grads = object_detector.loss_gradient(x_test, y_test)
206230
assert grads.shape == (1, 3, 416, 416)
207231

208-
import pprint
209-
210-
print()
211-
pprint.pprint(grads[0, 0, 0, :20])
212-
print()
213-
pprint.pprint(grads[0, 0, :20, 0])
214-
215232
expected_gradients1 = np.asarray(
216233
[
217-
-4.2168313e-06,
218-
-4.4972450e-05,
219-
-3.6137710e-05,
220-
-1.2499937e-06,
221-
1.2728384e-05,
222-
-1.7352231e-05,
223-
5.6671047e-06,
224-
1.4085637e-05,
225-
5.9047998e-06,
226-
1.0826078e-05,
227-
2.2078505e-06,
228-
-1.3319310e-05,
229-
-2.4521427e-05,
230-
-1.8251436e-05,
231-
-1.9938851e-05,
232-
-3.6778667e-07,
233-
1.1899039e-05,
234-
1.9246204e-06,
235-
-2.7922330e-05,
236-
-3.2529952e-06,
234+
-5.5341989e-05,
235+
-5.4428884e-04,
236+
5.4366910e-04,
237+
7.6082360e-04,
238+
-3.4690551e-05,
239+
-3.8355158e-04,
240+
9.4802541e-05,
241+
-1.2973599e-03,
242+
-8.5583847e-04,
243+
-1.9041763e-03,
244+
-2.0476838e-03,
245+
1.3446594e-04,
246+
9.6042868e-04,
247+
8.8853808e-04,
248+
4.1893515e-04,
249+
1.2266783e-04,
250+
6.0996308e-04,
251+
4.6253894e-04,
252+
-1.8787223e-03,
253+
-1.9494371e-03,
254+
-1.2018540e-03,
255+
-7.0822285e-04,
256+
3.9439899e-04,
257+
-1.9463699e-03,
258+
-1.9617968e-03,
259+
-1.8740186e-04,
260+
-4.7003134e-04,
261+
-7.1175391e-04,
262+
-2.6479245e-03,
263+
-7.6713605e-04,
264+
-9.1007189e-04,
265+
-9.5907447e-04,
237266
]
238267
)
239-
np.testing.assert_array_almost_equal(grads[0, 0, 0, :20], expected_gradients1, decimal=2)
268+
np.testing.assert_array_almost_equal(grads[0, 0, 208, 192:224], expected_gradients1, decimal=2)
240269

241270
expected_gradients2 = np.asarray(
242271
[
243-
-4.2168313e-06,
244-
-9.3028730e-06,
245-
1.5900954e-06,
246-
-9.7032771e-06,
247-
-7.9553565e-06,
248-
-1.9485701e-06,
249-
-1.3360468e-05,
250-
-2.7804586e-05,
251-
-4.2667002e-06,
252-
-6.1407286e-06,
253-
-6.6768125e-06,
254-
-1.6444834e-06,
255-
4.7967392e-06,
256-
2.4288647e-06,
257-
1.0280205e-05,
258-
4.2001102e-06,
259-
2.9494076e-05,
260-
1.4654281e-05,
261-
2.5580388e-05,
262-
3.0241908e-05,
272+
-0.00239724,
273+
-0.00271061,
274+
-0.0036578,
275+
-0.00504796,
276+
-0.0048536,
277+
-0.00433594,
278+
-0.00499022,
279+
-0.00401875,
280+
-0.00333852,
281+
-0.00060027,
282+
0.00098555,
283+
0.00249704,
284+
0.00135383,
285+
0.00277813,
286+
0.00033104,
287+
0.00016026,
288+
0.00060996,
289+
0.00010528,
290+
0.00096368,
291+
0.00230222,
292+
0.00169831,
293+
0.00172231,
294+
0.00270932,
295+
0.00224663,
296+
0.00077922,
297+
0.00174257,
298+
0.00041644,
299+
-0.00126136,
300+
-0.00112533,
301+
-0.00110854,
302+
-0.00126751,
303+
-0.0014297,
263304
]
264305
)
265-
np.testing.assert_array_almost_equal(grads[0, 0, :20, 0], expected_gradients2, decimal=2)
306+
np.testing.assert_array_almost_equal(grads[0, 0, 192:224, 208], expected_gradients2, decimal=2)
266307

267308
except ARTTestException as e:
268309
art_warning(e)

0 commit comments

Comments
 (0)