Skip to content

Commit 138a023

Browse files
committed
Merge branch 'master' into take-put-impl
2 parents d47fbf0 + 6c629a8 commit 138a023

File tree

6 files changed

+65
-26
lines changed

6 files changed

+65
-26
lines changed

.github/workflows/generate-docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ jobs:
4949
if: ${{ !github.event.pull_request || github.event.action != 'closed' }}
5050
shell: bash -l {0}
5151
run: |
52-
pip install numpy cython setuptools scikit-build cmake sphinx sphinx_rtd_theme pydot graphviz sphinxcontrib-programoutput
52+
pip install numpy cython setuptools scikit-build cmake sphinx sphinx_rtd_theme pydot graphviz sphinxcontrib-programoutput sphinxcontrib-googleanalytics
5353
- name: Checkout repo
5454
uses: actions/checkout@v3
5555
with:

docs/conf.in

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,12 @@ extensions = [
6161
"sphinx.ext.todo",
6262
"sphinx.ext.viewcode",
6363
"sphinxcontrib.programoutput",
64+
"sphinxcontrib.googleanalytics",
6465
]
6566

67+
googleanalytics_id = 'G-7TCKS5BHYE'
68+
googleanalytics_enabled = True
69+
6670
todo_include_todos = True
6771
use_doxyrest = "@DPCTL_ENABLE_DOXYREST@"
6872

docs/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@ cython
33
setuptools
44
sphinx
55
sphinx_rtd_theme
6+
sphinxcontrib-googleanalytics
67
pydot
78
graphviz

dpctl/apis/include/dpctl4pybind11.hpp

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,7 @@ class usm_ndarray : public py::object
829829

