Skip to content

Commit 970e34a

Browse files
authored
DPCTL data container initial usage implementation (#841)
1 parent deef7ea commit 970e34a

File tree

2 files changed

+31
-5
lines changed

2 files changed

+31
-5
lines changed

dpnp/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,8 @@
4242
'''
4343
Explicitly use NumPy.ndarray as return type for creation functions
4444
'''
45+
46+
__DPNP_DPCTL_AVAILABLE__ = False
47+
'''
48+
Availability of the DPCtl package in the environment
49+
'''

dpnp/dpnp_utils/dpnp_algo_utils.pyx

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,20 @@ cimport cpython
4242
cimport cython
4343
cimport numpy
4444

45+
try:
46+
"""
47+
Detect DPCtl availability to use data container
48+
"""
49+
import dpctl.tensor as dpctl
50+
51+
config.__DPNP_DPCTL_AVAILABLE__ = True
52+
53+
except ImportError:
54+
"""
55+
No DPCtl data container available
56+
"""
57+
config.__DPNP_DPCTL_AVAILABLE__ = False
58+
4559

4660
"""
4761
Python import functions
@@ -361,6 +375,9 @@ cdef dpnp_descriptor create_output_descriptor(shape_type_c output_shape,
361375
""" Create NumPy ndarray """
362376
# TODO need to use "buffer=" parameter to use SYCL aware memory
363377
result = numpy.ndarray(output_shape, dtype=result_dtype)
378+
elif config.__DPNP_DPCTL_AVAILABLE__:
379+
""" Create DPCTL array """
380+
result = dpctl.usm_ndarray(output_shape, dtype=numpy.dtype(result_dtype).name)
364381
else:
365382
""" Create DPNP array """
366383
result = dparray(output_shape, dtype=result_dtype)
@@ -481,13 +498,17 @@ cdef class dpnp_descriptor:
481498
self.dpnp_descriptor_data_size = 0
482499
self.dpnp_descriptor_is_scalar = True
483500

484-
""" Accure main data storage """
485-
self.descriptor = getattr(obj, "__array_interface__", None)
501+
""" Accure DPCTL data container storage """
502+
self.descriptor = getattr(obj, "__sycl_usm_array_interface__", None)
486503
if self.descriptor is None:
487-
return
488504

489-
if self.descriptor["version"] != 3:
490-
return
505+
""" Accure main data storage """
506+
self.descriptor = getattr(obj, "__array_interface__", None)
507+
if self.descriptor is None:
508+
return
509+
510+
if self.descriptor["version"] != 3:
511+
return
491512

492513
self.origin_pyobj = obj
493514

0 commit comments

Comments
 (0)