@@ -142,55 +142,79 @@ def test_loss_gradient(art_warning, get_pytorch_object_detector):
142142
143143 expected_gradients1 = np .asarray (
144144 [
145- - 2.7225273e-05 ,
146- - 2.7225284e-05 ,
147- - 3.2535860e-05 ,
148- - 9.3287526e-06 ,
149- - 1.1088990e-05 ,
150- - 3.4527478e-05 ,
151- 5.7807661e-06 ,
152- 1.1616970e-05 ,
153- 2.9732121e-06 ,
154- 1.1190044e-05 ,
155- - 6.4673945e-06 ,
156- - 1.6562306e-05 ,
157- - 1.5946282e-05 ,
158- - 1.8079168e-06 ,
159- - 9.7664342e-06 ,
160- 6.2075532e-07 ,
161- - 8.9023115e-06 ,
162- - 1.5546989e-06 ,
163- - 7.2730008e-06 ,
164- - 7.5181362e-07 ,
145+ - 7.2145270e-04 ,
146+ - 3.9774503e-04 ,
147+ - 5.5271841e-04 ,
148+ 2.5251633e-04 ,
149+ - 4.1167819e-05 ,
150+ 1.2919735e-04 ,
151+ 1.1148686e-04 ,
152+ 4.9278833e-04 ,
153+ 9.6094189e-04 ,
154+ 1.1812975e-03 ,
155+ 2.7167992e-04 ,
156+ 9.7095188e-05 ,
157+ 1.4456113e-04 ,
158+ - 8.8345587e-06 ,
159+ 4.7151549e-05 ,
160+ - 1.3497710e-04 ,
161+ - 2.3394797e-04 ,
162+ 1.3777621e-04 ,
163+ 3.2994794e-04 ,
164+ 3.7001527e-04 ,
165+ - 2.5945838e-04 ,
166+ - 8.3444244e-04 ,
167+ - 6.9832127e-04 ,
168+ - 3.0403296e-04 ,
169+ - 5.4019055e-04 ,
170+ - 3.4545487e-04 ,
171+ - 5.6993403e-04 ,
172+ - 2.9818740e-04 ,
173+ - 9.8479632e-04 ,
174+ - 4.1015903e-04 ,
175+ - 6.2145875e-04 ,
176+ - 1.1365353e-03 ,
165177 ]
166178 )
167- np .testing .assert_array_almost_equal (grads [0 , 0 , 0 , : 20 ], expected_gradients1 , decimal = 2 )
179+ np .testing .assert_array_almost_equal (grads [0 , 0 , 208 , 192 : 224 ], expected_gradients1 , decimal = 2 )
168180
169181 expected_gradients2 = np .asarray (
170182 [
171- - 2.7307957e-05 ,
172- - 1.9417710e-05 ,
173- - 2.0928457e-05 ,
174- - 2.1384752e-05 ,
175- - 2.5035972e-05 ,
176- - 3.6572790e-05 ,
177- - 8.2444545e-05 ,
178- - 7.3255811e-05 ,
179- - 4.5060227e-05 ,
180- - 1.9829258e-05 ,
181- - 2.2043951e-05 ,
182- - 3.6746951e-05 ,
183- - 4.2588043e-05 ,
184- - 3.1833035e-05 ,
185- - 1.5923406e-05 ,
186- - 3.5026955e-05 ,
187- - 4.4511849e-05 ,
188- - 3.3867167e-05 ,
189- - 1.8569792e-05 ,
190- - 3.5141209e-05 ,
183+ 0.00015462 ,
184+ 0.00028882 ,
185+ - 0.00018248 ,
186+ - 0.00114344 ,
187+ - 0.00160104 ,
188+ - 0.00190151 ,
189+ - 0.00183488 ,
190+ - 0.00191787 ,
191+ - 0.00018382 ,
192+ 0.00095297 ,
193+ 0.00042502 ,
194+ 0.00024631 ,
195+ 0.0002915 ,
196+ 0.00053676 ,
197+ 0.00028635 ,
198+ 0.00035274 ,
199+ - 0.00023395 ,
200+ - 0.00044685 ,
201+ - 0.00016795 ,
202+ 0.00059767 ,
203+ 0.00060389 ,
204+ 0.00010305 ,
205+ 0.0011498 ,
206+ 0.00135104 ,
207+ 0.00095133 ,
208+ 0.00081004 ,
209+ 0.00061877 ,
210+ 0.00089056 ,
211+ 0.00056647 ,
212+ 0.00070012 ,
213+ 0.00016926 ,
214+ 0.00026042 ,
191215 ]
192216 )
193- np .testing .assert_array_almost_equal (grads [0 , 0 , : 20 , 0 ], expected_gradients2 , decimal = 2 )
217+ np .testing .assert_array_almost_equal (grads [0 , 0 , 192 : 224 , 208 ], expected_gradients2 , decimal = 2 )
194218
195219 except ARTTestException as e :
196220 art_warning (e )
@@ -205,64 +229,81 @@ def test_loss_gradient_mask(art_warning, get_pytorch_object_detector_mask):
205229 grads = object_detector .loss_gradient (x_test , y_test )
206230 assert grads .shape == (1 , 3 , 416 , 416 )
207231
208- import pprint
209-
210- print ()
211- pprint .pprint (grads [0 , 0 , 0 , :20 ])
212- print ()
213- pprint .pprint (grads [0 , 0 , :20 , 0 ])
214-
215232 expected_gradients1 = np .asarray (
216233 [
217- - 4.2168313e-06 ,
218- - 4.4972450e-05 ,
219- - 3.6137710e-05 ,
220- - 1.2499937e-06 ,
221- 1.2728384e-05 ,
222- - 1.7352231e-05 ,
223- 5.6671047e-06 ,
224- 1.4085637e-05 ,
225- 5.9047998e-06 ,
226- 1.0826078e-05 ,
227- 2.2078505e-06 ,
228- - 1.3319310e-05 ,
229- - 2.4521427e-05 ,
230- - 1.8251436e-05 ,
231- - 1.9938851e-05 ,
232- - 3.6778667e-07 ,
233- 1.1899039e-05 ,
234- 1.9246204e-06 ,
235- - 2.7922330e-05 ,
236- - 3.2529952e-06 ,
234+ - 5.5341989e-05 ,
235+ - 5.4428884e-04 ,
236+ 5.4366910e-04 ,
237+ 7.6082360e-04 ,
238+ - 3.4690551e-05 ,
239+ - 3.8355158e-04 ,
240+ 9.4802541e-05 ,
241+ - 1.2973599e-03 ,
242+ - 8.5583847e-04 ,
243+ - 1.9041763e-03 ,
244+ - 2.0476838e-03 ,
245+ 1.3446594e-04 ,
246+ 9.6042868e-04 ,
247+ 8.8853808e-04 ,
248+ 4.1893515e-04 ,
249+ 1.2266783e-04 ,
250+ 6.0996308e-04 ,
251+ 4.6253894e-04 ,
252+ - 1.8787223e-03 ,
253+ - 1.9494371e-03 ,
254+ - 1.2018540e-03 ,
255+ - 7.0822285e-04 ,
256+ 3.9439899e-04 ,
257+ - 1.9463699e-03 ,
258+ - 1.9617968e-03 ,
259+ - 1.8740186e-04 ,
260+ - 4.7003134e-04 ,
261+ - 7.1175391e-04 ,
262+ - 2.6479245e-03 ,
263+ - 7.6713605e-04 ,
264+ - 9.1007189e-04 ,
265+ - 9.5907447e-04 ,
237266 ]
238267 )
239- np .testing .assert_array_almost_equal (grads [0 , 0 , 0 , : 20 ], expected_gradients1 , decimal = 2 )
268+ np .testing .assert_array_almost_equal (grads [0 , 0 , 208 , 192 : 224 ], expected_gradients1 , decimal = 2 )
240269
241270 expected_gradients2 = np .asarray (
242271 [
243- - 4.2168313e-06 ,
244- - 9.3028730e-06 ,
245- 1.5900954e-06 ,
246- - 9.7032771e-06 ,
247- - 7.9553565e-06 ,
248- - 1.9485701e-06 ,
249- - 1.3360468e-05 ,
250- - 2.7804586e-05 ,
251- - 4.2667002e-06 ,
252- - 6.1407286e-06 ,
253- - 6.6768125e-06 ,
254- - 1.6444834e-06 ,
255- 4.7967392e-06 ,
256- 2.4288647e-06 ,
257- 1.0280205e-05 ,
258- 4.2001102e-06 ,
259- 2.9494076e-05 ,
260- 1.4654281e-05 ,
261- 2.5580388e-05 ,
262- 3.0241908e-05 ,
272+ - 0.00239724 ,
273+ - 0.00271061 ,
274+ - 0.0036578 ,
275+ - 0.00504796 ,
276+ - 0.0048536 ,
277+ - 0.00433594 ,
278+ - 0.00499022 ,
279+ - 0.00401875 ,
280+ - 0.00333852 ,
281+ - 0.00060027 ,
282+ 0.00098555 ,
283+ 0.00249704 ,
284+ 0.00135383 ,
285+ 0.00277813 ,
286+ 0.00033104 ,
287+ 0.00016026 ,
288+ 0.00060996 ,
289+ 0.00010528 ,
290+ 0.00096368 ,
291+ 0.00230222 ,
292+ 0.00169831 ,
293+ 0.00172231 ,
294+ 0.00270932 ,
295+ 0.00224663 ,
296+ 0.00077922 ,
297+ 0.00174257 ,
298+ 0.00041644 ,
299+ - 0.00126136 ,
300+ - 0.00112533 ,
301+ - 0.00110854 ,
302+ - 0.00126751 ,
303+ - 0.0014297 ,
263304 ]
264305 )
265- np .testing .assert_array_almost_equal (grads [0 , 0 , : 20 , 0 ], expected_gradients2 , decimal = 2 )
306+ np .testing .assert_array_almost_equal (grads [0 , 0 , 192 : 224 , 208 ], expected_gradients2 , decimal = 2 )
266307
267308 except ARTTestException as e :
268309 art_warning (e )
0 commit comments