Skip to content

Commit b42fb07

Browse files
committed
Implements dpctl.tensor.diff
1 parent f8cfaa7 commit b42fb07

File tree

2 files changed

+301
-2
lines changed

2 files changed

+301
-2
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@
9494
from dpctl.tensor._search_functions import where
9595
from dpctl.tensor._statistical_functions import mean, std, var
9696
from dpctl.tensor._usmarray import usm_ndarray
97-
from dpctl.tensor._utility_functions import all, any
97+
from dpctl.tensor._utility_functions import all, any, diff
9898

9999
from ._accumulation import cumulative_logsumexp, cumulative_prod, cumulative_sum
100100
from ._array_api import __array_api_version__, __array_namespace_info__
@@ -371,4 +371,5 @@
371371
"cumulative_logsumexp",
372372
"cumulative_prod",
373373
"cumulative_sum",
374+
"diff",
374375
]

dpctl/tensor/_utility_functions.py

Lines changed: 299 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,38 @@
1-
from numpy.core.numeric import normalize_axis_tuple
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2024 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
216

17+
import operator
18+
19+
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
20+
21+
import dpctl
322
import dpctl.tensor as dpt
423
import dpctl.tensor._tensor_impl as ti
524
import dpctl.tensor._tensor_reductions_impl as tri
625
import dpctl.utils as du
26+
from dpctl.tensor._clip import (
27+
_resolve_one_strong_one_weak_types,
28+
_resolve_one_strong_two_weak_types,
29+
)
30+
from dpctl.tensor._elementwise_common import (
31+
_get_dtype,
32+
_get_queue_usm_type,
33+
_get_shape,
34+
_validate_dtype,
35+
)
736

837

