@@ -89,6 +89,7 @@ class dpctl_capi
89
89
90
90
// memory
91
91
DPCTLSyclUSMRef (* Memory_GetUsmPointer_ )(Py_MemoryObject * );
92
+ void * (* Memory_GetOpaquePointer_ )(Py_MemoryObject * );
92
93
DPCTLSyclContextRef (* Memory_GetContextRef_ )(Py_MemoryObject * );
93
94
DPCTLSyclQueueRef (* Memory_GetQueueRef_ )(Py_MemoryObject * );
94
95
size_t (* Memory_GetNumBytes_ )(Py_MemoryObject * );
@@ -115,6 +116,7 @@ class dpctl_capi
115
116
int (* UsmNDArray_GetFlags_ )(PyUSMArrayObject * );
116
117
DPCTLSyclQueueRef (* UsmNDArray_GetQueueRef_ )(PyUSMArrayObject * );
117
118
py ::ssize_t (* UsmNDArray_GetOffset_ )(PyUSMArrayObject * );
119
+ PyObject * (* UsmNDArray_GetUSMData_ )(PyUSMArrayObject * );
118
120
void (* UsmNDArray_SetWritableFlag_ )(PyUSMArrayObject * , int );
119
121
PyObject * (* UsmNDArray_MakeSimpleFromMemory_ )(int ,
120
122
const py ::ssize_t * ,
@@ -233,15 +235,16 @@ class dpctl_capi
233
235
SyclContext_Make_ (nullptr ), SyclEvent_GetEventRef_ (nullptr ),
234
236
SyclEvent_Make_ (nullptr ), SyclQueue_GetQueueRef_ (nullptr ),
235
237
SyclQueue_Make_ (nullptr ), Memory_GetUsmPointer_ (nullptr ),
236
- Memory_GetContextRef_ (nullptr ), Memory_GetQueueRef_ (nullptr ),
237
- Memory_GetNumBytes_ (nullptr ), Memory_Make_ (nullptr ),
238
- SyclKernel_GetKernelRef_ (nullptr ), SyclKernel_Make_ (nullptr ),
239
- SyclProgram_GetKernelBundleRef_ (nullptr ), SyclProgram_Make_ (nullptr ),
240
- UsmNDArray_GetData_ (nullptr ), UsmNDArray_GetNDim_ (nullptr ),
241
- UsmNDArray_GetShape_ (nullptr ), UsmNDArray_GetStrides_ (nullptr ),
242
- UsmNDArray_GetTypenum_ (nullptr ), UsmNDArray_GetElementSize_ (nullptr ),
243
- UsmNDArray_GetFlags_ (nullptr ), UsmNDArray_GetQueueRef_ (nullptr ),
244
- UsmNDArray_GetOffset_ (nullptr ), UsmNDArray_SetWritableFlag_ (nullptr ),
238
+ Memory_GetOpaquePointer_ (nullptr ), Memory_GetContextRef_ (nullptr ),
239
+ Memory_GetQueueRef_ (nullptr ), Memory_GetNumBytes_ (nullptr ),
240
+ Memory_Make_ (nullptr ), SyclKernel_GetKernelRef_ (nullptr ),
241
+ SyclKernel_Make_ (nullptr ), SyclProgram_GetKernelBundleRef_ (nullptr ),
242
+ SyclProgram_Make_ (nullptr ), UsmNDArray_GetData_ (nullptr ),
243
+ UsmNDArray_GetNDim_ (nullptr ), UsmNDArray_GetShape_ (nullptr ),
244
+ UsmNDArray_GetStrides_ (nullptr ), UsmNDArray_GetTypenum_ (nullptr ),
245
+ UsmNDArray_GetElementSize_ (nullptr ), UsmNDArray_GetFlags_ (nullptr ),
246
+ UsmNDArray_GetQueueRef_ (nullptr ), UsmNDArray_GetOffset_ (nullptr ),
247
+ UsmNDArray_GetUSMData_ (nullptr ), UsmNDArray_SetWritableFlag_ (nullptr ),
245
248
UsmNDArray_MakeSimpleFromMemory_ (nullptr ),
246
249
UsmNDArray_MakeSimpleFromPtr_ (nullptr ),
247
250
UsmNDArray_MakeFromPtr_ (nullptr ), USM_ARRAY_C_CONTIGUOUS_ (0 ),
@@ -299,6 +302,7 @@ class dpctl_capi
299
302
300
303
// dpctl.memory API
301
304
this -> Memory_GetUsmPointer_ = Memory_GetUsmPointer ;
305
+ this -> Memory_GetOpaquePointer_ = Memory_GetOpaquePointer ;
302
306
this -> Memory_GetContextRef_ = Memory_GetContextRef ;
303
307
this -> Memory_GetQueueRef_ = Memory_GetQueueRef ;
304
308
this -> Memory_GetNumBytes_ = Memory_GetNumBytes ;
@@ -320,6 +324,7 @@ class dpctl_capi
320
324
this -> UsmNDArray_GetFlags_ = UsmNDArray_GetFlags ;
321
325
this -> UsmNDArray_GetQueueRef_ = UsmNDArray_GetQueueRef ;
322
326
this -> UsmNDArray_GetOffset_ = UsmNDArray_GetOffset ;
327
+ this -> UsmNDArray_GetUSMData_ = UsmNDArray_GetUSMData ;
323
328
this -> UsmNDArray_SetWritableFlag_ = UsmNDArray_SetWritableFlag ;
324
329
this -> UsmNDArray_MakeSimpleFromMemory_ =
325
330
UsmNDArray_MakeSimpleFromMemory ;
@@ -779,6 +784,33 @@ class usm_memory : public py::object
779
784
return api .Memory_GetNumBytes_ (mem_obj );
780
785
}
781
786
787
+ bool is_managed_by_smart_ptr () const
788
+ {
789
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
790
+ Py_MemoryObject * mem_obj = reinterpret_cast < Py_MemoryObject * > (m_ptr );
791
+ const void * opaque_ptr = api .Memory_GetOpaquePointer_ (mem_obj );
792
+
793
+ return bool (opaque_ptr );
794
+ }
795
+
796
+ std ::shared_ptr < void > get_smart_ptr_owner () const
797
+ {
798
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
799
+ Py_MemoryObject * mem_obj = reinterpret_cast < Py_MemoryObject * > (m_ptr );
800
+ void * opaque_ptr = api .Memory_GetOpaquePointer_ (mem_obj );
801
+
802
+ if (opaque_ptr ) {
803
+ auto shptr_ptr =
804
+ reinterpret_cast < std ::shared_ptr < void > * > (opaque_ptr );
805
+ return * shptr_ptr ;
806
+ }
807
+ else {
808
+ throw std ::runtime_error (
809
+ "Memory object does not have smart pointer "
810
+ "managing lifetime of USM allocation" );
811
+ }
812
+ }
813
+
782
814
protected :
783
815
static PyObject * as_usm_memory (PyObject * o )
784
816
{
@@ -1065,6 +1097,63 @@ class usm_ndarray : public py::object
1065
1097
return static_cast < bool > (flags & api .USM_ARRAY_WRITABLE_ );
1066
1098
}
1067
1099
1100
+ py ::object get_usm_data () const
1101
+ {
1102
+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
1103
+
1104
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
1105
+ PyObject * usm_data = api .UsmNDArray_GetUSMData_ (raw_ar );
1106
+
1107
+ return py ::reinterpret_steal < py ::object > (usm_data );
1108
+ }
1109
+
1110
+ bool is_managed_by_smart_ptr () const
1111
+ {
1112
+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
1113
+
1114
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
1115
+ PyObject * usm_data = api .UsmNDArray_GetUSMData_ (raw_ar );
1116
+
1117
+ if (!PyObject_TypeCheck (usm_data , api .Py_MemoryType_ ))
1118
+ return false ;
1119
+
1120
+ Py_MemoryObject * mem_obj =
1121
+ reinterpret_cast < Py_MemoryObject * > (usm_data );
1122
+ const void * opaque_ptr = api .Memory_GetOpaquePointer_ (mem_obj );
1123
+
1124
+ return bool (opaque_ptr );
1125
+ }
1126
+
1127
+ std ::shared_ptr < void > get_smart_ptr_owner () const
1128
+ {
1129
+ PyUSMArrayObject * raw_ar = usm_array_ptr ();
1130
+
1131
+ auto const & api = ::dpctl ::detail ::dpctl_capi ::get ();
1132
+
1133
+ PyObject * usm_data = api .UsmNDArray_GetUSMData_ (raw_ar );
1134
+
1135
+ if (!PyObject_TypeCheck (usm_data , api .Py_MemoryType_ )) {
1136
+ throw std ::runtime_error (
1137
+ "usm_ndarray object does not have Memory object "
1138
+ "managing lifetime of USM allocation" );
1139
+ }
1140
+
1141
+ Py_MemoryObject * mem_obj =
1142
+ reinterpret_cast < Py_MemoryObject * > (usm_data );
1143
+ void * opaque_ptr = api .Memory_GetOpaquePointer_ (mem_obj );
1144
+
1145
+ if (opaque_ptr ) {
1146
+ auto shptr_ptr =
1147
+ reinterpret_cast < std ::shared_ptr < void > * > (opaque_ptr );
1148
+ return * shptr_ptr ;
1149
+ }
1150
+ else {
1151
+ throw std ::runtime_error (
1152
+ "Memory object underlying usm_ndarray does not have "
1153
+ "smart pointer managing lifetime of USM allocation" );
1154
+ }
1155
+ }
1156
+
1068
1157
private :
1069
1158
PyUSMArrayObject * usm_array_ptr () const
1070
1159
{
@@ -1077,26 +1166,107 @@ class usm_ndarray : public py::object
1077
1166
namespace utils
1078
1167
{
1079
1168
1169
+ namespace detail
1170
+ {
1171
+
1172
+ struct ManagedMemory
1173
+ {
1174
+
1175
+ static bool is_usm_managed_by_shared_ptr (const py ::handle & h )
1176
+ {
1177
+ if (py ::isinstance < dpctl ::memory ::usm_memory > (h )) {
1178
+ auto usm_memory_inst = py ::cast < dpctl ::memory ::usm_memory > (h );
1179
+ return usm_memory_inst .is_managed_by_smart_ptr ();
1180
+ }
1181
+ else if (py ::isinstance < dpctl ::tensor ::usm_ndarray > (h )) {
1182
+ auto usm_array_inst = py ::cast < dpctl ::tensor ::usm_ndarray > (h );
1183
+ return usm_array_inst .is_managed_by_smart_ptr ();
1184
+ }
1185
+
1186
+ return false;
1187
+ }
1188
+
1189
+ static std ::shared_ptr < void > extract_shared_ptr (const py ::handle & h )
1190
+ {
1191
+ if (py ::isinstance < dpctl ::memory ::usm_memory > (h )) {
1192
+ auto usm_memory_inst = py ::cast < dpctl ::memory ::usm_memory > (h );
1193
+ return usm_memory_inst .get_smart_ptr_owner ();
1194
+ }
1195
+ else if (py ::isinstance < dpctl ::tensor ::usm_ndarray > (h )) {
1196
+ auto usm_array_inst = py ::cast < dpctl ::tensor ::usm_ndarray > (h );
1197
+ return usm_array_inst .get_smart_ptr_owner ();
1198
+ }
1199
+
1200
+ throw std ::runtime_error (
1201
+ "Attempted extraction of shared_ptr on an unrecognized type" );
1202
+ }
1203
+ };
1204
+
1205
+ } // end of namespace detail
1206
+
1080
1207
template < std ::size_t num >
1081
1208
sycl ::event keep_args_alive (sycl ::queue & q ,
1082
1209
const py ::object (& py_objs )[num ],
1083
1210
const std ::vector < sycl ::event > & depends = {})
1084
1211
{
1085
- sycl ::event host_task_ev = q .submit ([& ](sycl ::handler & cgh ) {
1086
- cgh .depends_on (depends );
1087
- std ::array < std ::shared_ptr < py ::handle > , num > shp_arr ;
1088
- for (std ::size_t i = 0 ; i < num ; ++ i ) {
1089
- shp_arr [i ] = std ::make_shared < py ::handle > (py_objs [i ]);
1090
- shp_arr [i ]-> inc_ref ();
1212
+ std ::size_t n_objects_held = 0 ;
1213
+ std ::array < std ::shared_ptr < py ::handle > , num > shp_arr {};
1214
+
1215
+ std ::size_t n_usm_owners_held = 0 ;
1216
+ std ::array < std ::shared_ptr < void > , num > shp_usm {};
1217
+
1218
+ for (std ::size_t i = 0 ; i < num ; ++ i ) {
1219
+ auto py_obj_i = py_objs [i ];
1220
+ if (detail ::ManagedMemory ::is_usm_managed_by_shared_ptr (py_obj_i )) {
1221
+ shp_usm [n_usm_owners_held ] =
1222
+ detail ::ManagedMemory ::extract_shared_ptr (py_obj_i );
1223
+ ++ n_usm_owners_held ;
1091
1224
}
1092
- cgh .host_task ([shp_arr = std ::move (shp_arr )]() {
1093
- py ::gil_scoped_acquire acquire ;
1225
+ else {
1226
+ shp_arr [n_objects_held ] = std ::make_shared < py ::handle > (py_obj_i );
1227
+ shp_arr [n_objects_held ]-> inc_ref ();
1228
+ ++ n_objects_held ;
1229
+ }
1230
+ }
1231
+
1232
+ bool use_depends = true;
1233
+ sycl ::event host_task_ev ;
1234
+
1235
+ if (n_usm_owners_held > 0 ) {
1236
+ host_task_ev = q .submit ([& ](sycl ::handler & cgh ) {
1237
+ if (use_depends ) {
1238
+ cgh .depends_on (depends );
1239
+ use_depends = false;
1240
+ }
1241
+ else {
1242
+ cgh .depends_on (host_task_ev );
1243
+ }
1244
+ cgh .host_task ([shp_usm = std ::move (shp_usm )]() {
1245
+ // no body, but shared pointers are captured in
1246
+ // the lamba, ensuring that USM allocation is
1247
+ // kept alive
1248
+ });
1249
+ });
1250
+ }
1094
1251
1095
- for (std ::size_t i = 0 ; i < num ; ++ i ) {
1096
- shp_arr [i ]-> dec_ref ();
1252
+ if (n_objects_held > 0 ) {
1253
+ host_task_ev = q .submit ([& ](sycl ::handler & cgh ) {
1254
+ if (use_depends ) {
1255
+ cgh .depends_on (depends );
1256
+ use_depends = false;
1097
1257
}
1258
+ else {
1259
+ cgh .depends_on (host_task_ev );
1260
+ }
1261
+ cgh .host_task ([n_objects_held , shp_arr = std ::move (shp_arr )]() {
1262
+ py ::gil_scoped_acquire acquire ;
1263
+
1264
+ for (std ::size_t i = 0 ; i < n_objects_held ; ++ i ) {
1265
+ shp_arr [i ]-> dec_ref ();
1266
+ }
1267
+ });
1098
1268
});
1099
- });
1269
+ }
1100
1270
1101
1271
return host_task_ev ;
1102
1272
}
0 commit comments