Skip to content

Commit 35df171

Browse files
committed
C API for SyclProgram and SyclKernel classes
1 parent 78c443b commit 35df171

File tree

9 files changed

+194
-2
lines changed

9 files changed

+194
-2
lines changed

dpctl/_backend.pxd

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ cdef extern from "syclinterface/dpctl_sycl_event_interface.h":
263263
cdef extern from "syclinterface/dpctl_sycl_kernel_interface.h":
264264
cdef size_t DPCTLKernel_GetNumArgs(const DPCTLSyclKernelRef KRef)
265265
cdef void DPCTLKernel_Delete(DPCTLSyclKernelRef KRef)
266+
cdef DPCTLSyclKernelRef DPCTLKernel_Copy(const DPCTLSyclKernelRef KRef)
266267
cdef size_t DPCTLKernel_GetWorkGroupSize(const DPCTLSyclKernelRef KRef)
267268
cdef size_t DPCTLKernel_GetPreferredWorkGroupSizeMultiple(const DPCTLSyclKernelRef KRef)
268269
cdef size_t DPCTLKernel_GetPrivateMemSize(const DPCTLSyclKernelRef KRef)
@@ -341,6 +342,7 @@ cdef extern from "syclinterface/dpctl_sycl_kernel_bundle_interface.h":
341342
cdef bool DPCTLKernelBundle_HasKernel(DPCTLSyclKernelBundleRef KBRef,
342343
const char *KernelName)
343344
cdef void DPCTLKernelBundle_Delete(DPCTLSyclKernelBundleRef KBRef)
345+
cdef DPCTLSyclKernelBundleRef DPCTLKernelBundle_Copy(const DPCTLSyclKernelBundleRef KBRef)
344346

345347

346348
cdef extern from "syclinterface/dpctl_sycl_queue_interface.h":

dpctl/apis/include/dpctl4pybind11.hpp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ struct dpctl_capi
7171
PyTypeObject *PyMemoryUSMSharedType_;
7272
PyTypeObject *PyMemoryUSMHostType_;
7373
PyTypeObject *PyUSMArrayType_;
74+
PyTypeObject *PySyclProgramType_;
75+
PyTypeObject *PySyclKernelType_;
7476

7577
DPCTLSyclDeviceRef (*SyclDevice_GetDeviceRef_)(PySyclDeviceObject *);
7678
PySyclDeviceObject *(*SyclDevice_Make_)(DPCTLSyclDeviceRef);
@@ -94,6 +96,13 @@ struct dpctl_capi
9496
DPCTLSyclQueueRef,
9597
PyObject *);
9698

99+
// program
100+
DPCTLSyclKernelRef (*SyclKernel_GetKernelRef_)(PySyclKernelObject *);
101+
PySyclKernelObject *(*SyclKernel_Make_)(DPCTLSyclKernelRef);
102+
103+
DPCTLSyclKernelBundleRef (*SyclProgram_GetKernelBundleRef_)(PySyclProgramObject *);
104+
PySyclProgramObject *(*SyclProgram_Make_)(DPCTLSyclKernelBundleRef);
105+
97106
// tensor
98107
char *(*UsmNDArray_GetData_)(PyUSMArrayObject *);
99108
int (*UsmNDArray_GetNDim_)(PyUSMArrayObject *);
@@ -131,6 +140,14 @@ struct dpctl_capi
131140
{
132141
return PyObject_TypeCheck(obj, PySyclQueueType_) != 0;
133142
}
143+
bool PySyclKernel_Check_(PyObject *obj) const
144+
{
145+
return PyObject_TypeCheck(obj, PySyclKernelType_) != 0;
146+
}
147+
bool PySyclProgram_Check_(PyObject *obj) const
148+
{
149+
return PyObject_TypeCheck(obj, PySyclProgramType_) != 0;
150+
}
134151

135152
~dpctl_capi(){};
136153

@@ -174,6 +191,8 @@ struct dpctl_capi
174191
std::shared_ptr<py::object> default_usm_memory;
175192
std::shared_ptr<py::object> default_usm_ndarray;
176193
std::shared_ptr<py::object> as_usm_memory;
194+
std::shared_ptr<py::object> default_sycl_kernel;
195+
std::shared_ptr<py::object> default_sycl_program;
177196

178197
dpctl_capi()
179198
: default_sycl_queue{}, default_usm_memory{}, default_usm_ndarray{},
@@ -201,6 +220,8 @@ struct dpctl_capi
201220
this->PyMemoryUSMSharedType_ = &PyMemoryUSMSharedType;
202221
this->PyMemoryUSMHostType_ = &PyMemoryUSMHostType;
203222
this->PyUSMArrayType_ = &PyUSMArrayType;
223+
this->PySyclProgramType_ = &PySyclProgramType;
224+
this->PySyclKernelType_ = &PySyclKernelType;
204225