938
def _boolean_reduction(x, axis, keepdims, func):
@@ -128,3 +157,272 @@ def any(x, /, *, axis=None, keepdims=False):
128157
containing the results of the logical OR reduction.
129158
"""
130159
return _boolean_reduction(x, axis, keepdims, tri._any)
160+
161+
162+
def _validate_diff_shape(sh1, sh2, axis):
163+
if not sh2:
164+
# scalars will always be accepted
165+
return True
166+
else:
167+
sh1_ndim = len(sh1)
168+
if sh1_ndim == len(sh2) and all(
169+
sh1[i] == sh2[i] for i in range(sh1_ndim) if i != axis
170+
):
171+
return True
172+
else:
173+
return False
174+
175+
176+
def _concat_diff_input(arr, axis, prepend, append):
177+
if prepend is not None and append is not None:
178+
q1, x_usm_type = arr.sycl_queue, arr.usm_type
179+
q2, prepend_usm_type = _get_queue_usm_type(prepend)
180+
q3, append_usm_type = _get_queue_usm_type(append)
181+
if q2 is None and q3 is None:
182+
exec_q = q1
183+
coerced_usm_type = x_usm_type
184+
elif q3 is None:
185+
exec_q = du.get_execution_queue((q1, q2))
186+
if exec_q is None:
187+
raise du.ExecutionPlacementError(
188+
"Execution placement can not be unambiguously inferred "
189+
"from input arguments."
190+
)
191+
coerced_usm_type = dpctl.utils.get_coerced_usm_type(
192+
(
193+
x_usm_type,
194+
prepend_usm_type,
195+
)
196+
)
197+
elif q2 is None:
198+
exec_q = du.get_execution_queue((q1, q3))
199+
if exec_q is None:
200+
raise du.ExecutionPlacementError(
201+
"Execution placement can not be unambiguously inferred "
202+
"from input arguments."
203+
)
204+
coerced_usm_type = du.get_coerced_usm_type(
205+
(
206+
x_usm_type,
207+
append_usm_type,
208+
)
209+
)
210+
else:
211+
exec_q = du.get_execution_queue((q1, q2, q3))
212+
if exec_q is None:
213+
raise du.ExecutionPlacementError(
214+
"Execution placement can not be unambiguously inferred "
215+
"from input arguments."
216+
)
217+
coerced_usm_type = du.get_coerced_usm_type(
218+
(
219+
x_usm_type,
220+
prepend_usm_type,
221+
append_usm_type,
222+
)
223+
)
224+
du.validate_usm_type(coerced_usm_type, allow_none=False)
225+
arr_shape = arr.shape
226+
prepend_shape = _get_shape(prepend)
227+
append_shape = _get_shape(append)
228+
if not all(
229+
isinstance(s, (tuple, list))
230+
for s in (
231+
prepend_shape,
232+
append_shape,
233+
)
234+
):
235+
raise TypeError(
236+
"Shape of arguments can not be inferred. "
237+
"Arguments are expected to be "
238+
"lists, tuples, or both"
239+
)
240+
valid_prepend_shape = _validate_diff_shape(
241+
arr_shape, prepend_shape, axis
242+
)
243+
if not valid_prepend_shape:
244+
raise ValueError(
245+
f"`diff` argument `prepend` with shape {prepend_shape} is "
246+
f"invalid for first input with shape {arr_shape}"
247+
)
248+
valid_append_shape = _validate_diff_shape(arr_shape, append_shape, axis)
249+
if not valid_append_shape:
250+
raise ValueError(
251+
f"`diff` argument `append` with shape {append_shape} is invalid"
252+
f" for first input with shape {arr_shape}"
253+
)
254+
sycl_dev = exec_q.sycl_device
255+
arr_dtype = arr.dtype
256+
prepend_dtype = _get_dtype(prepend, sycl_dev)
257+
append_dtype = _get_dtype(append, sycl_dev)
258+
if not all(_validate_dtype(o) for o in (prepend_dtype, append_dtype)):
259+
raise ValueError("Operands have unsupported data types")
260+
prepend_dtype, append_dtype = _resolve_one_strong_two_weak_types(
261+
arr_dtype, prepend_dtype, append_dtype, sycl_dev
262+
)
263+
if isinstance(prepend, dpt.usm_ndarray):
264+
a_prepend = prepend
265+
else:
266+
a_prepend = dpt.asarray(
267+
prepend,
268+
dtype=prepend_dtype,
269+
usm_type=coerced_usm_type,
270+
sycl_queue=exec_q,
271+
)
272+
if isinstance(append, dpt.usm_ndarray):
273+
a_append = append
274+
else:
275+
a_append = dpt.asarray(
276+
prepend,
277+
dtype=append_dtype,
278+
usm_type=coerced_usm_type,
279+
sycl_queue=exec_q,
280+
)
281+
if not prepend_shape:
282+
prepend_shape = arr_shape[:axis] + (1,) + arr_shape[axis + 1 :]
283+
a_prepend = dpt.broadcast_to(a_prepend, arr_shape)
284+
if not append_shape:
285+
append_shape = arr_shape[:axis] + (1,) + arr_shape[axis + 1 :]
286+
a_append = dpt.broadcast_to(a_append, arr_shape)
287+
return dpt.concat((a_prepend, arr, a_append), axis=axis)
288+
elif prepend is not None:
289+
q1, x_usm_type = arr.sycl_queue, arr.usm_type
290+
q2, prepend_usm_type = _get_queue_usm_type(prepend)
291+
if q2 is None:
292+
exec_q = q1
293+
coerced_usm_type = x_usm_type
294+
else:
295+
exec_q = du.get_execution_queue((q1, q2))
296+
if exec_q is None:
297+
raise du.ExecutionPlacementError(
298+
"Execution placement can not be unambiguously inferred "
299+
"from input arguments."
300+
)
301+
coerced_usm_type = dpctl.utils.get_coerced_usm_type(
302+
(
303+
x_usm_type,
304+
prepend_usm_type,
305+
)
306+
)
307+
du.validate_usm_type(coerced_usm_type, allow_none=False)
308+
arr_shape = arr.shape
309+
prepend_shape = _get_shape(prepend)
310+
if not isinstance(prepend_shape, (tuple, list)):
311+
raise TypeError(
312+
"Shape of argument can not be inferred. "
313+
"Argument is expected to be a "
314+
"list or tuple"
315+
)
316+
valid_prepend_shape = _validate_diff_shape(
317+
arr_shape, prepend_shape, axis
318+
)
319+
if not valid_prepend_shape:
320+
raise ValueError(
321+
f"`diff` argument `prepend` with shape {prepend_shape} is "
322+
f"invalid for first input with shape {arr_shape}"
323+
)
324+
sycl_dev = exec_q.sycl_device
325+
arr_dtype = arr.dtype
326+
prepend_dtype = _get_dtype(prepend, sycl_dev)
327+
if not _validate_dtype(prepend_dtype):
328+
raise ValueError("Operand has unsupported data type")
329+
prepend_dtype = _resolve_one_strong_one_weak_types(
330+
arr_dtype, prepend_dtype, sycl_dev
331+
)
332+
if isinstance(prepend, dpt.usm_ndarray):
333+
a_prepend = prepend
334+
else:
335+
a_prepend = dpt.asarray(
336+
prepend,
337+
dtype=prepend_dtype,
338+
usm_type=coerced_usm_type,
339+
sycl_queue=exec_q,
340+
)
341+
if not prepend_shape:
342+
prepend_shape = arr_shape[:axis] + (1,) + arr_shape[axis + 1 :]
343+
a_prepend = dpt.broadcast_to(a_prepend, arr_shape)
344+
return dpt.concat((a_prepend, arr), axis=axis)
345+
elif append is not None:
346+
q1, x_usm_type = arr.sycl_queue, arr.usm_type
347+
q2, append_usm_type = _get_queue_usm_type(append)
348+
if q2 is None:
349+
exec_q = q1
350+
coerced_usm_type = x_usm_type
351+
else:
352+
exec_q = du.get_execution_queue((q1, q2))
353+
if exec_q is None:
354+
raise du.ExecutionPlacementError(
355+
"Execution placement can not be unambiguously inferred "
356+
"from input arguments."
357+
)
358+
coerced_usm_type = dpctl.utils.get_coerced_usm_type(
359+
(
360+
x_usm_type,
361+
append_usm_type,
362+
)
363+
)
364+
du.validate_usm_type(coerced_usm_type, allow_none=False)
365+
arr_shape = arr.shape
366+
append_shape = _get_shape(append)
367+
if not isinstance(append_shape, (tuple, list)):
368+
raise TypeError(
369+
"Shape of argument can not be inferred. "
370+
"Argument is expected to be a "
371+
"list or tuple"
372+
)
373+
valid_append_shape = _validate_diff_shape(arr_shape, append_shape, axis)
374+
if not valid_append_shape:
375+
raise ValueError(
376+
f"`diff` argument `append` with shape {append_shape} is invalid"
377+
f" for first input with shape {arr_shape}"
378+
)
379+
sycl_dev = exec_q.sycl_device
380+
arr_dtype = arr.dtype
381+
append_dtype = _get_dtype(append, sycl_dev)
382+
if not _validate_dtype(append_dtype):
383+
raise ValueError("Operand has unsupported data type")
384+
append_dtype = _resolve_one_strong_one_weak_types(
385+
arr_dtype, append_dtype, sycl_dev
386+
)
387+
if isinstance(append, dpt.usm_ndarray):
388+
a_append = append
389+
else:
390+
a_append = dpt.asarray(
391+
append,
392+
dtype=append_dtype,
393+
usm_type=coerced_usm_type,
394+
sycl_queue=exec_q,
395+
)
396+
if not append_shape:
397+
append_shape = arr_shape[:axis] + (1,) + arr_shape[axis + 1 :]
398+
a_append = dpt.broadcast_to(a_append, arr_shape)
399+
return dpt.concat((arr, a_append), axis=axis)
400+
else:
401+
arr1 = arr
402+
return arr1
403+
404+
405+
def diff(x, /, *, axis=-1, n=1, prepend=None, append=None):
406+
407+
if not isinstance(x, dpt.usm_ndarray):
408+
raise TypeError(
409+
"Expecting dpctl.tensor.usm_ndarray type, " f"got {type(x)}"
410+
)
411+
x_nd = x.ndim
412+
axis = normalize_axis_index(operator.index(axis), x_nd)
413+
n = operator.index(n)
414+
415+
arr = _concat_diff_input(x, axis, prepend, append)
416+
417+
# form slices and recurse
418+
sl0 = tuple(
419+
slice(None) if i != axis else slice(1, None) for i in range(x_nd)
420+
)
421+
sl1 = tuple(
422+
slice(None) if i != axis else slice(None, -1) for i in range(x_nd)
423+
)
424+
425+
for _ in range(n):
426+
arr = arr[sl0] - arr[sl1]
427+
428+
return arr

0 commit comments

Comments
 (0)