Skip to content

Commit 7f6e0c0

Browse files
committed
Implements dpctl.tensor.diff
1 parent 9403f76 commit 7f6e0c0

File tree

2 files changed

+284
-2
lines changed

2 files changed

+284
-2
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@
9494
from dpctl.tensor._search_functions import where
9595
from dpctl.tensor._statistical_functions import mean, std, var
9696
from dpctl.tensor._usmarray import usm_ndarray
97-
from dpctl.tensor._utility_functions import all, any
97+
from dpctl.tensor._utility_functions import all, any, diff
9898

9999
from ._accumulation import cumulative_logsumexp, cumulative_prod, cumulative_sum
100100
from ._array_api import __array_api_version__, __array_namespace_info__
@@ -373,4 +373,5 @@
373373
"cumulative_prod",
374374
"cumulative_sum",
375375
"nextafter",
376+
"diff",
376377
]

dpctl/tensor/_utility_functions.py

Lines changed: 282 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,24 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import operator
18+
1719
import dpctl.tensor as dpt
1820
import dpctl.tensor._tensor_impl as ti
1921
import dpctl.tensor._tensor_reductions_impl as tri
2022
import dpctl.utils as du
23+
from dpctl.tensor._clip import (
24+
_resolve_one_strong_one_weak_types,
25+
_resolve_one_strong_two_weak_types,
26+
)
27+
from dpctl.tensor._elementwise_common import (
28+
_get_dtype,
29+
_get_queue_usm_type,
30+
_get_shape,
31+
_validate_dtype,
32+
)
2133

22-
from ._numpy_helper import normalize_axis_tuple
34+
from ._numpy_helper import normalize_axis_index, normalize_axis_tuple
2335

2436