205226
// SyclDevice API
206227
this->SyclDevice_GetDeviceRef_ = SyclDevice_GetDeviceRef;
@@ -225,6 +246,10 @@ struct dpctl_capi
225246
this->Memory_GetNumBytes_ = Memory_GetNumBytes;
226247
this->Memory_Make_ = Memory_Make;
227248

249+
// dpctl.program API
250+
this->SyclKernel_Make_ = SyclKernel_Make;
251+
this->SyclProgram_Make_ = SyclProgram_Make;
252+
228253
// dpctl.tensor.usm_ndarray API
229254
this->UsmNDArray_GetData_ = UsmNDArray_GetData;
230255
this->UsmNDArray_GetNDim_ = UsmNDArray_GetNDim;
@@ -506,6 +531,76 @@ template <> struct type_caster<sycl::event>
506531

507532
DPCTL_TYPE_CASTER(sycl::event, _("dpctl.SyclEvent"));
508533
};
534+
535+
/* This type caster associates ``sycl::kernel`` C++ class with
536+
* :class:`dpctl.program.SyclKernel` for the purposes of generation of
537+
* Python bindings by pybind11.
538+
*/
539+
template <> struct type_caster<sycl::kernel>
540+
{
541+
public:
542+
bool load(handle src, bool)
543+
{
544+
PyObject *source = src.ptr();
545+
auto &api = ::dpctl::detail::dpctl_capi::get();
546+
if (api.PySyclKernel_Check_(source)) {
547+
DPCTLSyclKernelRef KRef = api.SyclKernel_GetKernelRef_(
548+
reinterpret_cast<PySyclKernelObject *>(source));
549+
value = std::make_unique<sycl::kernel>(
550+
*(reinterpret_cast<sycl::kernel *>(KRef)));
551+
return true;
552+
}
553+
else {
554+
throw py::type_error(
555+
"Input is of unexpected type, expected dpctl.program.SyclKernel");
556+
}
557+
}
558+
559+
static handle cast(sycl::kernel src, return_value_policy, handle)
560+
{
561+
auto &api = ::dpctl::detail::dpctl_capi::get();
562+
auto tmp =
563+
api.SyclKernel_Make_(reinterpret_cast<DPCTLSyclKernelRef>(&src));
564+
return handle(reinterpret_cast<PyObject *>(tmp));
565+
}
566+
567+
DPCTL_TYPE_CASTER(sycl::kernel, _("dpctl.program.SyclKernel"));
568+
};
569+
570+
/* This type caster associates ``sycl::kernel_bundle<sycl::bundle_state::executable>`` C++ class with
571+
* :class:`dpctl.program.SyclProgram` for the purposes of generation of
572+
* Python bindings by pybind11.
573+
*/
574+
template <> struct type_caster<sycl::kernel_bundle<sycl::bundle_state::executable>>
575+
{
576+
public:
577+
bool load(handle src, bool)
578+
{
579+
PyObject *source = src.ptr();
580+
auto &api = ::dpctl::detail::dpctl_capi::get();
581+
if (api.PySyclProgram_Check_(source)) {
582+
DPCTLSyclKernelBundleRef KBRef = api.SyclProgram_GetKernelBundleRef_(
583+
reinterpret_cast<PySyclProgramObject *>(source));
584+
value = std::make_unique<sycl::kernel_bundle<sycl::bundle_state::executable>>(
585+
*(reinterpret_cast<sycl::kernel_bundle<sycl::bundle_state::executable> *>(KBRef)));
586+
return true;
587+
}
588+
else {
589+
throw py::type_error(
590+
"Input is of unexpected type, expected dpctl.SyclEvent");
591+
}
592+
}
593+
594+
static handle cast(sycl::kernel_bundle<sycl::bundle_state::executable> src, return_value_policy, handle)
595+
{
596+
auto &api = ::dpctl::detail::dpctl_capi::get();
597+
auto tmp =
598+
api.SyclProgram_Make_(reinterpret_cast<DPCTLSyclKernelBundleRef>(&src));
599+
return handle(reinterpret_cast<PyObject *>(tmp));
600+
}
601+
602+
DPCTL_TYPE_CASTER(sycl::kernel_bundle<sycl::bundle_state::executable>, _("dpctl.program.SyclProgram"));
603+
};
509604
} // namespace detail
510605
} // namespace pybind11
511606

