@@ -142,55 +142,79 @@ def test_loss_gradient(art_warning, get_pytorch_object_detector):
142
142
143
143
expected_gradients1 = np .asarray (
144
144
[
145
- - 2.7225273e-05 ,
146
- - 2.7225284e-05 ,
147
- - 3.2535860e-05 ,
148
- - 9.3287526e-06 ,
149
- - 1.1088990e-05 ,
150
- - 3.4527478e-05 ,
151
- 5.7807661e-06 ,
152
- 1.1616970e-05 ,
153
- 2.9732121e-06 ,
154
- 1.1190044e-05 ,
155
- - 6.4673945e-06 ,
156
- - 1.6562306e-05 ,
157
- - 1.5946282e-05 ,
158
- - 1.8079168e-06 ,
159
- - 9.7664342e-06 ,
160
- 6.2075532e-07 ,
161
- - 8.9023115e-06 ,
162
- - 1.5546989e-06 ,
163
- - 7.2730008e-06 ,
164
- - 7.5181362e-07 ,
145
+ - 7.2145270e-04 ,
146
+ - 3.9774503e-04 ,
147
+ - 5.5271841e-04 ,
148
+ 2.5251633e-04 ,
149
+ - 4.1167819e-05 ,
150
+ 1.2919735e-04 ,
151
+ 1.1148686e-04 ,
152
+ 4.9278833e-04 ,
153
+ 9.6094189e-04 ,
154
+ 1.1812975e-03 ,
155
+ 2.7167992e-04 ,
156
+ 9.7095188e-05 ,
157
+ 1.4456113e-04 ,
158
+ - 8.8345587e-06 ,
159
+ 4.7151549e-05 ,
160
+ - 1.3497710e-04 ,
161
+ - 2.3394797e-04 ,
162
+ 1.3777621e-04 ,
163
+ 3.2994794e-04 ,
164
+ 3.7001527e-04 ,
165
+ - 2.5945838e-04 ,
166
+ - 8.3444244e-04 ,
167
+ - 6.9832127e-04 ,
168
+ - 3.0403296e-04 ,
169
+ - 5.4019055e-04 ,
170
+ - 3.4545487e-04 ,
171
+ - 5.6993403e-04 ,
172
+ - 2.9818740e-04 ,
173
+ - 9.8479632e-04 ,
174
+ - 4.1015903e-04 ,
175
+ - 6.2145875e-04 ,
176
+ - 1.1365353e-03 ,
165
177
]
166
178
)
167
- np .testing .assert_array_almost_equal (grads [0 , 0 , 0 , : 20 ], expected_gradients1 , decimal = 2 )
179
+ np .testing .assert_array_almost_equal (grads [0 , 0 , 208 , 192 : 224 ], expected_gradients1 , decimal = 2 )
168
180
169
181
expected_gradients2 = np .asarray (
170
182
[
171
- - 2.7307957e-05 ,
172
- - 1.9417710e-05 ,
173
- - 2.0928457e-05 ,
174
- - 2.1384752e-05 ,
175
- - 2.5035972e-05 ,
176
- - 3.6572790e-05 ,
177
- - 8.2444545e-05 ,
178
- - 7.3255811e-05 ,
179
- - 4.5060227e-05 ,
180
- - 1.9829258e-05 ,
181
- - 2.2043951e-05 ,
182
- - 3.6746951e-05 ,
183
- - 4.2588043e-05 ,
184
- - 3.1833035e-05 ,
185
- - 1.5923406e-05 ,
186
- - 3.5026955e-05 ,
187
- - 4.4511849e-05 ,
188
- - 3.3867167e-05 ,
189
- - 1.8569792e-05 ,
190
- - 3.5141209e-05 ,
183
+ 0.00015462 ,
184
+ 0.00028882 ,
185
+ - 0.00018248 ,
186
+ - 0.00114344 ,
187
+ - 0.00160104 ,
188
+ - 0.00190151 ,
189
+ - 0.00183488 ,
190
+ - 0.00191787 ,
191
+ - 0.00018382 ,
192
+ 0.00095297 ,
193
+ 0.00042502 ,
194
+ 0.00024631 ,
195
+ 0.0002915 ,
196
+ 0.00053676 ,
197
+ 0.00028635 ,
198
+ 0.00035274 ,
199
+ - 0.00023395 ,
200
+ - 0.00044685 ,
201
+ - 0.00016795 ,
202
+ 0.00059767 ,
203
+ 0.00060389 ,
204
+ 0.00010305 ,
205
+ 0.0011498 ,
206
+ 0.00135104 ,
207
+ 0.00095133 ,
208
+ 0.00081004 ,
209
+ 0.00061877 ,
210
+ 0.00089056 ,
211
+ 0.00056647 ,
212
+ 0.00070012 ,
213
+ 0.00016926 ,
214
+ 0.00026042 ,
191
215
]
192
216
)
193
- np .testing .assert_array_almost_equal (grads [0 , 0 , : 20 , 0 ], expected_gradients2 , decimal = 2 )
217
+ np .testing .assert_array_almost_equal (grads [0 , 0 , 192 : 224 , 208 ], expected_gradients2 , decimal = 2 )
194
218
195
219
except ARTTestException as e :
196
220
art_warning (e )
@@ -205,64 +229,81 @@ def test_loss_gradient_mask(art_warning, get_pytorch_object_detector_mask):
205
229
grads = object_detector .loss_gradient (x_test , y_test )
206
230
assert grads .shape == (1 , 3 , 416 , 416 )
207
231
208
- import pprint
209
-
210
- print ()
211
- pprint .pprint (grads [0 , 0 , 0 , :20 ])
212
- print ()
213
- pprint .pprint (grads [0 , 0 , :20 , 0 ])
214
-
215
232
expected_gradients1 = np .asarray (
216
233
[
217
- - 4.2168313e-06 ,
218
- - 4.4972450e-05 ,
219
- - 3.6137710e-05 ,
220
- - 1.2499937e-06 ,
221
- 1.2728384e-05 ,
222
- - 1.7352231e-05 ,
223
- 5.6671047e-06 ,
224
- 1.4085637e-05 ,
225
- 5.9047998e-06 ,
226
- 1.0826078e-05 ,
227
- 2.2078505e-06 ,
228
- - 1.3319310e-05 ,
229
- - 2.4521427e-05 ,
230
- - 1.8251436e-05 ,
231
- - 1.9938851e-05 ,
232
- - 3.6778667e-07 ,
233
- 1.1899039e-05 ,
234
- 1.9246204e-06 ,
235
- - 2.7922330e-05 ,
236
- - 3.2529952e-06 ,
234
+ - 5.5341989e-05 ,
235
+ - 5.4428884e-04 ,
236
+ 5.4366910e-04 ,
237
+ 7.6082360e-04 ,
238
+ - 3.4690551e-05 ,
239
+ - 3.8355158e-04 ,
240
+ 9.4802541e-05 ,
241
+ - 1.2973599e-03 ,
242
+ - 8.5583847e-04 ,
243
+ - 1.9041763e-03 ,
244
+ - 2.0476838e-03 ,
245
+ 1.3446594e-04 ,
246
+ 9.6042868e-04 ,
247
+ 8.8853808e-04 ,
248
+ 4.1893515e-04 ,
249
+ 1.2266783e-04 ,
250
+ 6.0996308e-04 ,
251
+ 4.6253894e-04 ,
252
+ - 1.8787223e-03 ,
253
+ - 1.9494371e-03 ,
254
+ - 1.2018540e-03 ,
255
+ - 7.0822285e-04 ,
256
+ 3.9439899e-04 ,
257
+ - 1.9463699e-03 ,
258
+ - 1.9617968e-03 ,
259
+ - 1.8740186e-04 ,
260
+ - 4.7003134e-04 ,
261
+ - 7.1175391e-04 ,
262
+ - 2.6479245e-03 ,
263
+ - 7.6713605e-04 ,
264
+ - 9.1007189e-04 ,
265
+ - 9.5907447e-04 ,
237
266
]
238
267
)
239
- np .testing .assert_array_almost_equal (grads [0 , 0 , 0 , : 20 ], expected_gradients1 , decimal = 2 )
268
+ np .testing .assert_array_almost_equal (grads [0 , 0 , 208 , 192 : 224 ], expected_gradients1 , decimal = 2 )
240
269
241
270
expected_gradients2 = np .asarray (
242
271
[
243
- - 4.2168313e-06 ,
244
- - 9.3028730e-06 ,
245
- 1.5900954e-06 ,
246
- - 9.7032771e-06 ,
247
- - 7.9553565e-06 ,
248
- - 1.9485701e-06 ,
249
- - 1.3360468e-05 ,
250
- - 2.7804586e-05 ,
251
- - 4.2667002e-06 ,
252
- - 6.1407286e-06 ,
253
- - 6.6768125e-06 ,
254
- - 1.6444834e-06 ,
255
- 4.7967392e-06 ,
256
- 2.4288647e-06 ,
257
- 1.0280205e-05 ,
258
- 4.2001102e-06 ,
259
- 2.9494076e-05 ,
260
- 1.4654281e-05 ,
261
- 2.5580388e-05 ,
262
- 3.0241908e-05 ,
272
+ - 0.00239724 ,
273
+ - 0.00271061 ,
274
+ - 0.0036578 ,
275
+ - 0.00504796 ,
276
+ - 0.0048536 ,
277
+ - 0.00433594 ,
278
+ - 0.00499022 ,
279
+ - 0.00401875 ,
280
+ - 0.00333852 ,
281
+ - 0.00060027 ,
282
+ 0.00098555 ,
283
+ 0.00249704 ,
284
+ 0.00135383 ,
285
+ 0.00277813 ,
286
+ 0.00033104 ,
287
+ 0.00016026 ,
288
+ 0.00060996 ,
289
+ 0.00010528 ,
290
+ 0.00096368 ,
291
+ 0.00230222 ,
292
+ 0.00169831 ,
293
+ 0.00172231 ,
294
+ 0.00270932 ,
295
+ 0.00224663 ,
296
+ 0.00077922 ,
297
+ 0.00174257 ,
298
+ 0.00041644 ,
299
+ - 0.00126136 ,
300
+ - 0.00112533 ,
301
+ - 0.00110854 ,
302
+ - 0.00126751 ,
303
+ - 0.0014297 ,
263
304
]
264
305
)
265
- np .testing .assert_array_almost_equal (grads [0 , 0 , : 20 , 0 ], expected_gradients2 , decimal = 2 )
306
+ np .testing .assert_array_almost_equal (grads [0 , 0 , 192 : 224 , 208 ], expected_gradients2 , decimal = 2 )
266
307
267
308
except ARTTestException as e :
268
309
art_warning (e )
0 commit comments