Skip to content

Commit e68c16e

Browse files
Merge pull request #440 from IntelPython/feature/data-api-device
Added Device class representing Data-API notion of device
2 parents 049d1a5 + 4acf56c commit e68c16e

File tree

2 files changed

+148
-2
lines changed

2 files changed

+148
-2
lines changed

dpctl/tensor/_device.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2021 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
import dpctl
17+
18+
19+
class Device:
20+
"""
21+
Class representing Data-API concept of device.
22+
23+
This is a wrapper around :class:`dpctl.SyclQueue` with custom
24+
formatting. The class does not have public constructor,
25+
but a class method to construct it from device= keyword
26+
in Array-API functions.
27+
28+
Instance can be queried for ``sycl_queue``, ``sycl_context``,
29+
or ``sycl_device``.
30+
"""
31+
32+
def __new__(cls, *args, **kwargs):
33+
raise TypeError("No public constructor")
34+
35+
@classmethod
36+
def create_device(cls, dev):
37+
"""
38+
Device.create_device(device)
39+
40+
Creates instance of Device from argument.
41+
42+
Args:
43+
device: None, :class:`.Device`, :class:`dpctl.SyclQueue`, or
44+
a :class:`dpctl.SyclDevice` corresponding to a root
45+
SYCL device.
46+
Raises:
47+
ValueError: if an instance of :class:`dpctl.SycDevice` corresponding
48+
to a sub-device was specified as the argument
49+
SyclQueueCreationError: if :class:`dpctl.SyclQueue` could not be
50+
created from the argument
51+
"""
52+
obj = super().__new__(cls)
53+
if isinstance(dev, Device):
54+
obj.sycl_queue_ = dev.sycl_queue
55+
elif isinstance(dev, dpctl.SyclQueue):
56+
obj.sycl_queue_ = dev
57+
elif isinstance(dev, dpctl.SyclDevice):
58+
par = dev.parent_device
59+
if par is None:
60+
obj.sycl_queue_ = dpctl.SyclQueue(dev)
61+
else:
62+
raise ValueError(
63+
"Using non-root device {} to specify offloading "
64+
"target is ambiguous. Please use dpctl.SyclQueue "
65+
"targeting this device".format(dev)
66+
)
67+
else:
68+
obj.sycl_queue_ = dpctl.SyclQueue(dev)
69+
return obj
70+
71+
@property
72+
def sycl_queue(self):
73+
"""
74+
:class:`dpctl.SyclQueue` used to offload to this :class:`.Device`.
75+
"""
76+
return self.sycl_queue_
77+
78+
@property
79+
def sycl_context(self):
80+
"""
81+
:class:`dpctl.SyclContext` associated with this :class:`.Device`.
82+
"""
83+
return self.sycl_queue_.sycl_context
84+
85+
@property
86+
def sycl_device(self):
87+
"""
88+
:class:`dpctl.SyclDevice` targed by this :class:`.Device`.
89+
"""
90+
return self.sycl_queue_.sycl_device
91+
92+
def __repr__(self):
93+
try:
94+
sd = self.sycl_device
95+
except AttributeError:
96+
raise ValueError(
97+
"Instance of {} is not initialized".format(self.__class__)
98+
)
99+
try:
100+
fs = sd.filter_string
101+
return "Device({})".format(fs)
102+
except TypeError:
103+
# This is a sub-device
104+
return repr(self.sycl_queue)

dpctl/tensor/_usmarray.pyx

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import numpy as np
2222
import dpctl
2323
import dpctl.memory as dpmem
2424

25+
from ._device import Device
26+
2527
from cpython.mem cimport PyMem_Free
2628
from cpython.tuple cimport PyTuple_New, PyTuple_SetItem
2729

@@ -181,8 +183,8 @@ cdef class usm_ndarray:
181183
raise ValueError(
182184
"buffer='{}' is not understood. "
183185
"Recognized values are 'device', 'shared', 'host', "
184-
"or an object with __sycl_usm_array_interface__ "
185-
"property".format(buffer))
186+
"an instance of `MemoryUSM*` object, or a usm_ndarray"
187+
"".format(buffer))
186188
elif isinstance(buffer, usm_ndarray):
187189
_buffer = buffer.usm_data
188190
else:
@@ -428,6 +430,13 @@ cdef class usm_ndarray:
428430
q = self.sycl_queue
429431
return q.sycl_device
430432

433+
@property
434+
def device(self):
435+
"""
436+
Returns data-API object representing residence of the array data.
437+
"""
438+
return Device.create_device(self.sycl_queue)
439+
431440
@property
432441
def sycl_context(self):
433442
"""
@@ -475,6 +484,39 @@ cdef class usm_ndarray:
475484
res.flags_ |= (self.flags_ & USM_ARRAY_WRITEABLE)
476485
return res
477486

487+
def to_device(self, target_device):
488+
"""
489+
Transfer array to target device
490+
"""
491+
d = Device.create_device(target_device)
492+
if (d.sycl_device == self.sycl_device):
493+
return self
494+
elif (d.sycl_context == self.sycl_context):
495+
res = usm_ndarray(
496+
self.shape,
497+
self.dtype,
498+
buffer=self.usm_data,
499+
strides=self.strides,
500+
offset=self.get_offset()
501+
)
502+
res.flags_ = self.flags
503+
return res
504+
else:
505+
nbytes = self.usm_data.nbytes
506+
new_buffer = type(self.usm_data)(
507+
nbytes, queue=d.sycl_queue
508+
)
509+
new_buffer.copy_from_device(self.usm_data)
510+
res = usm_ndarray(
511+
self.shape,
512+
self.dtype,
513+
buffer=new_buffer,
514+
strides=self.strides,
515+
offset=self.get_offset()
516+
)
517+
res.flags_ = self.flags
518+
return res
519+
478520

479521
cdef usm_ndarray _real_view(usm_ndarray ary):
480522
"""

0 commit comments

Comments
 (0)