@@ -48,22 +48,18 @@ def test_diff_basic(dt):
48
48
skip_if_dtype_not_supported (dt , q )
49
49
50
50
x = dpt .asarray ([9 , 12 , 7 , 17 , 10 , 18 , 15 , 9 , 8 , 8 ], dtype = dt , sycl_queue = q )
51
- res = dpt .diff (x )
52
51
op = dpt .not_equal if x .dtype is dpt .bool else dpt .subtract
53
- expected_res = op (x [1 :], x [:- 1 ])
54
- if dpt .dtype (dt ).kind in "fc" :
55
- assert dpt .allclose (res , expected_res )
56
- else :
57
- assert dpt .all (res == expected_res )
58
52
59
- res = dpt .diff (x , n = 5 )
60
- expected_res = x
61
- for _ in range (5 ):
62
- expected_res = op (expected_res [1 :], expected_res [:- 1 ])
63
- if dpt .dtype (dt ).kind in "fc" :
64
- assert dpt .allclose (res , expected_res )
65
- else :
66
- assert dpt .all (res == expected_res )
53
+ # test both n=2 and n>2 branches
54
+ for n in [1 , 2 , 5 ]:
55
+ res = dpt .diff (x , n = n )
56
+ expected_res = x
57
+ for _ in range (n ):
58
+ expected_res = op (expected_res [1 :], expected_res [:- 1 ])
59
+ if dpt .dtype (dt ).kind in "fc" :
60
+ assert dpt .allclose (res , expected_res )
61
+ else :
62
+ assert dpt .all (res == expected_res )
67
63
68
64
69
65
def test_diff_axis ():
@@ -73,17 +69,15 @@ def test_diff_axis():
73
69
dpt .asarray ([9 , 12 , 7 , 17 , 10 , 18 , 15 , 9 , 8 , 8 ], dtype = "i4" ), (3 , 4 , 1 )
74
70
)
75
71
x [:, ::2 , :] = 0
76
- res = dpt .diff (x , n = 1 , axis = 1 )
77
- expected_res = dpt .subtract (x [:, 1 :, :], x [:, :- 1 , :])
78
- assert dpt .all (res == expected_res )
79
-
80
- res = dpt .diff (x , n = 3 , axis = 1 )
81
- expected_res = x
82
- for _ in range (3 ):
83
- expected_res = dpt .subtract (
84
- expected_res [:, 1 :, :], expected_res [:, :- 1 , :]
85
- )
86
- assert dpt .all (res == expected_res )
72
+
73
+ for n in [1 , 2 , 3 ]:
74
+ res = dpt .diff (x , n = 3 , axis = 1 )
75
+ expected_res = x
76
+ for _ in range (3 ):
77
+ expected_res = dpt .subtract (
78
+ expected_res [:, 1 :, :], expected_res [:, :- 1 , :]
79
+ )
80
+ assert dpt .all (res == expected_res )
87
81
88
82
89
83
def test_diff_prepend_append_type_promotion ():
@@ -179,33 +173,28 @@ def test_diff_prepend_append_py_scalars(sh, axis):
179
173
sl2 = tuple (sl2 )
180
174
181
175
r = dpt .diff (arr , axis = axis , prepend = zero , append = zero )
182
- assert isinstance (r , dpt .usm_ndarray )
183
176
assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
184
177
assert r .shape [axis ] == arr .shape [axis ] + 2 - n
185
178
assert dpt .all (r [sl1 ] == 1 )
186
179
assert dpt .all (r [sl2 ] == - 1 )
187
180
188
181
r = dpt .diff (arr , axis = axis , prepend = zero )
189
- assert isinstance (r , dpt .usm_ndarray )
190
182
assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
191
183
assert r .shape [axis ] == arr .shape [axis ] + 1 - n
192
184
assert dpt .all (r [sl1 ] == 1 )
193
185
194
186
r = dpt .diff (arr , axis = axis , append = zero )
195
- assert isinstance (r , dpt .usm_ndarray )
196
187
assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
197
188
assert r .shape [axis ] == arr .shape [axis ] + 1 - n
198
189
assert dpt .all (r [sl2 ] == - 1 )
199
190
200
191
r = dpt .diff (arr , axis = axis , prepend = dpt .asarray (zero ), append = zero )
201
- assert isinstance (r , dpt .usm_ndarray )
202
192
assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
203
193
assert r .shape [axis ] == arr .shape [axis ] + 2 - n
204
194
assert dpt .all (r [sl1 ] == 1 )
205
195
assert dpt .all (r [sl2 ] == - 1 )
206
196
207
197
r = dpt .diff (arr , axis = axis , prepend = zero , append = dpt .asarray (zero ))
208
- assert isinstance (r , dpt .usm_ndarray )
209
198
assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
210
199
assert r .shape [axis ] == arr .shape [axis ] + 2 - n
211
200
assert dpt .all (r [sl1 ] == 1 )
@@ -218,54 +207,36 @@ def test_tensor_diff_append_prepend_arrays():
218
207
n = 1
219
208
axis = 0
220
209
221
- sz = 5
222
- arr = dpt .arange (sz , 2 * sz , dtype = "i4" )
223
- prepend = dpt .arange (sz , dtype = "i4" )
224
- append = dpt .arange (2 * sz , 3 * sz , dtype = "i4" )
225
- const_diff = 1
226
-
227
- r = dpt .diff (arr , axis = axis , prepend = prepend , append = append )
228
- assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
229
- assert (
230
- r .shape [axis ]
231
- == arr .shape [axis ] + prepend .shape [axis ] + append .shape [axis ] - n
232
- )
233
- assert dpt .all (r == const_diff )
234
-
235
- r = dpt .diff (arr , axis = axis , prepend = prepend )
236
- assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
237
- assert r .shape [axis ] == arr .shape [axis ] + prepend .shape [axis ] - n
238
- assert dpt .all (r == const_diff )
239
-
240
- r = dpt .diff (arr , axis = axis , append = append )
241
- assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
242
- assert r .shape [axis ] == arr .shape [axis ] + append .shape [axis ] - n
243
- assert dpt .all (r == const_diff )
244
-
245
- sh = (3 , 4 , 5 )
246
- sz = prod (sh )
247
- arr = dpt .reshape (dpt .arange (sz , 2 * sz , dtype = "i4" ), sh )
248
- prepend = dpt .reshape (dpt .arange (sz , dtype = "i4" ), sh )
249
- append = dpt .reshape (dpt .arange (2 * sz , 3 * sz , dtype = "i4" ), sh )
250
- const_diff = prod (sh [axis + 1 :])
210
+ for sh in [(5 ,), (3 , 4 , 5 )]:
211
+ sz = prod (sh )
212
+ arr = dpt .reshape (dpt .arange (sz , 2 * sz , dtype = "i4" ), sh )
213
+ prepend = dpt .reshape (dpt .arange (sz , dtype = "i4" ), sh )
214
+ append = dpt .reshape (dpt .arange (2 * sz , 3 * sz , dtype = "i4" ), sh )
215
+ const_diff = sz / sh [axis ]
251
216
252
- r = dpt .diff (arr , axis = axis , prepend = prepend , append = append )
253
- assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
254
- assert (
255
- r .shape [axis ]
256
- == arr .shape [axis ] + prepend .shape [axis ] + append .shape [axis ] - n
257
- )
258
- assert dpt .all (r == const_diff )
217
+ r = dpt .diff (arr , axis = axis , prepend = prepend , append = append )
218
+ assert all (
219
+ r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis
220
+ )
221
+ assert (
222
+ r .shape [axis ]
223
+ == arr .shape [axis ] + prepend .shape [axis ] + append .shape [axis ] - n
224
+ )
225
+ assert dpt .all (r == const_diff )
259
226
260
- r = dpt .diff (arr , axis = axis , prepend = prepend )
261
- assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
262
- assert r .shape [axis ] == arr .shape [axis ] + prepend .shape [axis ] - n
263
- assert dpt .all (r == const_diff )
227
+ r = dpt .diff (arr , axis = axis , prepend = prepend )
228
+ assert all (
229
+ r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis
230
+ )
231
+ assert r .shape [axis ] == arr .shape [axis ] + prepend .shape [axis ] - n
232
+ assert dpt .all (r == const_diff )
264
233
265
- r = dpt .diff (arr , axis = axis , append = append )
266
- assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
267
- assert r .shape [axis ] == arr .shape [axis ] + append .shape [axis ] - n
268
- assert dpt .all (r == const_diff )
234
+ r = dpt .diff (arr , axis = axis , append = append )
235
+ assert all (
236
+ r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis
237
+ )
238
+ assert r .shape [axis ] == arr .shape [axis ] + append .shape [axis ] - n
239
+ assert dpt .all (r == const_diff )
269
240
270
241
271
242
def test_diff_wrong_append_prepend_shape ():
@@ -332,6 +303,26 @@ def test_diff_compute_follows_data():
332
303
append = ar3 ,
333
304
)
334
305
306
+ assert_raises_regex (
307
+ ExecutionPlacementError ,
308
+ "Execution placement can not be unambiguously inferred from input "
309
+ "arguments" ,
310
+ dpt .diff ,
311
+ ar1 ,
312
+ prepend = ar2 ,
313
+ append = 0 ,
314
+ )
315
+
316
+ assert_raises_regex (
317
+ ExecutionPlacementError ,
318
+ "Execution placement can not be unambiguously inferred from input "
319
+ "arguments" ,
320
+ dpt .diff ,
321
+ ar1 ,
322
+ prepend = 0 ,
323
+ append = ar2 ,
324
+ )
325
+
335
326
assert_raises_regex (
336
327
ExecutionPlacementError ,
337
328
"Execution placement can not be unambiguously inferred from input "
0 commit comments