@@ -60,6 +60,9 @@ cdef object _as_zero_dim_ndarray(object usm_ary):
60
60
view.shape = tuple ()
61
61
return view
62
62
63
+ cdef int _copy_writable(int lhs_flags, int rhs_flags):
64
+ " Copy the WRITABLE flag to lhs_flags from rhs_flags"
65
+ return (lhs_flags & ~ USM_ARRAY_WRITABLE) | (rhs_flags & USM_ARRAY_WRITABLE)
63
66
64
67
cdef class usm_ndarray:
65
68
""" usm_ndarray(shape, dtype=None, strides=None, buffer="device", \
@@ -546,7 +549,7 @@ cdef class usm_ndarray:
546
549
PyMem_Free(self .shape_)
547
550
if (self .strides_):
548
551
PyMem_Free(self .strides_)
549
- self .flags_ = contig_flag
552
+ self .flags_ = ( contig_flag | ( self .flags_ & USM_ARRAY_WRITABLE))
550
553
self .nd_ = new_nd
551
554
self .shape_ = shape_ptr
552
555
self .strides_ = strides_ptr
@@ -725,13 +728,13 @@ cdef class usm_ndarray:
725
728
buffer = self .base_,
726
729
offset = _meta[2 ]
727
730
)
728
- res.flags_ |= (self .flags_ & USM_ARRAY_WRITABLE)
729
731
res.array_namespace_ = self .array_namespace_
730
732
731
733
adv_ind = _meta[3 ]
732
734
adv_ind_start_p = _meta[4 ]
733
735
734
736
if adv_ind_start_p < 0 :
737
+ res.flags_ = _copy_writable(res.flags_, self .flags_)
735
738
return res
736
739
737
740
from ._copy_utils import _extract_impl, _nonzero_impl, _take_multi_index
@@ -749,6 +752,7 @@ cdef class usm_ndarray:
749
752
if not matching:
750
753
raise IndexError (" boolean index did not match indexed array in dimensions" )
751
754
res = _extract_impl(res, key_, axis = adv_ind_start_p)
755
+ res.flags_ = _copy_writable(res.flags_, self .flags_)
752
756
return res
753
757
754
758
if any (ind.dtype == dpt_bool for ind in adv_ind):
@@ -758,10 +762,13 @@ cdef class usm_ndarray:
758
762
adv_ind_int.extend(_nonzero_impl(ind))
759
763
else :
760
764
adv_ind_int.append(ind)
761
- return _take_multi_index(res, tuple (adv_ind_int), adv_ind_start_p)
762
-
763
- return _take_multi_index(res, adv_ind, adv_ind_start_p)
765
+ res = _take_multi_index(res, tuple (adv_ind_int), adv_ind_start_p)
766
+ res.flags_ = _copy_writable(res.flags_, self .flags_)
767
+ return res
764
768
769
+ res = _take_multi_index(res, adv_ind, adv_ind_start_p)
770
+ res.flags_ = _copy_writable(res.flags_, self .flags_)
771
+ return res
765
772
766
773
def to_device (self , target , stream = None ):
767
774
""" to_device(target_device)
@@ -1040,8 +1047,7 @@ cdef class usm_ndarray:
1040
1047
buffer = self .base_,
1041
1048
offset = _meta[2 ],
1042
1049
)
1043
- # set flags and namespace
1044
- Xv.flags_ |= (self .flags_ & USM_ARRAY_WRITABLE)
1050
+ # set namespace
1045
1051
Xv.array_namespace_ = self .array_namespace_
1046
1052
1047
1053
from ._copy_utils import (
@@ -1225,7 +1231,7 @@ cdef usm_ndarray _real_view(usm_ndarray ary):
1225
1231
offset = offset_elems,
1226
1232
order = (' C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else ' F' )
1227
1233
)
1228
- r.flags_ |= (ary .flags_ & USM_ARRAY_WRITABLE )
1234
+ r.flags_ = _copy_writable(r .flags_, ary.flags_ )
1229
1235
r.array_namespace_ = ary.array_namespace_
1230
1236
return r
1231
1237
@@ -1257,7 +1263,7 @@ cdef usm_ndarray _imag_view(usm_ndarray ary):
1257
1263
offset = offset_elems,
1258
1264
order = (' C' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else ' F' )
1259
1265
)
1260
- r.flags_ |= (ary .flags_ & USM_ARRAY_WRITABLE )
1266
+ r.flags_ = _copy_writable(r .flags_, ary.flags_ )
1261
1267
r.array_namespace_ = ary.array_namespace_
1262
1268
return r
1263
1269
@@ -1277,7 +1283,7 @@ cdef usm_ndarray _transpose(usm_ndarray ary):
1277
1283
order = (' F' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else ' C' ),
1278
1284
offset = ary.get_offset()
1279
1285
)
1280
- r.flags_ |= (ary .flags_ & USM_ARRAY_WRITABLE )
1286
+ r.flags_ = _copy_writable(r .flags_, ary.flags_ )
1281
1287
return r
1282
1288
1283
1289
@@ -1294,7 +1300,7 @@ cdef usm_ndarray _m_transpose(usm_ndarray ary):
1294
1300
order = (' F' if (ary.flags_ & USM_ARRAY_C_CONTIGUOUS) else ' C' ),
1295
1301
offset = ary.get_offset()
1296
1302
)
1297
- r.flags_ |= (ary .flags_ & USM_ARRAY_WRITABLE )
1303
+ r.flags_ = _copy_writable(r .flags_, ary.flags_ )
1298
1304
return r
1299
1305
1300
1306
0 commit comments