Skip to content

Commit 375fa77

Browse files
Added missing initialization of function pointers for dpctl.program CAPI functions
dpctl_capi constructor must initialize DPCTLKernel_GetKernelRef_ and DPCTLKernelBundle_GetKernelBundleRef_ for casters to work correctly.
1 parent 753548b commit 375fa77

File tree

1 file changed

+30
-28
lines changed

1 file changed

+30
-28
lines changed

dpctl/apis/include/dpctl4pybind11.hpp

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,9 @@ struct dpctl_capi
248248
this->Memory_Make_ = Memory_Make;
249249

250250
// dpctl.program API
251+
this->SyclKernel_GetKernelRef_ = SyclKernel_GetKernelRef;
251252
this->SyclKernel_Make_ = SyclKernel_Make;
253+
this->SyclProgram_GetKernelBundleRef_ = SyclProgram_GetKernelBundleRef;
252254
this->SyclProgram_Make_ = SyclProgram_Make;
253255

254256
// dpctl.tensor.usm_ndarray API
@@ -403,7 +405,7 @@ template <> struct type_caster<sycl::queue>
403405
bool load(handle src, bool)
404406
{
405407
PyObject *source = src.ptr();
406-
auto &api = ::dpctl::detail::dpctl_capi::get();
408+
auto const &api = ::dpctl::detail::dpctl_capi::get();
407409
if (api.PySyclQueue_Check_(source)) {
408410
DPCTLSyclQueueRef QRef = api.SyclQueue_GetQueueRef_(
409411
reinterpret_cast<PySyclQueueObject *>(source));
@@ -419,7 +421,7 @@ template <> struct type_caster<sycl::queue>
419421

420422
static handle cast(sycl::queue src, return_value_policy, handle)
421423
{
422-
auto &api = ::dpctl::detail::dpctl_capi::get();
424+
auto const &api = ::dpctl::detail::dpctl_capi::get();
423425
auto tmp =
424426
api.SyclQueue_Make_(reinterpret_cast<DPCTLSyclQueueRef>(&src));
425427
return handle(reinterpret_cast<PyObject *>(tmp));
@@ -438,7 +440,7 @@ template <> struct type_caster<sycl::device>
438440
bool load(handle src, bool)
439441
{
440442
PyObject *source = src.ptr();
441-
auto &api = ::dpctl::detail::dpctl_capi::get();
443+
auto const &api = ::dpctl::detail::dpctl_capi::get();
442444
if (api.PySyclDevice_Check_(source)) {
443445
DPCTLSyclDeviceRef DRef = api.SyclDevice_GetDeviceRef_(
444446
reinterpret_cast<PySyclDeviceObject *>(source));
@@ -454,7 +456,7 @@ template <> struct type_caster<sycl::device>
454456

455457
static handle cast(sycl::device src, return_value_policy, handle)
456458
{
457-
auto &api = ::dpctl::detail::dpctl_capi::get();
459+
auto const &api = ::dpctl::detail::dpctl_capi::get();
458460
auto tmp =
459461
api.SyclDevice_Make_(reinterpret_cast<DPCTLSyclDeviceRef>(&src));
460462
return handle(reinterpret_cast<PyObject *>(tmp));
@@ -473,7 +475,7 @@ template <> struct type_caster<sycl::context>
473475
bool load(handle src, bool)
474476
{
475477
PyObject *source = src.ptr();
476-
auto &api = ::dpctl::detail::dpctl_capi::get();
478+
auto const &api = ::dpctl::detail::dpctl_capi::get();
477479
if (api.PySyclContext_Check_(source)) {
478480
DPCTLSyclContextRef CRef = api.SyclContext_GetContextRef_(
479481
reinterpret_cast<PySyclContextObject *>(source));
@@ -489,7 +491,7 @@ template <> struct type_caster<sycl::context>
489491

490492
static handle cast(sycl::context src, return_value_policy, handle)
491493
{
492-
auto &api = ::dpctl::detail::dpctl_capi::get();
494+
auto const &api = ::dpctl::detail::dpctl_capi::get();
493495
auto tmp =
494496
api.SyclContext_Make_(reinterpret_cast<DPCTLSyclContextRef>(&src));
495497
return handle(reinterpret_cast<PyObject *>(tmp));
@@ -508,7 +510,7 @@ template <> struct type_caster<sycl::event>
508510
bool load(handle src, bool)
509511
{
510512
PyObject *source = src.ptr();
511-
auto &api = ::dpctl::detail::dpctl_capi::get();
513+
auto const &api = ::dpctl::detail::dpctl_capi::get();
512514
if (api.PySyclEvent_Check_(source)) {
513515
DPCTLSyclEventRef ERef = api.SyclEvent_GetEventRef_(
514516
reinterpret_cast<PySyclEventObject *>(source));
@@ -524,7 +526,7 @@ template <> struct type_caster<sycl::event>
524526

525527
static handle cast(sycl::event src, return_value_policy, handle)
526528
{
527-
auto &api = ::dpctl::detail::dpctl_capi::get();
529+
auto const &api = ::dpctl::detail::dpctl_capi::get();
528530
auto tmp =
529531
api.SyclEvent_Make_(reinterpret_cast<DPCTLSyclEventRef>(&src));
530532
return handle(reinterpret_cast<PyObject *>(tmp));
@@ -543,7 +545,7 @@ template <> struct type_caster<sycl::kernel>
543545
bool load(handle src, bool)
544546
{
545547
PyObject *source = src.ptr();
546-
auto &api = ::dpctl::detail::dpctl_capi::get();
548+
auto const &api = ::dpctl::detail::dpctl_capi::get();
547549
if (api.PySyclKernel_Check_(source)) {
548550
DPCTLSyclKernelRef KRef = api.SyclKernel_GetKernelRef_(
549551
reinterpret_cast<PySyclKernelObject *>(source));
@@ -559,7 +561,7 @@ template <> struct type_caster<sycl::kernel>
559561

560562
static handle cast(sycl::kernel src, return_value_policy, handle)
561563
{
562-
auto &api = ::dpctl::detail::dpctl_capi::get();
564+
auto const &api = ::dpctl::detail::dpctl_capi::get();
563565
auto tmp =
564566
api.SyclKernel_Make_(reinterpret_cast<DPCTLSyclKernelRef>(&src),
565567
"dpctl4pybind11_kernel");
@@ -581,7 +583,7 @@ struct type_caster<sycl::kernel_bundle<sycl::bundle_state::executable>>
581583
bool load(handle src, bool)
582584
{
583585
PyObject *source = src.ptr();
584-
auto &api = ::dpctl::detail::dpctl_capi::get();
586+
auto const &api = ::dpctl::detail::dpctl_capi::get();
585587
if (api.PySyclProgram_Check_(source)) {
586588
DPCTLSyclKernelBundleRef KBRef =
587589
api.SyclProgram_GetKernelBundleRef_(
@@ -603,7 +605,7 @@ struct type_caster<sycl::kernel_bundle<sycl::bundle_state::executable>>
603605
return_value_policy,
604606
handle)
605607
{
606-
auto &api = ::dpctl::detail::dpctl_capi::get();
608+
auto const &api = ::dpctl::detail::dpctl_capi::get();
607609
auto tmp = api.SyclProgram_Make_(
608610
reinterpret_cast<DPCTLSyclKernelBundleRef>(&src));
609611
return handle(reinterpret_cast<PyObject *>(tmp));
@@ -650,7 +652,7 @@ class usm_memory : public py::object
650652
sycl::queue get_queue() const
651653
{
652654
Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
653-
auto &api = ::dpctl::detail::dpctl_capi::get();
655+
auto const &api = ::dpctl::detail::dpctl_capi::get();
654656
DPCTLSyclQueueRef QRef = api.Memory_GetQueueRef_(mem_obj);
655657
sycl::queue *obj_q = reinterpret_cast<sycl::queue *>(QRef);
656658
return *obj_q;
@@ -659,14 +661,14 @@ class usm_memory : public py::object
659661
char *get_pointer() const
660662
{
661663
Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
662-
auto &api = ::dpctl::detail::dpctl_capi::get();
664+
auto const &api = ::dpctl::detail::dpctl_capi::get();
663665
DPCTLSyclUSMRef MRef = api.Memory_GetUsmPointer_(mem_obj);
664666
return reinterpret_cast<char *>(MRef);
665667
}
666668

667669
size_t get_nbytes() const
668670
{
669-
auto &api = ::dpctl::detail::dpctl_capi::get();
671+
auto const &api = ::dpctl::detail::dpctl_capi::get();
670672
Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
671673
return api.Memory_GetNumBytes_(mem_obj);
672674
}
@@ -769,7 +771,7 @@ class usm_ndarray : public py::object
769771
{
770772
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
771773

772-
auto &api = ::dpctl::detail::dpctl_capi::get();
774+
auto const &api = ::dpctl::detail::dpctl_capi::get();
773775
return api.UsmNDArray_GetData_(raw_ar);
774776
}
775777

@@ -782,15 +784,15 @@ class usm_ndarray : public py::object
782784
{
783785
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
784786

785-
auto &api = ::dpctl::detail::dpctl_capi::get();
787+
auto const &api = ::dpctl::detail::dpctl_capi::get();
786788
return api.UsmNDArray_GetNDim_(raw_ar);
787789
}
788790

789791
const py::ssize_t *get_shape_raw() const
790792
{
791793
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
792794

793-
auto &api = ::dpctl::detail::dpctl_capi::get();
795+
auto const &api = ::dpctl::detail::dpctl_capi::get();
794796
return api.UsmNDArray_GetShape_(raw_ar);
795797
}
796798

@@ -804,15 +806,15 @@ class usm_ndarray : public py::object
804806
{
805807
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
806808

807-
auto &api = ::dpctl::detail::dpctl_capi::get();
809+
auto const &api = ::dpctl::detail::dpctl_capi::get();
808810
return api.UsmNDArray_GetStrides_(raw_ar);
809811
}
810812

811813
py::ssize_t get_size() const
812814
{
813815
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
814816

815-
auto &api = ::dpctl::detail::dpctl_capi::get();
817+
auto const &api = ::dpctl::detail::dpctl_capi::get();
816818
int ndim = api.UsmNDArray_GetNDim_(raw_ar);
817819
const py::ssize_t *shape = api.UsmNDArray_GetShape_(raw_ar);
818820

@@ -829,7 +831,7 @@ class usm_ndarray : public py::object
829831
{
830832
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
831833

832-
auto &api = ::dpctl::detail::dpctl_capi::get();
834+
auto const &api = ::dpctl::detail::dpctl_capi::get();
833835
int nd = api.UsmNDArray_GetNDim_(raw_ar);
834836
const py::ssize_t *shape = api.UsmNDArray_GetShape_(raw_ar);
835837
const py::ssize_t *strides = api.UsmNDArray_GetStrides_(raw_ar);
@@ -863,7 +865,7 @@ class usm_ndarray : public py::object
863865
{
864866
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
865867

866-
auto &api = ::dpctl::detail::dpctl_capi::get();
868+
auto const &api = ::dpctl::detail::dpctl_capi::get();
867869
DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar);
868870
return *(reinterpret_cast<sycl::queue *>(QRef));
869871
}
@@ -872,44 +874,44 @@ class usm_ndarray : public py::object
872874
{
873875
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
874876

875-
auto &api = ::dpctl::detail::dpctl_capi::get();
877+
auto const &api = ::dpctl::detail::dpctl_capi::get();
876878
return api.UsmNDArray_GetTypenum_(raw_ar);
877879
}
878880

879881
int get_flags() const
880882
{
881883
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
882884

883-
auto &api = ::dpctl::detail::dpctl_capi::get();
885+
auto const &api = ::dpctl::detail::dpctl_capi::get();
884886
return api.UsmNDArray_GetFlags_(raw_ar);
885887
}
886888

887889
int get_elemsize() const
888890
{
889891
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
890892

891-
auto &api = ::dpctl::detail::dpctl_capi::get();
893+
auto const &api = ::dpctl::detail::dpctl_capi::get();
892894
return api.UsmNDArray_GetElementSize_(raw_ar);
893895
}
894896

895897
bool is_c_contiguous() const
896898
{
897899
int flags = this->get_flags();
898-
auto &api = ::dpctl::detail::dpctl_capi::get();
900+
auto const &api = ::dpctl::detail::dpctl_capi::get();
899901
return static_cast<bool>(flags & api.USM_ARRAY_C_CONTIGUOUS_);
900902
}
901903

902904
bool is_f_contiguous() const
903905
{
904906
int flags = this->get_flags();
905-
auto &api = ::dpctl::detail::dpctl_capi::get();
907+
auto const &api = ::dpctl::detail::dpctl_capi::get();
906908
return static_cast<bool>(flags & api.USM_ARRAY_F_CONTIGUOUS_);
907909
}
908910

909911
bool is_writable() const
910912
{
911913
int flags = this->get_flags();
912-
auto &api = ::dpctl::detail::dpctl_capi::get();
914+
auto const &api = ::dpctl::detail::dpctl_capi::get();
913915
return static_cast<bool>(flags & api.USM_ARRAY_WRITABLE_);
914916
}
915917

0 commit comments

Comments
 (0)