Skip to content

Commit dc77858

Browse files
Deploy using shared pointers from keep_args_alive
1 parent a51973c commit dc77858

File tree

2 files changed

+195
-20
lines changed

2 files changed

+195
-20
lines changed

dpctl/apis/include/dpctl4pybind11.hpp

Lines changed: 190 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ class dpctl_capi
8989

9090
// memory
9191
DPCTLSyclUSMRef (*Memory_GetUsmPointer_)(Py_MemoryObject *);
92+
void *(*Memory_GetOpaquePointer_)(Py_MemoryObject *);
9293
DPCTLSyclContextRef (*Memory_GetContextRef_)(Py_MemoryObject *);
9394
DPCTLSyclQueueRef (*Memory_GetQueueRef_)(Py_MemoryObject *);
9495
size_t (*Memory_GetNumBytes_)(Py_MemoryObject *);
@@ -115,6 +116,7 @@ class dpctl_capi
115116
int (*UsmNDArray_GetFlags_)(PyUSMArrayObject *);
116117
DPCTLSyclQueueRef (*UsmNDArray_GetQueueRef_)(PyUSMArrayObject *);
117118
py::ssize_t (*UsmNDArray_GetOffset_)(PyUSMArrayObject *);
119+
PyObject *(*UsmNDArray_GetUSMData_)(PyUSMArrayObject *);
118120
void (*UsmNDArray_SetWritableFlag_)(PyUSMArrayObject *, int);
119121
PyObject *(*UsmNDArray_MakeSimpleFromMemory_)(int,
120122
const py::ssize_t *,
@@ -233,15 +235,16 @@ class dpctl_capi
233235
SyclContext_Make_(nullptr), SyclEvent_GetEventRef_(nullptr),
234236
SyclEvent_Make_(nullptr), SyclQueue_GetQueueRef_(nullptr),
235237
SyclQueue_Make_(nullptr), Memory_GetUsmPointer_(nullptr),
236-
Memory_GetContextRef_(nullptr), Memory_GetQueueRef_(nullptr),
237-
Memory_GetNumBytes_(nullptr), Memory_Make_(nullptr),
238-
SyclKernel_GetKernelRef_(nullptr), SyclKernel_Make_(nullptr),
239-
SyclProgram_GetKernelBundleRef_(nullptr), SyclProgram_Make_(nullptr),
240-
UsmNDArray_GetData_(nullptr), UsmNDArray_GetNDim_(nullptr),
241-
UsmNDArray_GetShape_(nullptr), UsmNDArray_GetStrides_(nullptr),
242-
UsmNDArray_GetTypenum_(nullptr), UsmNDArray_GetElementSize_(nullptr),
243-
UsmNDArray_GetFlags_(nullptr), UsmNDArray_GetQueueRef_(nullptr),
244-
UsmNDArray_GetOffset_(nullptr), UsmNDArray_SetWritableFlag_(nullptr),
238+
Memory_GetOpaquePointer_(nullptr), Memory_GetContextRef_(nullptr),
239+
Memory_GetQueueRef_(nullptr), Memory_GetNumBytes_(nullptr),
240+
Memory_Make_(nullptr), SyclKernel_GetKernelRef_(nullptr),
241+
SyclKernel_Make_(nullptr), SyclProgram_GetKernelBundleRef_(nullptr),
242+
SyclProgram_Make_(nullptr), UsmNDArray_GetData_(nullptr),
243+
UsmNDArray_GetNDim_(nullptr), UsmNDArray_GetShape_(nullptr),
244+
UsmNDArray_GetStrides_(nullptr), UsmNDArray_GetTypenum_(nullptr),
245+
UsmNDArray_GetElementSize_(nullptr), UsmNDArray_GetFlags_(nullptr),
246+
UsmNDArray_GetQueueRef_(nullptr), UsmNDArray_GetOffset_(nullptr),
247+
UsmNDArray_GetUSMData_(nullptr), UsmNDArray_SetWritableFlag_(nullptr),
245248
UsmNDArray_MakeSimpleFromMemory_(nullptr),
246249
UsmNDArray_MakeSimpleFromPtr_(nullptr),
247250
UsmNDArray_MakeFromPtr_(nullptr), USM_ARRAY_C_CONTIGUOUS_(0),
@@ -299,6 +302,7 @@ class dpctl_capi
299302

300303
// dpctl.memory API
301304
this->Memory_GetUsmPointer_ = Memory_GetUsmPointer;
305+
this->Memory_GetOpaquePointer_ = Memory_GetOpaquePointer;
302306
this->Memory_GetContextRef_ = Memory_GetContextRef;
303307
this->Memory_GetQueueRef_ = Memory_GetQueueRef;
304308
this->Memory_GetNumBytes_ = Memory_GetNumBytes;
@@ -320,6 +324,7 @@ class dpctl_capi
320324
this->UsmNDArray_GetFlags_ = UsmNDArray_GetFlags;
321325
this->UsmNDArray_GetQueueRef_ = UsmNDArray_GetQueueRef;
322326
this->UsmNDArray_GetOffset_ = UsmNDArray_GetOffset;
327+
this->UsmNDArray_GetUSMData_ = UsmNDArray_GetUSMData;
323328
this->UsmNDArray_SetWritableFlag_ = UsmNDArray_SetWritableFlag;
324329
this->UsmNDArray_MakeSimpleFromMemory_ =
325330
UsmNDArray_MakeSimpleFromMemory;
@@ -779,6 +784,33 @@ class usm_memory : public py::object
779784
return api.Memory_GetNumBytes_(mem_obj);
780785
}
781786