dpctl/apis/include/dpctl_capi.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@
4040
#include "../memory/_memory_api.h"
4141
#include "../tensor/_usmarray.h"
4242
#include "../tensor/_usmarray_api.h"
43+
#include "../program/_program.h"
44+
#include "../program/_program_api.h"
45+
46+
4347
// clang-format on
4448

4549
/*
@@ -59,5 +63,6 @@ static inline void import_dpctl(void)
5963
import_dpctl___sycl_queue();
6064
import_dpctl__memory___memory();
6165
import_dpctl__tensor___usmarray();
66+
import_dpctl__program___program();
6267
return;
6368
}

dpctl/program/_program.pxd

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ from .._sycl_device cimport SyclDevice
2828
from .._sycl_queue cimport SyclQueue
2929

3030

31-
cdef class SyclKernel:
31+
cdef api class SyclKernel [object PySyclKernelObject, type PySyclKernelType]:
3232
''' Wraps a sycl::kernel object created from an OpenCL interoperability
3333
kernel.
3434
'''
@@ -40,7 +40,7 @@ cdef class SyclKernel:
4040
cdef SyclKernel _create (DPCTLSyclKernelRef kref, str name)
4141

4242

43-
cdef class SyclProgram:
43+
cdef api class SyclProgram [object PySyclProgramObject, type PySyclProgramType]:
4444
''' Wraps a sycl::kernel_bundle<sycl::bundle_state::executable> object created by
4545
using SYCL interoperability layer for OpenCL and Level-Zero backends.
4646
@@ -59,3 +59,4 @@ cdef class SyclProgram:
5959
cpdef create_program_from_source (SyclQueue q, unicode source, unicode copts=*)
6060
cpdef create_program_from_spirv (SyclQueue q, const unsigned char[:] IL,
6161
unicode copts=*)
62+

dpctl/program/_program.pyx

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ from libc.stdint cimport uint32_t
3131
from dpctl._backend cimport ( # noqa: E211, E402;
3232
DPCTLCString_Delete,
3333
DPCTLKernel_Delete,
34+
DPCTLKernel_Copy,
3435
DPCTLKernel_GetCompileNumSubGroups,
3536
DPCTLKernel_GetCompileSubGroupSize,
3637
DPCTLKernel_GetMaxNumSubGroups,
@@ -41,6 +42,7 @@ from dpctl._backend cimport ( # noqa: E211, E402;
4142
DPCTLKernelBundle_CreateFromOCLSource,
4243
DPCTLKernelBundle_CreateFromSpirv,
4344
DPCTLKernelBundle_Delete,
45+
DPCTLKernelBundle_Copy,
4446
DPCTLKernelBundle_GetKernel,
4547
DPCTLKernelBundle_HasKernel,
4648
DPCTLSyclContextRef,
@@ -164,6 +166,19 @@ cdef class SyclKernel:
164166
cdef size_t n = DPCTLKernel_GetCompileSubGroupSize(self._kernel_ref)
165167
return n
166168

169+
cdef api DPCTLSyclKernelRef SyclKernel_GetKernelRef(SyclKernel ker):
170+
""" C-API function to access opaque kernel reference from
171+
Python object of type :class:`dpctl.program.SyclKernel`.
172+
"""
173+
return ker.get_kernel_ref()
174+
175+
cdef api SyclKernel SyclKernel_Make(DPCTLSyclKernelRef KRef):
176+
"""
177+
C-API function to create :class:`dpctl.program.SyclKernel`
178+
instance from opaque sycl kernel reference.
179+
"""
180+
cdef DPCTLSyclKernelRef copied_KRef = DPCTLKernel_Copy(KRef)
181+
return SyclKernel._create(copied_KRef, "foo")
167182

168183
cdef class SyclProgram:
169184
""" Wraps a ``sycl::kernel_bundle<sycl::bundle_state::executable>`` object
@@ -290,3 +305,17 @@ cpdef create_program_from_spirv(SyclQueue q, const unsigned char[:] IL,
290305
raise SyclProgramCompilationError()
291306

292307
return SyclProgram._create(KBref)
308+
309+
cdef api DPCTLSyclKernelBundleRef SyclProgram_GetKernelBundleRef(SyclProgram pro):
310+
""" C-API function to access opaque kernel bundle reference from
311+
Python object of type :class:`dpctl.program.SyclKernel`.
312+
"""
313+
return pro.get_program_ref()
314+
315+
cdef api SyclProgram SyclProgram_Make(DPCTLSyclKernelBundleRef KBRef):
316+
"""
317+
C-API function to create :class:`dpctl.program.SyclProgram`
318+
instance from opaque sycl kernel bundle reference.
319+
"""
320+
cdef DPCTLSyclKernelBundleRef copied_KBRef = DPCTLKernelBundle_Copy(KBRef)
321+
return SyclProgram._create(copied_KBRef)

libsyclinterface/include/dpctl_sycl_kernel_bundle_interface.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,4 +117,16 @@ bool DPCTLKernelBundle_HasKernel(__dpctl_keep DPCTLSyclKernelBundleRef KBRef,
117117
DPCTL_API
118118
void DPCTLKernelBundle_Delete(__dpctl_take DPCTLSyclKernelBundleRef KBRef);
119119

120+
/*!
121+
* @brief Returns a copy of the DPCTLSyclKernelBundleRef object.
122+
*
123+
* @param KBRef DPCTLSyclKernelBundleRef object to be copied.
124+
* @return A new DPCTLSyclKernelBundleRef created by copying the passed in
125+
* DPCTLSyclKernelBundleRef object.
126+
* @ingroup KernelBundleInterface
127+
*/
128+
DPCTL_API
129+
__dpctl_give DPCTLSyclKernelBundleRef
130+
DPCTLKernelBundle_Copy(__dpctl_keep const DPCTLSyclKernelBundleRef KBRef);
131+
120132
DPCTL_C_EXTERN_C_END