830830
char *get_data() const
831831
{
832-
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
832+
PyUSMArrayObject *raw_ar = usm_array_ptr();
833833

834834
auto const &api = ::dpctl::detail::dpctl_capi::get();
835835
return api.UsmNDArray_GetData_(raw_ar);
@@ -842,20 +842,29 @@ class usm_ndarray : public py::object
842842

843843
int get_ndim() const
844844
{
845-
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
845+
PyUSMArrayObject *raw_ar = usm_array_ptr();
846846

847847
auto const &api = ::dpctl::detail::dpctl_capi::get();
848848
return api.UsmNDArray_GetNDim_(raw_ar);
849849
}
850850

851851
const py::ssize_t *get_shape_raw() const
852852
{
853-
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
853+
PyUSMArrayObject *raw_ar = usm_array_ptr();
854854

855855
auto const &api = ::dpctl::detail::dpctl_capi::get();
856856
return api.UsmNDArray_GetShape_(raw_ar);
857857
}
858858

859+
std::vector<py::ssize_t> get_shape_vector() const
860+
{
861+
auto raw_sh = get_shape_raw();
862+
auto nd = get_ndim();
863+
864+
std::vector<py::ssize_t> shape_vector(raw_sh, raw_sh + nd);
865+
return shape_vector;
866+
}
867+
859868
py::ssize_t get_shape(int i) const
860869
{
861870
auto shape_ptr = get_shape_raw();
@@ -864,15 +873,43 @@ class usm_ndarray : public py::object
864873

865874
const py::ssize_t *get_strides_raw() const
866875
{
867-
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
876+
PyUSMArrayObject *raw_ar = usm_array_ptr();
868877

869878
auto const &api = ::dpctl::detail::dpctl_capi::get();
870879
return api.UsmNDArray_GetStrides_(raw_ar);
871880
}
872881

882+
std::vector<py::ssize_t> get_strides_vector() const
883+
{
884+
auto raw_st = get_strides_raw();
885+
auto nd = get_ndim();
886+
887+
if (raw_st == nullptr) {
888+
auto is_c_contig = is_c_contiguous();
889+
auto is_f_contig = is_f_contiguous();
890+
auto raw_sh = get_shape_raw();
891+
if (is_c_contig) {
892+
const auto &contig_strides = c_contiguous_strides(nd, raw_sh);
893+
return contig_strides;
894+
}
895+
else if (is_f_contig) {
896+
const auto &contig_strides = f_contiguous_strides(nd, raw_sh);
897+
return contig_strides;
898+
}
899+
else {
900+
throw std::runtime_error("Invalid array encountered when "
901+
"building strides");
902+
}
903+
}
904+
else {
905+
std::vector<py::ssize_t> st_vec(raw_st, raw_st + nd);
906+
return st_vec;
907+
}
908+
}
909+
873910
py::ssize_t get_size() const
874911
{
875-
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
912+
PyUSMArrayObject *raw_ar = usm_array_ptr();
876913

877914
auto const &api = ::dpctl::detail::dpctl_capi::get();
878915
int ndim = api.UsmNDArray_GetNDim_(raw_ar);
@@ -889,7 +926,7 @@ class usm_ndarray : public py::object
889926

890927
std::pair<py::ssize_t, py::ssize_t> get_minmax_offsets() const
891928
{
892-
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
929+
PyUSMArrayObject *raw_ar = usm_array_ptr();
893930

894931
auto const &api = ::dpctl::detail::dpctl_capi::get();
895932
int nd = api.UsmNDArray_GetNDim_(raw_ar);
@@ -906,8 +943,6 @@ class usm_ndarray : public py::object
906943
}
907944
}
908945
else {
909-
offset_min = api.UsmNDArray_GetOffset_(raw_ar);
910-
offset_max = offset_min;
911946
for (int i = 0; i < nd; ++i) {
912947
py::ssize_t delta = strides[i] * (shape[i] - 1);
913948
if (strides[i] > 0) {
@@ -923,7 +958,7 @@ class usm_ndarray : public py::object
923958

924959
sycl::queue get_queue() const
925960
{
926-
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
961+
PyUSMArrayObject *raw_ar = usm_array_ptr();
927962

928963
auto const &api = ::dpctl::detail::dpctl_capi::get();
929964
DPCTLSyclQueueRef QRef = api.UsmNDArray_GetQueueRef_(raw_ar);
@@ -932,45 +967,45 @@ class usm_ndarray : public py::object
932967

933968
int get_typenum() const
934969
{
935-
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
970+
PyUSMArrayObject *raw_ar = usm_array_ptr();
936971

937972
auto const &api = ::dpctl::detail::dpctl_capi::get();
938973
return api.UsmNDArray_GetTypenum_(raw_ar);
939974
}
940975

941976
int get_flags() const
942977
{
943-
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
978+
PyUSMArrayObject *raw_ar = usm_array_ptr();
944979

945980
auto const &api = ::dpctl::detail::dpctl_capi::get();
946981
return api.UsmNDArray_GetFlags_(raw_ar);
947982
}
948983

949984
int get_elemsize() const
950985
{
951-
PyUSMArrayObject *raw_ar = this->usm_array_ptr();
986+
PyUSMArrayObject *raw_ar = usm_array_ptr();
952987

953988
auto const &api = ::dpctl::detail::dpctl_capi::get();
954989
return api.UsmNDArray_GetElementSize_(raw_ar);
955990
}
956991

957992
bool is_c_contiguous() const
958993
{
959-
int flags = this->get_flags();
994+
int flags = get_flags();
960995
auto const &api = ::dpctl::detail::dpctl_capi::get();
961996
return static_cast<bool>(flags & api.USM_ARRAY_C_CONTIGUOUS_);
962997
}
963998

964999
bool is_f_contiguous() const
9651000
{
966-
int flags = this->get_flags();
1001+
int flags = get_flags();
9671002
auto const &api = ::dpctl::detail::dpctl_capi::get();
9681003
return static_cast<bool>(flags & api.USM_ARRAY_F_CONTIGUOUS_);
9691004
}
9701005

9711006
bool is_writable() const
9721007
{
973-
int flags = this->get_flags();
1008+
int flags = get_flags();
9741009
auto const &api = ::dpctl::detail::dpctl_capi::get();
9751010
return static_cast<bool>(flags & api.USM_ARRAY_WRITABLE_);
9761011
}

dpctl/apis/include/dpctl_sycl_interface.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,4 +40,6 @@
4040
#include "syclinterface/dpctl_sycl_device_manager.h"
4141
#include "syclinterface/dpctl_sycl_platform_manager.h"
4242
#include "syclinterface/dpctl_sycl_queue_manager.h"
43+
#include "syclinterface/dpctl_sycl_kernel_bundle_interface.h"
44+
#include "syclinterface/dpctl_sycl_kernel_interface.h"
4345
// clang-format on

libsyclinterface/source/dpctl_sycl_queue_interface.cpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ static_assert(__SYCL_COMPILER_VERSION >= __SYCL_COMPILER_VERSION_REQUIRED,
4444

4545
using namespace dpctl::syclinterface;
4646

47+
typedef struct complex
48+
{
49+
uint64_t real;
50+
uint64_t imag;
51+
} complexNumber;
52+
4753
/*!
4854
* @brief Set the kernel arg object
4955
*
@@ -735,15 +741,6 @@ DPCTLQueue_Fill64(__dpctl_keep const DPCTLSyclQueueRef QRef,
735741
}
736742
}
737743

738-
namespace
739-
{
740-
typedef struct complex
741-
{
742-
uint64_t real;
743-
uint64_t imag;
744-
} coplexNumber;
745-
} // namespace
746-
747744
__dpctl_give DPCTLSyclEventRef
748745
DPCTLQueue_Fill128(__dpctl_keep const DPCTLSyclQueueRef QRef,
749746
void *USMRef,
@@ -754,7 +751,7 @@ DPCTLQueue_Fill128(__dpctl_keep const DPCTLSyclQueueRef QRef,
754751
if (Q && USMRef) {
755752
sycl::event ev;
756753
try {
757-
coplexNumber Val;
754+
complexNumber Val;
758755
Val.real = Value[0];
759756
Val.imag = Value[1];
760757
ev = Q->fill(USMRef, Val, Count);

0 commit comments

Comments
 (0)