2537
def _boolean_reduction(x, axis, keepdims, func):
@@ -144,3 +156,272 @@ def any(x, /, *, axis=None, keepdims=False):
144156
containing the results of the logical OR reduction.
145157
"""
146158
return _boolean_reduction(x, axis, keepdims, tri._any)
159+
160+
161+
def _validate_diff_shape(sh1, sh2, axis):
162+
if not sh2:
163+
# scalars will always be accepted
164+
return True
165+
else:
166+
sh1_ndim = len(sh1)
167+
if sh1_ndim == len(sh2) and all(
168+
sh1[i] == sh2[i] for i in range(sh1_ndim) if i != axis
169+
):
170+
return True
171+
else:
172+
return False
173+
174+
175+
def _concat_diff_input(arr, axis, prepend, append):
176+
if prepend is not None and append is not None:
177+
q1, x_usm_type = arr.sycl_queue, arr.usm_type
178+
q2, prepend_usm_type = _get_queue_usm_type(prepend)
179+
q3, append_usm_type = _get_queue_usm_type(append)
180+
if q2 is None and q3 is None:
181+
exec_q = q1
182+
coerced_usm_type = x_usm_type
183+
elif q3 is None:
184+
exec_q = du.get_execution_queue((q1, q2))
185+
if exec_q is None:
186+
raise du.ExecutionPlacementError(
187+
"Execution placement can not be unambiguously inferred "
188+
"from input arguments."
189+
)
190+
coerced_usm_type = du.get_coerced_usm_type(
191+
(
192+
x_usm_type,
193+
prepend_usm_type,
194+
)
195+
)
196+
elif q2 is None:
197+
exec_q = du.get_execution_queue((q1, q3))
198+
if exec_q is None:
199+
raise du.ExecutionPlacementError(
200+
"Execution placement can not be unambiguously inferred "
201+
"from input arguments."
202+
)
203+
coerced_usm_type = du.get_coerced_usm_type(
204+
(
205+
x_usm_type,
206+
append_usm_type,
207+
)
208+
)
209+
else:
210+
exec_q = du.get_execution_queue((q1, q2, q3))
211+
if exec_q is None:
212+
raise du.ExecutionPlacementError(
213+
"Execution placement can not be unambiguously inferred "
214+
"from input arguments."
215+
)
216+
coerced_usm_type = du.get_coerced_usm_type(
217+
(
218+
x_usm_type,
219+
prepend_usm_type,
220+
append_usm_type,
221+
)
222+
)
223+
du.validate_usm_type(coerced_usm_type, allow_none=False)
224+
arr_shape = arr.shape
225+
prepend_shape = _get_shape(prepend)
226+
append_shape = _get_shape(append)
227+
if not all(
228+
isinstance(s, (tuple, list))
229+
for s in (
230+
prepend_shape,
231+
append_shape,
232+
)
233+
):
234+
raise TypeError(
235+
"Shape of arguments can not be inferred. "
236+
"Arguments are expected to be "
237+
"lists, tuples, or both"
238+
)
239+
valid_prepend_shape = _validate_diff_shape(
240+
arr_shape, prepend_shape, axis
241+
)
242+
if not valid_prepend_shape:
243+
raise ValueError(
244+
f"`diff` argument `prepend` with shape {prepend_shape} is "
245+
f"invalid for first input with shape {arr_shape}"
246+
)
247+
valid_append_shape = _validate_diff_shape(arr_shape, append_shape, axis)
248+
if not valid_append_shape:
249+
raise ValueError(
250+
f"`diff` argument `append` with shape {append_shape} is invalid"
251+
f" for first input with shape {arr_shape}"
252+
)
253+
sycl_dev = exec_q.sycl_device
254+
arr_dtype = arr.dtype
255+
prepend_dtype = _get_dtype(prepend, sycl_dev)
256+
append_dtype = _get_dtype(append, sycl_dev)
257+
if not all(_validate_dtype(o) for o in (prepend_dtype, append_dtype)):
258+
raise ValueError("Operands have unsupported data types")
259+
prepend_dtype, append_dtype = _resolve_one_strong_two_weak_types(
260+
arr_dtype, prepend_dtype, append_dtype, sycl_dev
261+
)
262+
if isinstance(prepend, dpt.usm_ndarray):
263+
a_prepend = prepend
264+
else:
265+
a_prepend = dpt.asarray(
266+
prepend,
267+
dtype=prepend_dtype,
268+
usm_type=coerced_usm_type,
269+
sycl_queue=exec_q,
270+
)
271+
if isinstance(append, dpt.usm_ndarray):
272+
a_append = append
273+
else:
274+
a_append = dpt.asarray(
275+
prepend,
276+
dtype=append_dtype,
277+
usm_type=coerced_usm_type,
278+
sycl_queue=exec_q,
279+
)
280+
if not prepend_shape:
281+
prepend_shape = arr_shape[:axis] + (1,) + arr_shape[axis + 1 :]
282+
a_prepend = dpt.broadcast_to(a_prepend, arr_shape)
283+
if not append_shape:
284+
append_shape = arr_shape[:axis] + (1,) + arr_shape[axis + 1 :]
285+
a_append = dpt.broadcast_to(a_append, arr_shape)
286+
return dpt.concat((a_prepend, arr, a_append), axis=axis)
287+
elif prepend is not None:
288+
q1, x_usm_type = arr.sycl_queue, arr.usm_type
289+
q2, prepend_usm_type = _get_queue_usm_type(prepend)
290+
if q2 is None:
291+
exec_q = q1
292+
coerced_usm_type = x_usm_type
293+
else:
294+
exec_q = du.get_execution_queue((q1, q2))
295+
if exec_q is None:
296+
raise du.ExecutionPlacementError(
297+
"Execution placement can not be unambiguously inferred "
298+
"from input arguments."
299+
)
300+
coerced_usm_type = du.get_coerced_usm_type(
301+
(
302+
x_usm_type,
303+
prepend_usm_type,
304+
)
305+
)
306+
du.validate_usm_type(coerced_usm_type, allow_none=False)
307+
arr_shape = arr.shape
308+
prepend_shape = _get_shape(prepend)
309+
if not isinstance(prepend_shape, (tuple, list)):
310+
raise TypeError(
311+
"Shape of argument can not be inferred. "
312+
"Argument is expected to be a "
313+
"list or tuple"
314+
)
315+
valid_prepend_shape = _validate_diff_shape(
316+
arr_shape, prepend_shape, axis
317+
)
318+
if not valid_prepend_shape:
319+
raise ValueError(
320+
f"`diff` argument `prepend` with shape {prepend_shape} is "
321+
f"invalid for first input with shape {arr_shape}"
322+
)
323+
sycl_dev = exec_q.sycl_device
324+
arr_dtype = arr.dtype
325+
prepend_dtype = _get_dtype(prepend, sycl_dev)
326+
if not _validate_dtype(prepend_dtype):
327+
raise ValueError("Operand has unsupported data type")
328+
prepend_dtype = _resolve_one_strong_one_weak_types(
329+
arr_dtype, prepend_dtype, sycl_dev
330+
)
331+
if isinstance(prepend, dpt.usm_ndarray):
332+
a_prepend = prepend
333+
else:
334+
a_prepend = dpt.asarray(
335+
prepend,
336+
dtype=prepend_dtype,
337+
usm_type=coerced_usm_type,
338+
sycl_queue=exec_q,
339+
)
340+
if not prepend_shape:
341+
prepend_shape = arr_shape[:axis] + (1,) + arr_shape[axis + 1 :]
342+
a_prepend = dpt.broadcast_to(a_prepend, arr_shape)
343+
return dpt.concat((a_prepend, arr), axis=axis)
344+
elif append is not None:
345+
q1, x_usm_type = arr.sycl_queue, arr.usm_type
346+
q2, append_usm_type = _get_queue_usm_type(append)
347+
if q2 is None:
348+
exec_q = q1
349+
coerced_usm_type = x_usm_type
350+
else:
351+
exec_q = du.get_execution_queue((q1, q2))
352+
if exec_q is None:
353+
raise du.ExecutionPlacementError(
354+
"Execution placement can not be unambiguously inferred "
355+
"from input arguments."
356+
)
357+
coerced_usm_type = du.get_coerced_usm_type(
358+
(
359+
x_usm_type,
360+
append_usm_type,
361+
)
362+
)
363+
du.validate_usm_type(coerced_usm_type, allow_none=False)
364+
arr_shape = arr.shape
365+
append_shape = _get_shape(append)
366+
if not isinstance(append_shape, (tuple, list)):
367+
raise TypeError(
368+
"Shape of argument can not be inferred. "
369+
"Argument is expected to be a "
370+
"list or tuple"
371+
)
372+
valid_append_shape = _validate_diff_shape(arr_shape, append_shape, axis)
373+
if not valid_append_shape:
374+
raise ValueError(
375+
f"`diff` argument `append` with shape {append_shape} is invalid"
376+
f" for first input with shape {arr_shape}"
377+
)
378+
sycl_dev = exec_q.sycl_device
379+
arr_dtype = arr.dtype
380+
append_dtype = _get_dtype(append, sycl_dev)
381+
if not _validate_dtype(append_dtype):
382+
raise ValueError("Operand has unsupported data type")
383+
append_dtype = _resolve_one_strong_one_weak_types(
384+
arr_dtype, append_dtype, sycl_dev
385+
)
386+
if isinstance(append, dpt.usm_ndarray):
387+
a_append = append
388+
else:
389+
a_append = dpt.asarray(
390+
append,
391+
dtype=append_dtype,
392+
usm_type=coerced_usm_type,
393+
sycl_queue=exec_q,
394+
)
395+
if not append_shape:
396+
append_shape = arr_shape[:axis] + (1,) + arr_shape[axis + 1 :]
397+
a_append = dpt.broadcast_to(a_append, arr_shape)
398+
return dpt.concat((arr, a_append), axis=axis)
399+
else:
400+
arr1 = arr
401+
return arr1
402+
403+
404+
def diff(x, /, *, axis=-1, n=1, prepend=None, append=None):
405+
406+
if not isinstance(x, dpt.usm_ndarray):
407+
raise TypeError(
408+
"Expecting dpctl.tensor.usm_ndarray type, " f"got {type(x)}"
409+
)
410+
x_nd = x.ndim
411+
axis = normalize_axis_index(operator.index(axis), x_nd)
412+
n = operator.index(n)
413+
414+
arr = _concat_diff_input(x, axis, prepend, append)
415+
416+
# form slices and recurse
417+
sl0 = tuple(
418+
slice(None) if i != axis else slice(1, None) for i in range(x_nd)
419+
)
420+
sl1 = tuple(
421+
slice(None) if i != axis else slice(None, -1) for i in range(x_nd)
422+
)
423+
424+
for _ in range(n):
425+
arr = arr[sl0] - arr[sl1]
426+
427+
return arr

0 commit comments

Comments
 (0)