Skip to content

Commit 96c6741

Browse files
Hooked up _Memory.memset() method (#815)
obj.memset() filled object with zero bytes obj.memset(val) filles object with given value (expected to fit in `unsigned short` type)
1 parent f1239b9 commit 96c6741

File tree

4 files changed

+53
-0
lines changed

4 files changed

+53
-0
lines changed

dpctl/_backend.pxd

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,11 @@ cdef extern from "syclinterface/dpctl_sycl_queue_interface.h":
363363
void *Dest,
364364
const void *Src,
365365
size_t Count)
366+
cdef DPCTLSyclEventRef DPCTLQueue_Memset(
367+
const DPCTLSyclQueueRef Q,
368+
void *Dest,
369+
int Val,
370+
size_t Count)
366371
cdef DPCTLSyclEventRef DPCTLQueue_Prefetch(
367372
const DPCTLSyclQueueRef Q,
368373
const void *Src,

dpctl/memory/_memory.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ cdef public api class _Memory [object Py_MemoryObject, type Py_MemoryType]:
4747
cpdef copy_to_host(self, object obj=*)
4848
cpdef copy_from_host(self, object obj)
4949
cpdef copy_from_device(self, object obj)
50+
cpdef memset(self, unsigned short val=*)
5051

5152
cpdef bytes tobytes(self)
5253

dpctl/memory/_memory.pyx

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ from dpctl._backend cimport ( # noqa: E211
4747
DPCTLQueue_Delete,
4848
DPCTLQueue_GetContext,
4949
DPCTLQueue_Memcpy,
50+
DPCTLQueue_Memset,
5051
DPCTLSyclContextRef,
5152
DPCTLSyclDeviceRef,
5253
DPCTLSyclEventRef,
@@ -478,6 +479,27 @@ cdef class _Memory:
478479
else:
479480
raise TypeError
480481

482+
cpdef memset(self, unsigned short val = 0):
483+
"""
484+
Populates this USM allocation with given value.
485+
"""
486+
cdef DPCTLSyclEventRef ERef = NULL
487+
488+
ERef = DPCTLQueue_Memset(
489+
self.queue.get_queue_ref(),
490+
<void *>self.memory_ptr, # destination
491+
<int> val,
492+
self.nbytes)
493+
494+
if ERef is not NULL:
495+
DPCTLEvent_Wait(ERef)
496+
DPCTLEvent_Delete(ERef)
497+
return
498+
else:
499+
raise RuntimeError(
500+
"Call to memset resulted in an error"
501+
)
502+
481503
cpdef bytes tobytes(self):
482504
"""
483505
Constructs bytes object populated with copy of USM memory.

dpctl/tests/test_sycl_usm.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,3 +610,28 @@ def test_memory_copy_between_contexts():
610610
copy_buf = bytearray(256)
611611
m1.copy_to_host(copy_buf)
612612
assert host_buf == copy_buf
613+
614+
615+
def test_memset():
616+
try:
617+
q = dpctl.SyclQueue()
618+
except dpctl.SyclQueueCreationError:
619+
pytest.skip("Default queue could not be created")
620+
621+
n = 4086
622+
m_de = MemoryUSMDevice(n, queue=q)
623+
m_sh = MemoryUSMShared(n, queue=q)
624+
m_ho = MemoryUSMHost(n, queue=q)
625+
626+
host_buf = bytearray(n)
627+
m_de.memset()
628+
m_de.copy_to_host(host_buf)
629+
assert host_buf == b"\x00" * n
630+
631+
m_sh.memset(ord("j"))
632+
m_sh.copy_to_host(host_buf)
633+
assert host_buf == b"j" * n
634+
635+
m_ho.memset(ord("7"))
636+
m_ho.copy_to_host(host_buf)
637+
assert host_buf == b"7" * n

0 commit comments

Comments
 (0)