Skip to content

Commit 7528ce8

Browse files
Add arange function (#814)
* Added dpt.arange Added kernels to populate 1D contiguous array with linear sequence values based on starting_value and step as well as using starting and ending values interpolated by affine transformation. Modularized usm_type arg checking to dpctl.utils Extended support for NumPy scalars by dpt.asarray * added a test for dpctl.tensor.arange * added tests for dpctl.utils.validate_usm_type function * Test case when end is None
2 parents 96c6741 + 94380f6 commit 7528ce8

File tree

7 files changed

+503
-39
lines changed

7 files changed

+503
-39
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
"""
2323

2424
from dpctl.tensor._copy_utils import asnumpy, astype, copy, from_numpy, to_numpy
25-
from dpctl.tensor._ctors import asarray, empty
25+
from dpctl.tensor._ctors import arange, asarray, empty
2626
from dpctl.tensor._device import Device
2727
from dpctl.tensor._dlpack import from_dlpack
2828
from dpctl.tensor._manipulation_functions import (
@@ -40,6 +40,7 @@
4040
__all__ = [
4141
"Device",
4242
"usm_ndarray",
43+
"arange",
4344
"asarray",
4445
"astype",
4546
"copy",

dpctl/tensor/_ctors.py

Lines changed: 92 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import dpctl
2020
import dpctl.memory as dpm
2121
import dpctl.tensor as dpt
22+
import dpctl.tensor._tensor_impl as ti
2223
import dpctl.utils
2324
from dpctl.tensor._device import normalize_queue_device
2425

@@ -43,6 +44,11 @@ def _array_info_dispatch(obj):
4344
return _empty_tuple, complex, _host_set
4445
elif isinstance(obj, (list, tuple, range)):
4546
return _array_info_sequence(obj)
47+
elif any(
48+
isinstance(obj, s)
49+
for s in [np.integer, np.floating, np.complexfloating, np.bool_]
50+
):
51+
return _empty_tuple, obj.dtype, _host_set
4652
else:
4753
raise ValueError(type(obj))
4854

@@ -256,13 +262,13 @@ def asarray(
256262
is created. `device` can be `None`, a oneAPI filter selector string,
257263
an instance of :class:`dpctl.SyclDevice` corresponding to a
258264
non-partitioned SYCL device, an instance of
259-
:class:`dpctl.SyclQueue`, or a `Device` object returnedby
265+
:class:`dpctl.SyclQueue`, or a `Device` object returned by
260266
`dpctl.tensor.usm_array.device`. Default: `None`.
261267
usm_type ("device"|"shared"|"host", optional): The type of SYCL USM
262268
allocation for the output array. For `usm_type=None` the allocation
263269
type is inferred from the input if `obj` has USM allocation, or
264270
`"device"` is used instead. Default: `None`.
265-
sycl_queue: (:class:`dpctl.SyclQueue`, optional): The SYCL queue to use
271+
sycl_queue (:class:`dpctl.SyclQueue`, optional): The SYCL queue to use
266272
for output array allocation and copying. `sycl_queue` and `device`
267273
are exclusive keywords, i.e. use one or another. If both are
268274
specified, a `TypeError` is raised unless both imply the same
@@ -290,17 +296,7 @@ def asarray(
290296
else:
291297
order = order[0].upper()
292298
# 4. Check that usm_type is None, or a valid value
293-
if usm_type is not None:
294-
if isinstance(usm_type, str):
295-
if usm_type not in ["device", "shared", "host"]:
296-
raise ValueError(
297-
f"Unrecognized value of usm_type={usm_type}, "
298-
"expected 'device', 'shared', 'host', or None."
299-
)
300-
else:
301-
raise TypeError(
302-
f"Expected usm_type to be a str or None, got {type(usm_type)}"
303-
)
299+
dpctl.utils.validate_usm_type(usm_type, allow_none=True)
304300
# 5. Normalize device/sycl_queue [keep it None if was None]
305301
if device is not None or sycl_queue is not None:
306302
sycl_queue = normalize_queue_device(
@@ -410,7 +406,7 @@ def empty(
410406
`dpctl.tensor.usm_array.device`. Default: `None`.
411407
usm_type ("device"|"shared"|"host", optional): The type of SYCL USM
412408
allocation for the output array. Default: `"device"`.
413-
sycl_queue: (:class:`dpctl.SyclQueue`, optional): The SYCL queue to use
409+
sycl_queue (:class:`dpctl.SyclQueue`, optional): The SYCL queue to use
414410
for output array allocation and copying. `sycl_queue` and `device`
415411
are exclusive keywords, i.e. use one or another. If both are
416412
specified, a `TypeError` is raised unless both imply the same
@@ -425,16 +421,7 @@ def empty(
425421
)
426422
else:
427423
order = order[0].upper()
428-
if isinstance(usm_type, str):
429-
if usm_type not in ["device", "shared", "host"]:
430-
raise ValueError(
431-
f"Unrecognized value of usm_type={usm_type}, "
432-
"expected 'device', 'shared', or 'host'."
433-
)
434-
else:
435-
raise TypeError(
436-
f"Expected usm_type to be of type str, got {type(usm_type)}"
437-
)
424+
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
438425
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
439426
res = dpt.usm_ndarray(
440427
sh,
@@ -444,3 +431,84 @@ def empty(
444431
buffer_ctor_kwargs={"queue": sycl_queue},
445432
)
446433
return res
434+
435+
436+
def _coerce_and_infer_dt(*args, dt):
437+
nd, seq_dt, d = _array_info_sequence(args)
438+
if d != _host_set or nd != (len(args),):
439+
raise ValueError("start, stop and step must be Python scalars")
440+
if dt is None:
441+
dt = seq_dt
442+
dt = np.dtype(dt)
443+
if np.issubdtype(dt, np.integer):
444+
return tuple(int(v) for v in args), dt
445+
elif np.issubdtype(dt, np.floating):
446+
return tuple(float(v) for v in args), dt
447+
elif np.issubdtype(dt, np.complexfloating):
448+
return tuple(complex(v) for v in args), dt
449+
else:
450+
raise ValueError(f"Data type {dt} is not supported")
451+
452+
453+
def arange(
454+
start,
455+
/,
456+
stop=None,
457+
step=1,
458+
*,
459+
dtype=None,
460+
device=None,
461+
usm_type="device",
462+
sycl_queue=None,
463+
):
464+
""" arange(start, /, stop=None, step=1, *, dtype=None, \
465+
device=None, usm_type="device", sycl_queue=None) -> usm_ndarray
466+
467+
Args:
468+
start:
469+
device (optional): array API concept of device where the output array
470+
is created. `device` can be `None`, a oneAPI filter selector string,
471+
an instance of :class:`dpctl.SyclDevice` corresponding to a
472+
non-partitioned SYCL device, an instance of
473+
:class:`dpctl.SyclQueue`, or a `Device` object returned by
474+
`dpctl.tensor.usm_array.device`. Default: `None`.
475+
usm_type ("device"|"shared"|"host", optional): The type of SYCL USM
476+
allocation for the output array. Default: `'device'`.
477+
sycl_queue (:class:`dpctl.SyclQueue`, optional): The SYCL queue to use
478+
for output array allocation and copying. `sycl_queue` and `device`
479+
are exclusive keywords, i.e. use one or another. If both are
480+
specified, a `TypeError` is raised unless both imply the same
481+
underlying SYCL queue to be used. If both a `None`, the
482+
`dpctl.SyclQueue()` is used for allocation and copying.
483+
Default: `None`.
484+
"""
485+
if stop is None:
486+
stop = start
487+
start = 0
488+
(
489+
start,
490+
stop,
491+
step,
492+
), dt = _coerce_and_infer_dt(start, stop, step, dt=dtype)
493+
try:
494+
tmp = 1 + (stop - start - 1) / step
495+
if type(tmp) is complex and tmp.imag == 0.0:
496+
tmp = tmp.real
497+
sh = int(tmp)
498+
if sh < 0:
499+
sh = 0
500+
except TypeError:
501+
sh = 0
502+
dpctl.utils.validate_usm_type(usm_type, allow_none=False)
503+
sycl_queue = normalize_queue_device(sycl_queue=sycl_queue, device=device)
504+
res = dpt.usm_ndarray(
505+
(sh,),
506+
dtype=dt,
507+
buffer=usm_type,
508+
order="C",
509+
buffer_ctor_kwargs={"queue": sycl_queue},
510+
)
511+
_step = (start + step) - start
512+
hev, _ = ti._linspace_step(start, _step, res, sycl_queue)
513+
hev.wait()
514+
return res

0 commit comments

Comments
 (0)