787+
bool is_managed_by_smart_ptr() const
788+
{
789+
auto const &api = ::dpctl::detail::dpctl_capi::get();
790+
Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
791+
const void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
792+
793+
return bool(opaque_ptr);
794+
}
795+
796+
std::shared_ptr<void> get_smart_ptr_owner() const
797+
{
798+
auto const &api = ::dpctl::detail::dpctl_capi::get();
799+
Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
800+
void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
801+
802+
if (opaque_ptr) {
803+
auto shptr_ptr =
804+
reinterpret_cast<std::shared_ptr<void> *>(opaque_ptr);
805+
return *shptr_ptr;
806+
}
807+
else {
808+
throw std::runtime_error(
809+
"Memory object does not have smart pointer "
810+
"managing lifetime of USM allocation");
811+
}
812+
}
813+
782814
protected:
783815
static PyObject *as_usm_memory(PyObject *o)
784816
{
@@ -1065,6 +1097,63 @@ class usm_ndarray : public py::object
10651097
return static_cast<bool>(flags & api.USM_ARRAY_WRITABLE_);
10661098
}
10671099

1100+
py::object get_usm_data() const
1101+
{
1102+
PyUSMArrayObject *raw_ar = usm_array_ptr();
1103+
1104+
auto const &api = ::dpctl::detail::dpctl_capi::get();
1105+
PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
1106+
1107+
return py::reinterpret_steal<py::object>(usm_data);
1108+
}
1109+
1110+
bool is_managed_by_smart_ptr() const
1111+
{
1112+
PyUSMArrayObject *raw_ar = usm_array_ptr();
1113+
1114+
auto const &api = ::dpctl::detail::dpctl_capi::get();
1115+
PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
1116+
1117+
if (!PyObject_TypeCheck(usm_data, api.Py_MemoryType_))
1118+
return false;
1119+
1120+
Py_MemoryObject *mem_obj =
1121+
reinterpret_cast<Py_MemoryObject *>(usm_data);
1122+
const void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
1123+
1124+
return bool(opaque_ptr);
1125+
}
1126+
1127+
std::shared_ptr<void> get_smart_ptr_owner() const
1128+
{
1129+
PyUSMArrayObject *raw_ar = usm_array_ptr();
1130+
1131+
auto const &api = ::dpctl::detail::dpctl_capi::get();
1132+
1133+
PyObject *usm_data = api.UsmNDArray_GetUSMData_(raw_ar);
1134+
1135+
if (!PyObject_TypeCheck(usm_data, api.Py_MemoryType_)) {
1136+
throw std::runtime_error(
1137+
"usm_ndarray object does not have Memory object "
1138+
"managing lifetime of USM allocation");
1139+
}
1140+
1141+
Py_MemoryObject *mem_obj =
1142+
reinterpret_cast<Py_MemoryObject *>(usm_data);
1143+
void *opaque_ptr = api.Memory_GetOpaquePointer_(mem_obj);
1144+
1145+
if (opaque_ptr) {
1146+
auto shptr_ptr =
1147+
reinterpret_cast<std::shared_ptr<void> *>(opaque_ptr);
1148+
return *shptr_ptr;
1149+
}
1150+
else {
1151+
throw std::runtime_error(
1152+
"Memory object underlying usm_ndarray does not have "
1153+
"smart pointer managing lifetime of USM allocation");
1154+
}
1155+
}
1156+
10681157
private:
10691158
PyUSMArrayObject *usm_array_ptr() const
10701159
{
@@ -1077,26 +1166,107 @@ class usm_ndarray : public py::object
10771166
namespace utils
10781167
{
10791168

1169+
namespace detail
1170+
{
1171+
1172+
struct ManagedMemory
1173+
{
1174+
1175+
static bool is_usm_managed_by_shared_ptr(const py::handle &h)
1176+
{
1177+
if (py::isinstance<dpctl::memory::usm_memory>(h)) {
1178+
auto usm_memory_inst = py::cast<dpctl::memory::usm_memory>(h);
1179+
return usm_memory_inst.is_managed_by_smart_ptr();
1180+
}
1181+
else if (py::isinstance<dpctl::tensor::usm_ndarray>(h)) {
1182+
auto usm_array_inst = py::cast<dpctl::tensor::usm_ndarray>(h);
1183+
return usm_array_inst.is_managed_by_smart_ptr();
1184+
}
1185+
1186+
return false;
1187+
}
1188+
1189+
static std::shared_ptr<void> extract_shared_ptr(const py::handle &h)
1190+
{
1191+
if (py::isinstance<dpctl::memory::usm_memory>(h)) {
1192+
auto usm_memory_inst = py::cast<dpctl::memory::usm_memory>(h);
1193+
return usm_memory_inst.get_smart_ptr_owner();
1194+
}
1195+
else if (py::isinstance<dpctl::tensor::usm_ndarray>(h)) {
1196+
auto usm_array_inst = py::cast<dpctl::tensor::usm_ndarray>(h);
1197+
return usm_array_inst.get_smart_ptr_owner();
1198+
}
1199+
1200+
throw std::runtime_error(
1201+
"Attempted extraction of shared_ptr on an unrecognized type");
1202+
}
1203+
};
1204+
1205+
} // end of namespace detail
1206+
10801207
template <std::size_t num>
10811208
sycl::event keep_args_alive(sycl::queue &q,
10821209
const py::object (&py_objs)[num],
10831210
const std::vector<sycl::event> &depends = {})
10841211
{
1085-
sycl::event host_task_ev = q.submit([&](sycl::handler &cgh) {
1086-
cgh.depends_on(depends);
1087-
std::array<std::shared_ptr<py::handle>, num> shp_arr;
1088-
for (std::size_t i = 0; i < num; ++i) {
1089-
shp_arr[i] = std::make_shared<py::handle>(py_objs[i]);
1090-
shp_arr[i]->inc_ref();
1212+
std::size_t n_objects_held = 0;
1213+
std::array<std::shared_ptr<py::handle>, num> shp_arr{};
1214+
1215+
std::size_t n_usm_owners_held = 0;
1216+
std::array<std::shared_ptr<void>, num> shp_usm{};
1217+
1218+
for (std::size_t i = 0; i < num; ++i) {
1219+
auto py_obj_i = py_objs[i];
1220+
if (detail::ManagedMemory::is_usm_managed_by_shared_ptr(py_obj_i)) {
1221+
shp_usm[n_usm_owners_held] =
1222+
detail::ManagedMemory::extract_shared_ptr(py_obj_i);
1223+
++n_usm_owners_held;
10911224
}
1092-
cgh.host_task([shp_arr = std::move(shp_arr)]() {
1093-
py::gil_scoped_acquire acquire;
1225+
else {
1226+
shp_arr[n_objects_held] = std::make_shared<py::handle>(py_obj_i);
1227+
shp_arr[n_objects_held]->inc_ref();
1228+
++n_objects_held;
1229+
}
1230+
}
1231+
1232+
bool use_depends = true;
1233+
sycl::event host_task_ev;
1234+
1235+
if (n_usm_owners_held > 0) {
1236+
host_task_ev = q.submit([&](sycl::handler &cgh) {
1237+
if (use_depends) {
1238+
cgh.depends_on(depends);
1239+
use_depends = false;
1240+
}
1241+
else {
1242+
cgh.depends_on(host_task_ev);
1243+
}
1244+
cgh.host_task([shp_usm = std::move(shp_usm)]() {
1245+
// no body, but shared pointers are captured in
1246+
// the lamba, ensuring that USM allocation is
1247+
// kept alive
1248+
});
1249+
});
1250+
}
10941251

1095-
for (std::size_t i = 0; i < num; ++i) {
1096-
shp_arr[i]->dec_ref();
1252+
if (n_objects_held > 0) {
1253+
host_task_ev = q.submit([&](sycl::handler &cgh) {
1254+
if (use_depends) {
1255+
cgh.depends_on(depends);
1256+
use_depends = false;
10971257
}
1258+
else {
1259+
cgh.depends_on(host_task_ev);
1260+
}
1261+
cgh.host_task([n_objects_held, shp_arr = std::move(shp_arr)]() {
1262+
py::gil_scoped_acquire acquire;
1263+
1264+
for (std::size_t i = 0; i < n_objects_held; ++i) {
1265+
shp_arr[i]->dec_ref();
1266+
}
1267+
});
10981268
});
1099-
});
1269+
}
11001270

11011271
return host_task_ev;
11021272
}

dpctl/tensor/_usmarray.pyx

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1612,6 +1612,11 @@ cdef api Py_ssize_t UsmNDArray_GetOffset(usm_ndarray arr):
16121612
return arr.get_offset()
16131613

16141614

1615+
cdef api object UsmNDArray_GetUSMData(usm_ndarray arr):
1616+
"""Get USM data object underlying the array"""
1617+
return arr.get_base()
1618+
1619+
16151620
cdef api void UsmNDArray_SetWritableFlag(usm_ndarray arr, int flag):
16161621
"""Set/unset USM_ARRAY_WRITABLE in the given array `arr`."""
16171622
arr._set_writable_flag(flag)

0 commit comments

Comments
 (0)