Skip to content

Commit f2c5f18

Browse files
Added Device class representing Data-API notion of device
Data-API notion of device is higher level abstraction. The class does allow for public constructor, has class method to create instance instead. It typesets as Device(filter_string) for parent devices and Device(queue) for sub-devices. It can constructed from - filter selector string - SyclDevice corresponding to a root device (attempt at passing sub-device raises and error) - SyclQueue - Another instance of Device class It implements sycl_queue, sycl_device, sycl_context attributes. usm_ndarray adds .device property, and .to_device(device) method. ``` In [5]: X = dpt.usm_ndarray((4, 5), dtype='d') In [6]: X.device Out[6]: Device(level_zero:gpu:0) In [7]: Y = X.to_device('cpu') In [8]: Y.device Out[8]: Device(opencl:cpu:0) ```
1 parent 049d1a5 commit f2c5f18

File tree

2 files changed

+124
-2
lines changed

2 files changed

+124
-2
lines changed

dpctl/tensor/_device.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Data Parallel Control (dpctl)
2+
#
3+
# Copyright 2020-2021 Intel Corporation
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
import dpctl
17+
18+
19+
class Device:
20+
"""
21+
Class representing Data-API concept of Device
22+
23+
This is a wrapper around `dpctl.SyclQueue` with custom
24+
formatting. The class does not have public constructor,
25+
but a class method to construct it from device= keyword
26+
in Array-API functions.
27+
28+
Instance can be queries for `sycl_queue`, `sycl_context`,
29+
or `sycl_device`
30+
"""
31+
32+
def __new__(cls, *args, **kwargs):
33+
raise TypeError("No public constructor")
34+
35+
@classmethod
36+
def create_device(cls, dev):
37+
obj = super().__new__(cls)
38+
if isinstance(dev, Device):
39+
obj.sycl_queue_ = dev.sycl_queue
40+
elif isinstance(dev, dpctl.SyclQueue):
41+
obj.sycl_queue_ = dev
42+
elif isinstance(dev, dpctl.SyclDevice):
43+
par = dev.parent_device
44+
if par is None:
45+
obj.sycl_queue_ = dpctl.SyclQueue(dev)
46+
else:
47+
raise ValueError(
48+
"Using non-root device {} to specify offloading "
49+
"target is ambiguous. Please use dpctl.SyclQueue "
50+
"targeting this device".format(dev)
51+
)
52+
else:
53+
obj.sycl_queue_ = dpctl.SyclQueue(dev)
54+
return obj
55+
56+
@property
57+
def sycl_queue(self):
58+
return self.sycl_queue_
59+
60+
@property
61+
def sycl_context(self):
62+
return self.sycl_queue_.sycl_context
63+
64+
@property
65+
def sycl_device(self):
66+
return self.sycl_queue_.sycl_device
67+
68+
def __repr__(self):
69+
try:
70+
sd = self.sycl_device
71+
except AttributeError:
72+
raise ValueError(
73+
"Instance of {} is not initialized".format(self.__class__)
74+
)
75+
try:
76+
fs = sd.filter_string
77+
return "Device({})".format(fs)
78+
except TypeError:
79+
# This is a sub-device
80+
return repr(self.sycl_queue)

dpctl/tensor/_usmarray.pyx

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import numpy as np
2222
import dpctl
2323
import dpctl.memory as dpmem
2424

25+
from ._device import Device
26+
2527
from cpython.mem cimport PyMem_Free
2628
from cpython.tuple cimport PyTuple_New, PyTuple_SetItem
2729

@@ -181,8 +183,8 @@ cdef class usm_ndarray:
181183
raise ValueError(
182184
"buffer='{}' is not understood. "
183185
"Recognized values are 'device', 'shared', 'host', "
184-
"or an object with __sycl_usm_array_interface__ "
185-
"property".format(buffer))
186+
"an instance of `MemoryUSM*` object, or a usm_ndarray"
187+
"".format(buffer))
186188
elif isinstance(buffer, usm_ndarray):
187189
_buffer = buffer.usm_data
188190
else:
@@ -428,6 +430,13 @@ cdef class usm_ndarray:
428430
q = self.sycl_queue
429431
return q.sycl_device
430432

433+
@property
434+
def device(self):
435+
"""
436+
Returns data-API object representing residence of the array data.
437+
"""
438+
return Device.create_device(self.sycl_queue)
439+
431440
@property
432441
def sycl_context(self):
433442
"""
@@ -475,6 +484,39 @@ cdef class usm_ndarray:
475484
res.flags_ |= (self.flags_ & USM_ARRAY_WRITEABLE)
476485
return res
477486

487+
def to_device(self, target_device):
488+
"""
489+
Transfer array to target device
490+
"""
491+
d = Device.create_device(target_device)
492+
if (d.sycl_device == self.sycl_device):
493+
return self
494+
elif (d.sycl_context == self.sycl_context):
495+
res = usm_ndarray(
496+
self.shape,
497+
self.dtype,
498+
buffer=self.usm_data,
499+
strides=self.strides,
500+
offset=self.get_offset()
501+
)
502+
res.flags_ = self.flags
503+
return res
504+
else:
505+
nbytes = self.usm_data.nbytes
506+
new_buffer = type(self.usm_data)(
507+
nbytes, queue=d.sycl_queue
508+
)
509+
new_buffer.copy_from_device(self.usm_data)
510+
res = usm_ndarray(
511+
self.shape,
512+
self.dtype,
513+
buffer=new_buffer,
514+
strides=self.strides,
515+
offset=self.get_offset()
516+
)
517+
res.flags_ = self.flags
518+
return res
519+
478520

479521
cdef usm_ndarray _real_view(usm_ndarray ary):
480522
"""

0 commit comments

Comments
 (0)