@@ -829,7 +829,7 @@ class usm_ndarray : public py::object
829
829
830
830
char * get_data () const
831
831
{
832
- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
832
+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
833
833
834
834
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
835
835
return api .UsmNDArray_GetData_ (raw_ar );
@@ -842,20 +842,29 @@ class usm_ndarray : public py::object
842
842
843
843
int get_ndim () const
844
844
{
845
- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
845
+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
846
846
847
847
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
848
848
return api .UsmNDArray_GetNDim_ (raw_ar );
849
849
}
850
850
851
851
const py ::ssize_t * get_shape_raw () const
852
852
{
853
- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
853
+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
854
854
855
855
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
856
856
return api .UsmNDArray_GetShape_ (raw_ar );
857
857
}
858
858
859
+ std ::vector < py ::ssize_t > get_shape_vector () const
860
+ {
861
+ auto raw_sh = get_shape_raw ();
862
+ auto nd = get_ndim ();
863
+
864
+ std ::vector < py ::ssize_t > shape_vector (raw_sh , raw_sh + nd );
865
+ return shape_vector ;
866
+ }
867
+
859
868
py ::ssize_t get_shape (int i ) const
860
869
{
861
870
auto shape_ptr = get_shape_raw ();
@@ -864,15 +873,43 @@ class usm_ndarray : public py::object
864
873
865
874
const py ::ssize_t * get_strides_raw () const
866
875
{
867
- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
876
+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
868
877
869
878
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
870
879
return api .UsmNDArray_GetStrides_ (raw_ar );
871
880
}
872
881
882
+ std ::vector < py ::ssize_t > get_strides_vector () const
883
+ {
884
+ auto raw_st = get_strides_raw ();
885
+ auto nd = get_ndim ();
886
+
887
+ if (raw_st == nullptr ) {
888
+ auto is_c_contig = is_c_contiguous ();
889
+ auto is_f_contig = is_f_contiguous ();
890
+ auto raw_sh = get_shape_raw ();
891
+ if (is_c_contig ) {
892
+ const auto & contig_strides = c_contiguous_strides (nd , raw_sh );
893
+ return contig_strides ;
894
+ }
895
+ else if (is_f_contig ) {
896
+ const auto & contig_strides = f_contiguous_strides (nd , raw_sh );
897
+ return contig_strides ;
898
+ }
899
+ else {
900
+ throw std ::runtime_error ("Invalid array encountered when "
901
+ "building strides" );
902
+ }
903
+ }
904
+ else {
905
+ std ::vector < py ::ssize_t > st_vec (raw_st , raw_st + nd );
906
+ return st_vec ;
907
+ }
908
+ }
909
+
873
910
py ::ssize_t get_size () const
874
911
{
875
- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
912
+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
876
913
877
914
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
878
915
int ndim = api .UsmNDArray_GetNDim_ (raw_ar );
@@ -889,7 +926,7 @@ class usm_ndarray : public py::object
889
926
890
927
std ::pair < py ::ssize_t , py ::ssize_t > get_minmax_offsets () const
891
928
{
892
- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
929
+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
893
930
894
931
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
895
932
int nd = api .UsmNDArray_GetNDim_ (raw_ar );
@@ -906,8 +943,6 @@ class usm_ndarray : public py::object
906
943
}
907
944
}
908
945
else {
909
- offset_min = api .UsmNDArray_GetOffset_ (raw_ar );
910
- offset_max = offset_min ;
911
946
for (int i = 0 ; i < nd ; ++ i ) {
912
947
py ::ssize_t delta = strides [i ] * (shape [i ] - 1 );
913
948
if (strides [i ] > 0 ) {
@@ -923,7 +958,7 @@ class usm_ndarray : public py::object
923
958
924
959
sycl ::queue get_queue () const
925
960
{
926
- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
961
+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
927
962
928
963
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
929
964
DPCTLSyclQueueRef QRef = api .UsmNDArray_GetQueueRef_ (raw_ar );
@@ -932,45 +967,45 @@ class usm_ndarray : public py::object
932
967
933
968
int get_typenum () const
934
969
{
935
- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
970
+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
936
971
937
972
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
938
973
return api .UsmNDArray_GetTypenum_ (raw_ar );
939
974
}
940
975
941
976
int get_flags () const
942
977
{
943
- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
978
+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
944
979
945
980
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
946
981
return api .UsmNDArray_GetFlags_ (raw_ar );
947
982
}
948
983
949
984
int get_elemsize () const
950
985
{
951
- PyUSMArrayObject * raw_ar = this -> usm_array_ptr ();
986
+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
952
987
953
988
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
954
989
return api .UsmNDArray_GetElementSize_ (raw_ar );
955
990
}
956
991
957
992
bool is_c_contiguous () const
958
993
{
959
- int flags = this -> get_flags ();
994
+ int flags = get_flags ();
960
995
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
961
996
return static_cast < bool > (flags & api .USM_ARRAY_C_CONTIGUOUS_ );
962
997
}
963
998
964
999
bool is_f_contiguous () const
965
1000
{
966
- int flags = this -> get_flags ();
1001
+ int flags = get_flags ();
967
1002
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
968
1003
return static_cast < bool > (flags & api .USM_ARRAY_F_CONTIGUOUS_ );
969
1004
}
970
1005
971
1006
bool is_writable () const
972
1007
{
973
- int flags = this -> get_flags ();
1008
+ int flags = get_flags ();
974
1009
auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
975
1010
return static_cast < bool > (flags & api .USM_ARRAY_WRITABLE_ );
976
1011
}
0 commit comments