@@ -248,7 +248,9 @@ struct dpctl_capi
248
248
this -> Memory_Make_ = Memory_Make ;
249
249
250
250
// dpctl.program API
251
+ this -> SyclKernel_GetKernelRef_ = SyclKernel_GetKernelRef ;
251
252
this -> SyclKernel_Make_ = SyclKernel_Make ;
253
+ this -> SyclProgram_GetKernelBundleRef_ = SyclProgram_GetKernelBundleRef ;
252
254
this -> SyclProgram_Make_ = SyclProgram_Make ;
253
255
254
256
// dpctl.tensor.usm_ndarray API
@@ -403,7 +405,7 @@ template <> struct type_caster<sycl::queue>
403
405
bool load (handle src , bool )
404
406
{
405
407
PyObject * source = src .ptr ();
406
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
408
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
407
409
if (api .PySyclQueue_Check_ (source )) {
408
410
DPCTLSyclQueueRef QRef = api .SyclQueue_GetQueueRef_ (
409
411
reinterpret_cast < PySyclQueueObject * > (source ));
@@ -419,7 +421,7 @@ template <> struct type_caster<sycl::queue>
419
421
420
422
static handle cast (sycl ::queue src , return_value_policy , handle )
421
423
{
422
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
424
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
423
425
auto tmp =
424
426
api .SyclQueue_Make_ (reinterpret_cast < DPCTLSyclQueueRef > (& src ));
425
427
return handle (reinterpret_cast < PyObject * > (tmp ));
@@ -438,7 +440,7 @@ template <> struct type_caster<sycl::device>
438
440
bool load (handle src , bool )
439
441
{
440
442
PyObject * source = src .ptr ();
441
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
443
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
442
444
if (api .PySyclDevice_Check_ (source )) {
443
445
DPCTLSyclDeviceRef DRef = api .SyclDevice_GetDeviceRef_ (
444
446
reinterpret_cast < PySyclDeviceObject * > (source ));
@@ -454,7 +456,7 @@ template <> struct type_caster<sycl::device>
454
456
455
457
static handle cast (sycl ::device src , return_value_policy , handle )
456
458
{
457
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
459
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
458
460
auto tmp =
459
461
api .SyclDevice_Make_ (reinterpret_cast < DPCTLSyclDeviceRef > (& src ));
460
462
return handle (reinterpret_cast < PyObject * > (tmp ));
@@ -473,7 +475,7 @@ template <> struct type_caster<sycl::context>
473
475
bool load (handle src , bool )
474
476
{
475
477
PyObject * source = src .ptr ();
476
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
478
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
477
479
if (api .PySyclContext_Check_ (source )) {
478
480
DPCTLSyclContextRef CRef = api .SyclContext_GetContextRef_ (
479
481
reinterpret_cast < PySyclContextObject * > (source ));
@@ -489,7 +491,7 @@ template <> struct type_caster<sycl::context>
489
491
490
492
static handle cast (sycl ::context src , return_value_policy , handle )
491
493
{
492
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
494
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
493
495
auto tmp =
494
496
api .SyclContext_Make_ (reinterpret_cast < DPCTLSyclContextRef > (& src ));
495
497
return handle (reinterpret_cast < PyObject * > (tmp ));
@@ -508,7 +510,7 @@ template <> struct type_caster<sycl::event>
508
510
bool load (handle src , bool )
509
511
{
510
512
PyObject * source = src .ptr ();
511
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
513
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
512
514
if (api .PySyclEvent_Check_ (source )) {
513
515
DPCTLSyclEventRef ERef = api .SyclEvent_GetEventRef_ (
514
516
reinterpret_cast < PySyclEventObject * > (source ));
@@ -524,7 +526,7 @@ template <> struct type_caster<sycl::event>
524
526
525
527
static handle cast (sycl ::event src , return_value_policy , handle )
526
528
{
527
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
529
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
528
530
auto tmp =
529
531
api .SyclEvent_Make_ (reinterpret_cast < DPCTLSyclEventRef > (& src ));
530
532
return handle (reinterpret_cast < PyObject * > (tmp ));
@@ -543,7 +545,7 @@ template <> struct type_caster<sycl::kernel>
543
545
bool load (handle src , bool )
544
546
{
545
547
PyObject * source = src .ptr ();
546
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
548
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
547
549
if (api .PySyclKernel_Check_ (source )) {
548
550
DPCTLSyclKernelRef KRef = api .SyclKernel_GetKernelRef_ (
549
551
reinterpret_cast < PySyclKernelObject * > (source ));
@@ -559,7 +561,7 @@ template <> struct type_caster<sycl::kernel>
559
561
560
562
static handle cast (sycl ::kernel src , return_value_policy , handle )
561
563
{
562
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
564
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
563
565
auto tmp =
564
566
api .SyclKernel_Make_ (reinterpret_cast < DPCTLSyclKernelRef > (& src ),
565
567
"dpctl4pybind11_kernel" );
@@ -581,7 +583,7 @@ struct type_caster<sycl::kernel_bundle<sycl::bundle_state::executable>>
581
583
bool load (handle src , bool )
582
584
{
583
585
PyObject * source = src .ptr ();
584
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
586
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
585
587
if (api .PySyclProgram_Check_ (source )) {
586
588
DPCTLSyclKernelBundleRef KBRef =
587
589
api .SyclProgram_GetKernelBundleRef_ (
@@ -603,7 +605,7 @@ struct type_caster<sycl::kernel_bundle<sycl::bundle_state::executable>>
603
605
return_value_policy ,
604
606
handle )
605
607
{
606
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
608
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
607
609
auto tmp = api .SyclProgram_Make_ (
608
610
reinterpret_cast < DPCTLSyclKernelBundleRef > (& src ));
609
611
return handle (reinterpret_cast < PyObject * > (tmp ));
@@ -650,7 +652,7 @@ class usm_memory : public py::object
650
652
sycl ::queue get_queue () const
651
653
{
652
654
Py_MemoryObject * mem_obj = reinterpret_cast < Py_MemoryObject * > (m_ptr );
653
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
655
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
654
656
DPCTLSyclQueueRef QRef = api .Memory_GetQueueRef_ (mem_obj );
655
657
sycl ::queue * obj_q = reinterpret_cast < sycl ::queue * > (QRef );
656
658
return * obj_q ;
@@ -659,14 +661,14 @@ class usm_memory : public py::object
659
661
char * get_pointer () const
660
662
{
661
663
Py_MemoryObject * mem_obj = reinterpret_cast < Py_MemoryObject * > (m_ptr );
662
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
664
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
663
665
DPCTLSyclUSMRef MRef = api .Memory_GetUsmPointer_ (mem_obj );
664
666
return reinterpret_cast < char * > (MRef );
665
667
}
666
668
667
669
size_t get_nbytes () const
668
670
{
669
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
671
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
670
672
Py_MemoryObject * mem_obj = reinterpret_cast < Py_MemoryObject * > (m_ptr );
671
673
return api .Memory_GetNumBytes_ (mem_obj );
672
674
}
@@ -769,7 +771,7 @@ class usm_ndarray : public py::object
769
771
{
770
772
PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
771
773
772
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
774
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
773
775
return api .UsmNDArray_GetData_ (raw_ar );
774
776
}
775
777
@@ -782,15 +784,15 @@ class usm_ndarray : public py::object
782
784
{
783
785
PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
784
786
785
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
787
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
786
788
return api .UsmNDArray_GetNDim_ (raw_ar );
787
789
}
788
790
789
791
const py ::ssize_t * get_shape_raw () const
790
792
{
791
793
PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
792
794
793
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
795
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
794
796
return api .UsmNDArray_GetShape_ (raw_ar );
795
797
}
796
798
@@ -804,15 +806,15 @@ class usm_ndarray : public py::object
804
806
{
805
807
PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
806
808
807
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
809
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
808
810
return api .UsmNDArray_GetStrides_ (raw_ar );
809
811
}
810
812
811
813
py ::ssize_t get_size () const
812
814
{
813
815
PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
814
816
815
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
817
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
816
818
int ndim = api .UsmNDArray_GetNDim_ (raw_ar );
817
819
const py ::ssize_t * shape = api .UsmNDArray_GetShape_ (raw_ar );
818
820
@@ -829,7 +831,7 @@ class usm_ndarray : public py::object
829
831
{
830
832
PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
831
833
832
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
834
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
833
835
int nd = api .UsmNDArray_GetNDim_ (raw_ar );
834
836
const py ::ssize_t * shape = api .UsmNDArray_GetShape_ (raw_ar );
835
837
const py ::ssize_t * strides = api .UsmNDArray_GetStrides_ (raw_ar );
@@ -863,7 +865,7 @@ class usm_ndarray : public py::object
863
865
{
864
866
PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
865
867
866
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
868
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
867
869
DPCTLSyclQueueRef QRef = api .UsmNDArray_GetQueueRef_ (raw_ar );
868
870
return * (reinterpret_cast < sycl ::queue * > (QRef ));
869
871
}
@@ -872,44 +874,44 @@ class usm_ndarray : public py::object
872
874
{
873
875
PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
874
876
875
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
877
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
876
878
return api .UsmNDArray_GetTypenum_ (raw_ar );
877
879
}
878
880
879
881
int get_flags () const
880
882
{
881
883
PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
882
884
883
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
885
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
884
886
return api .UsmNDArray_GetFlags_ (raw_ar );
885
887
}
886
888
887
889
int get_elemsize () const
888
890
{
889
891
PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
890
892
891
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
893
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
892
894
return api .UsmNDArray_GetElementSize_ (raw_ar );
893
895
}
894
896
895
897
bool is_c_contiguous () const
896
898
{
897
899
int flags = this -> get_flags ();
898
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
900
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
899
901
return static_cast < bool > (flags & api .USM_ARRAY_C_CONTIGUOUS_ );
900
902
}
901
903
902
904
bool is_f_contiguous () const
903
905
{
904
906
int flags = this -> get_flags ();
905
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
907
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
906
908
return static_cast < bool > (flags & api .USM_ARRAY_F_CONTIGUOUS_ );
907
909
}
908
910
909
911
bool is_writable () const
910
912
{
911
913
int flags = this -> get_flags ();
912
- auto & api = ::dpctl ::detail ::dpctl_capi ::get ();
914
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
913
915
return static_cast < bool > (flags & api .USM_ARRAY_WRITABLE_ );
914
916
}
915
917
0 commit comments