Skip to content

Commit 95047ab

Browse files
authored
Add wrapper for SYCL queue::memcpy(). (#70)
* Add wrapper for SYCL queue::memcpy(). * Fix comment. * Fix comment for DPPLQueue_memcpy * Rename DPPLQueue_Memcpy * Fix cdef void* declaring variables
1 parent 63ce96f commit 95047ab

File tree

9 files changed

+166
-4
lines changed

9 files changed

+166
-4
lines changed

backends/include/dppl_sycl_queue_interface.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,16 @@ DPPLQueue_GetDevice (__dppl_keep const DPPLSyclQueueRef QRef);
6363
DPPL_API
6464
void DPPLQueue_Delete (__dppl_take DPPLSyclQueueRef QRef);
6565

66+
/*!
67+
* @brief C-API wrapper for sycl::queue::memcpy. It waits an event.
68+
*
69+
* @param QRef An opaque pointer to the sycl queue.
70+
* @param Dest An USM pointer to the destination memory.
71+
* @param Src An USM pointer to the source memory.
72+
* @param Count A number of bytes to copy.
73+
*/
74+
DPPL_API
75+
void DPPLQueue_Memcpy (__dppl_keep const DPPLSyclQueueRef QRef,
76+
void *Dest, const void *Src, size_t Count);
77+
6678
DPPL_C_EXTERN_C_END

backends/source/dppl_sycl_queue_interface.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,11 @@ void DPPLQueue_Delete (__dppl_take DPPLSyclQueueRef QRef)
6161
{
6262
delete unwrap(QRef);
6363
}
64+
65+
void DPPLQueue_Memcpy (__dppl_take const DPPLSyclQueueRef QRef,
66+
void *Dest, const void *Src, size_t Count)
67+
{
68+
auto Q = unwrap(QRef);
69+
auto event = Q->memcpy(Dest, Src, Count);
70+
event.wait();
71+
}

dpctl/_memory.pxd

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
##===--------------- _memory.pxd - dpctl module --------*- Cython -*-------===##
2+
##
3+
## Data Parallel Control (dpCtl)
4+
##
5+
## Copyright 2020 Intel Corporation
6+
##
7+
## Licensed under the Apache License, Version 2.0 (the "License");
8+
## you may not use this file except in compliance with the License.
9+
## You may obtain a copy of the License at
10+
##
11+
## http://www.apache.org/licenses/LICENSE-2.0
12+
##
13+
## Unless required by applicable law or agreed to in writing, software
14+
## distributed under the License is distributed on an "AS IS" BASIS,
15+
## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
## See the License for the specific language governing permissions and
17+
## limitations under the License.
18+
##
19+
##===----------------------------------------------------------------------===##
20+
21+
# distutils: language = c++
22+
# cython: language_level=3
23+
24+
from .backend cimport DPPLSyclUSMRef
25+
from ._sycl_core cimport SyclQueue
26+
27+
28+
cdef class Memory:
29+
cdef DPPLSyclUSMRef memory_ptr
30+
cdef Py_ssize_t nbytes
31+
cdef SyclQueue queue
32+
33+
cdef _cinit(self, Py_ssize_t nbytes, ptr_type, SyclQueue queue)
34+
cdef _getbuffer(self, Py_buffer *buffer, int flags)
35+
36+
37+
cdef class MemoryUSMShared(Memory):
38+
pass
39+
40+
41+
cdef class MemoryUSMHost(Memory):
42+
pass
43+
44+
45+
cdef class MemoryUSMDevice(Memory):
46+
pass

dpctl/_memory.pyx

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,6 @@ from cpython cimport Py_buffer
3636

3737

3838
cdef class Memory:
39-
cdef DPPLSyclUSMRef memory_ptr
40-
cdef Py_ssize_t nbytes
41-
cdef SyclQueue queue
4239

4340
cdef _cinit(self, Py_ssize_t nbytes, ptr_type, SyclQueue queue):
4441
cdef DPPLSyclUSMRef p

dpctl/_sycl_core.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,4 @@ cdef class SyclQueue:
6363
cpdef SyclContext get_sycl_context (self)
6464
cpdef SyclDevice get_sycl_device (self)
6565
cdef DPPLSyclQueueRef get_queue_ref (self)
66+
cpdef memcpy (self, dest, src, int count)

dpctl/backend.pxd

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ cdef extern from "dppl_sycl_queue_interface.h":
7878
except+
7979
cdef DPPLSyclDeviceRef DPPLQueue_GetDevice (const DPPLSyclQueueRef Q) \
8080
except +
81+
cdef void DPPLQueue_Memcpy (const DPPLSyclQueueRef Q,
82+
void *Dest, const void *Src, size_t Count) \
83+
except +
8184

8285

8386
cdef extern from "dppl_sycl_queue_manager.h":

dpctl/sycl_core.pyx

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@
2929
from __future__ import print_function
3030
from enum import Enum, auto
3131
import logging
32-
from dpctl.backend cimport *
32+
from .backend cimport *
33+
from ._memory cimport Memory
3334

3435

3536
_logger = logging.getLogger(__name__)
@@ -132,6 +133,22 @@ cdef class SyclQueue:
132133
cdef DPPLSyclQueueRef get_queue_ref (self):
133134
return self.queue_ptr
134135

136+
cpdef memcpy (self, dest, src, int count):
137+
cdef void *c_dest
138+
cdef void *c_src
139+
140+
if isinstance(dest, Memory):
141+
c_dest = <void*>(<Memory>dest).memory_ptr
142+
else:
143+
raise TypeError("Parameter dest should be Memory.")
144+
145+
if isinstance(src, Memory):
146+
c_src = <void*>(<Memory>src).memory_ptr
147+
else:
148+
raise TypeError("Parameter src should be Memory.")
149+
150+
DPPLQueue_Memcpy(self.queue_ptr, c_dest, c_src, count)
151+
135152

136153
cdef class _SyclQueueManager:
137154
def _set_as_current_queue (self, device_ty, device_id):

dpctl/tests/test_sycl_queue_memcpy.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
##===---------- test_sycl_queue_manager.py - dpctl -------*- Python -*----===##
2+
##
3+
## Data Parallel Control (dpCtl)
4+
##
5+
## Copyright 2020 Intel Corporation
6+
##
7+
## Licensed under the Apache License, Version 2.0 (the "License");
8+
## you may not use this file except in compliance with the License.
9+
## You may obtain a copy of the License at
10+
##
11+
## http://www.apache.org/licenses/LICENSE-2.0
12+
##
13+
## Unless required by applicable law or agreed to in writing, software
14+
## distributed under the License is distributed on an "AS IS" BASIS,
15+
## WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
## See the License for the specific language governing permissions and
17+
## limitations under the License.
18+
##
19+
##===----------------------------------------------------------------------===##
20+
###
21+
### \file
22+
### Defines unit test cases for the SyclQueue.memcpy in sycl_core.pyx.
23+
##===----------------------------------------------------------------------===##
24+
25+
import dpctl
26+
import unittest
27+
28+
29+
30+
class TestQueueMemcpy (unittest.TestCase):
31+
32+
def _create_memory (self):
33+
nbytes = 1024
34+
queue = dpctl.get_current_queue()
35+
mobj = dpctl._memory.MemoryUSMShared(nbytes, queue)
36+
return mobj
37+
38+
def test_memcpy_copy_usm_to_usm (self):
39+
mobj1 = self._create_memory()
40+
mobj2 = self._create_memory()
41+
q = dpctl.get_current_queue()
42+
43+
mv1 = memoryview(mobj1)
44+
mv2 = memoryview(mobj2)
45+
46+
mv1[:3] = b'123'
47+
48+
q.memcpy(mobj2, mobj1, 3)
49+
50+
self.assertEqual(mv2[:3], b'123')
51+
52+
def test_memcpy_type_error (self):
53+
mobj = self._create_memory()
54+
q = dpctl.get_current_queue()
55+
56+
with self.assertRaises(TypeError) as cm:
57+
q.memcpy(None, mobj, 3)
58+
59+
self.assertEqual(type(cm.exception), TypeError)
60+
self.assertEqual(str(cm.exception), "Parameter dest should be Memory.")
61+
62+
with self.assertRaises(TypeError) as cm:
63+
q.memcpy(mobj, None, 3)
64+
65+
self.assertEqual(type(cm.exception), TypeError)
66+
self.assertEqual(str(cm.exception), "Parameter src should be Memory.")
67+
68+
69+
if __name__ == '__main__':
70+
unittest.main()

dpctl/tests/test_sycl_usm.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,13 @@ def test_memory_gpu_context (self):
7878
self.assertTrue(usm_type in ['unknown', 'shared'])
7979

8080

81+
def test_buffer_protocol (self):
82+
mobj = self._create_memory()
83+
mv1 = memoryview(mobj)
84+
mv2 = memoryview(mobj)
85+
self.assertEqual(mv1, mv2)
86+
87+
8188
class TestMemoryUSMBase:
8289
""" Base tests for MemoryUSM* """
8390

@@ -116,5 +123,6 @@ class TestMemoryUSMDevice(TestMemoryUSMBase, unittest.TestCase):
116123
MemoryUSMClass = MemoryUSMDevice
117124
usm_type = 'device'
118125

126+
119127
if __name__ == '__main__':
120128
unittest.main()

0 commit comments

Comments
 (0)