Skip to content

Commit 63109d7

Browse files
Added Memory C-API functions to dpctl_capi.h
Renamed functions to avoid name clashes for get_queue_ref/get_context_ref defined for different signatures. (C-API requires names to be different). Adjsuted tests to reflect changes in C-API names
1 parent 7618634 commit 63109d7

File tree

3 files changed

+18
-8
lines changed

3 files changed

+18
-8
lines changed

dpctl/apis/include/dpctl_capi.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
#include "../_sycl_event_api.h"
3737
#include "../_sycl_queue.h"
3838
#include "../_sycl_queue_api.h"
39+
#include "../memory/_memory.h"
40+
#include "../memory/_memory_api.h"
3941
// clang-format on
4042

4143
/*
@@ -50,6 +52,6 @@ void import_dpctl(void)
5052
import_dpctl___sycl_context();
5153
import_dpctl___sycl_event();
5254
import_dpctl___sycl_queue();
53-
55+
import_dpctl__memory___memory();
5456
return;
5557
}

dpctl/memory/_memory.pyx

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -751,20 +751,20 @@ def as_usm_memory(obj):
751751
)
752752

753753

754-
cdef api DPCTLSyclUSMRef get_usm_pointer(_Memory obj):
754+
cdef api DPCTLSyclUSMRef Memory_get_usm_pointer(_Memory obj):
755755
"Pointer of USM allocation"
756756
return obj.memory_ptr
757757

758-
cdef api DPCTLSyclContextRef get_context_ref(_Memory obj):
758+
cdef api DPCTLSyclContextRef Memory_get_context_ref(_Memory obj):
759759
"Context reference to which USM allocation is bound"
760760
return obj.queue._context.get_context_ref()
761761

762-
cdef api DPCTLSyclQueueRef get_queue_ref(_Memory obj):
762+
cdef api DPCTLSyclQueueRef Memory_get_queue_ref(_Memory obj):
763763
"""Queue associated with this allocation, used
764764
for copying, population, etc."""
765765
return obj.queue.get_queue_ref()
766766

767-
cdef api size_t get_nbytes(_Memory obj):
767+
cdef api size_t Memory_get_nbytes(_Memory obj):
768768
"Size of the allocation in bytes."
769769
return <size_t>obj.nbytes
770770

dpctl/tests/test_sycl_usm.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -530,9 +530,10 @@ def test_cpython_api(memory_ctor):
530530
mobj = memory_ctor(1024)
531531
mod = sys.modules[mobj.__class__.__module__]
532532
# get capsules storing function pointers
533-
mem_ptr_fn_cap = mod.__pyx_capi__["get_usm_pointer"]
534-
mem_ctx_ref_fn_cap = mod.__pyx_capi__["get_context_ref"]
535-
mem_nby_fn_cap = mod.__pyx_capi__["get_nbytes"]
533+
mem_ptr_fn_cap = mod.__pyx_capi__["Memory_get_usm_pointer"]
534+
mem_q_ref_fn_cap = mod.__pyx_capi__["Memory_get_queue_ref"]
535+
mem_ctx_ref_fn_cap = mod.__pyx_capi__["Memory_get_context_ref"]
536+
mem_nby_fn_cap = mod.__pyx_capi__["Memory_get_nbytes"]
536537
# construct Python callable to invoke "get_usm_pointer"
537538
cap_ptr_fn = ctypes.pythonapi.PyCapsule_GetPointer
538539
cap_ptr_fn.restype = ctypes.c_void_p
@@ -543,12 +544,16 @@ def test_cpython_api(memory_ctor):
543544
mem_ctx_ref_fn_ptr = cap_ptr_fn(
544545
mem_ctx_ref_fn_cap, b"DPCTLSyclContextRef (struct Py_MemoryObject *)"
545546
)
547+
mem_q_ref_fn_ptr = cap_ptr_fn(
548+
mem_q_ref_fn_cap, b"DPCTLSyclQueueRef (struct Py_MemoryObject *)"
549+
)
546550
mem_nby_fn_ptr = cap_ptr_fn(
547551
mem_nby_fn_cap, b"size_t (struct Py_MemoryObject *)"
548552
)
549553
callable_maker = ctypes.PYFUNCTYPE(ctypes.c_void_p, ctypes.py_object)
550554
get_ptr_fn = callable_maker(mem_ptr_fn_ptr)
551555
get_ctx_ref_fn = callable_maker(mem_ctx_ref_fn_ptr)
556+
get_q_ref_fn = callable_maker(mem_q_ref_fn_ptr)
552557
get_nby_fn = callable_maker(mem_nby_fn_ptr)
553558

554559
capi_ptr = get_ptr_fn(mobj)
@@ -557,6 +562,9 @@ def test_cpython_api(memory_ctor):
557562
capi_ctx_ref = get_ctx_ref_fn(mobj)
558563
direct_ctx_ref = mobj._context.addressof_ref()
559564
assert capi_ctx_ref == direct_ctx_ref
565+
capi_q_ref = get_q_ref_fn(mobj)
566+
direct_q_ref = mobj.sycl_queue.addressof_ref()
567+
assert capi_q_ref == direct_q_ref
560568
capi_nbytes = get_nby_fn(mobj)
561569
direct_nbytes = mobj.nbytes
562570
assert capi_nbytes == direct_nbytes

0 commit comments

Comments
 (0)