libsyclinterface/include/dpctl_sycl_kernel_interface.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,18 @@ size_t DPCTLKernel_GetNumArgs(__dpctl_keep const DPCTLSyclKernelRef KRef);
6262
DPCTL_API
6363
void DPCTLKernel_Delete(__dpctl_take DPCTLSyclKernelRef KRef);
6464

65+
/*!
66+
* @brief Returns a copy of the DPCTLSyclKernelRef object.
67+
*
68+
* @param KRef DPCTLSyclKernelRef object to be copied.
69+
* @return A new DPCTLSyclKernelRef created by copying the passed in
70+
* DPCTLSyclKernelRef object.
71+
* @ingroup KernelInterface
72+
*/
73+
DPCTL_API
74+
__dpctl_give DPCTLSyclKernelRef
75+
DPCTLKernel_Copy(__dpctl_keep const DPCTLSyclKernelRef KRef);
76+
6577
/*!
6678
* !brief Wrapper around
6779
* `kernel::get_info<info::kernel_device_specific::work_group_size>()`.

libsyclinterface/source/dpctl_sycl_kernel_bundle_interface.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,3 +740,21 @@ void DPCTLKernelBundle_Delete(__dpctl_take DPCTLSyclKernelBundleRef KBRef)
740740
{
741741
delete unwrap(KBRef);
742742
}
743+
744+
__dpctl_give DPCTLSyclKernelBundleRef
745+
DPCTLKernelBundle_Copy(__dpctl_keep const DPCTLSyclKernelBundleRef KBRef)
746+
{
747+
auto Bundle = unwrap(KBRef);
748+
if (!Bundle) {
749+
error_handler("Cannot copy DPCTLSyclKernelBundleRef as input is a nullptr",
750+
__FILE__, __func__, __LINE__);
751+
return nullptr;
752+
}
753+
try {
754+
auto CopiedBundle = new kernel_bundle<bundle_state::executable>(*Bundle);
755+
return wrap(CopiedBundle);
756+
} catch (std::exception const &e) {
757+
error_handler(e, __FILE__, __func__, __LINE__);
758+
return nullptr;
759+
}
760+
}

libsyclinterface/source/dpctl_sycl_kernel_interface.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,24 @@ void DPCTLKernel_Delete(__dpctl_take DPCTLSyclKernelRef KRef)
5959
delete unwrap(KRef);
6060
}
6161

62+
__dpctl_give DPCTLSyclKernelRef
63+
DPCTLKernel_Copy(__dpctl_keep const DPCTLSyclKernelRef KRef)
64+
{
65+
auto Kernel = unwrap(KRef);
66+
if (!Kernel) {
67+
error_handler("Cannot copy DPCTLSyclKernelRef as input is a nullptr",
68+
__FILE__, __func__, __LINE__);
69+
return nullptr;
70+
}
71+
try {
72+
auto CopiedKernel = new kernel(*Kernel);
73+
return wrap(CopiedKernel);
74+
} catch (std::exception const &e) {
75+
error_handler(e, __FILE__, __func__, __LINE__);
76+
return nullptr;
77+
}
78+
}
79+
6280
size_t DPCTLKernel_GetWorkGroupSize(__dpctl_keep const DPCTLSyclKernelRef KRef)
6381
{
6482
if (!KRef) {

0 commit comments

Comments
 (0)