|
27 | 27 |
|
28 | 28 | #include "dpctl_capi.h"
|
29 | 29 | #include <complex>
|
| 30 | +#include <exception> |
30 | 31 | #include <memory>
|
31 | 32 | #include <pybind11/pybind11.h>
|
32 | 33 | #include <sycl/sycl.hpp>
|
@@ -748,6 +749,54 @@ class usm_memory : public py::object
|
748 | 749 | throw py::error_already_set();
|
749 | 750 | }
|
750 | 751 |
|
| 752 | + /*! @brief Create usm_memory object from shared pointer that manages |
| 753 | + * lifetime of the USM allocation. |
| 754 | + */ |
| 755 | + usm_memory(void *usm_ptr, |
| 756 | + size_t nbytes, |
| 757 | + const sycl::queue &q, |
| 758 | + std::shared_ptr<void> shptr) |
| 759 | + { |
| 760 | + auto const &api = ::dpctl::detail::dpctl_capi::get(); |
| 761 | + DPCTLSyclUSMRef usm_ref = reinterpret_cast<DPCTLSyclUSMRef>(usm_ptr); |
| 762 | + auto q_uptr = std::make_unique<sycl::queue>(q); |
| 763 | + DPCTLSyclQueueRef QRef = |
| 764 | + reinterpret_cast<DPCTLSyclQueueRef>(q_uptr.get()); |
| 765 | + |
| 766 | + auto vacuous_destructor = []() {}; |
| 767 | + py::object mock_owner = py::capsule(vacuous_destructor); |
| 768 | + |
| 769 | + // create memory object owned by mock_owner, it is a new reference |
| 770 | + PyObject *_memory = |
| 771 | + api.Memory_Make_(usm_ref, nbytes, QRef, mock_owner.ptr()); |
| 772 | + auto ref_count_decrementer = [](PyObject *o) noexcept { Py_DECREF(o); }; |
| 773 | + |
| 774 | + using py_uptrT = |
| 775 | + std::unique_ptr<PyObject, decltype(ref_count_decrementer)>; |
| 776 | + auto memory_uptr = py_uptrT(_memory, ref_count_decrementer); |
| 777 | + |
| 778 | + if (!_memory) { |
| 779 | + throw py::error_already_set(); |
| 780 | + } |
| 781 | + |
| 782 | + std::shared_ptr<void> *opaque_ptr = nullptr; |
| 783 | + opaque_ptr = new std::shared_ptr<void>(shptr); |
| 784 | + |
| 785 | + Py_MemoryObject *memobj = reinterpret_cast<Py_MemoryObject *>(_memory); |
| 786 | + // replace mock_owner capsule as the owner |
| 787 | + memobj->refobj = Py_None; |
| 788 | + // set opaque ptr field, usm_memory now knowns that USM is managed |
| 789 | + // by smart pointer |
| 790 | + memobj->_opaque_ptr = reinterpret_cast<void *>(opaque_ptr); |
| 791 | + |
| 792 | + // _memory will delete created copies of sycl::queue, and |
| 793 | + // std::shared_ptr and the deleter of the shared_ptr<void> is |
| 794 | + // supposed to free the USM allocation |
| 795 | + m_ptr = _memory; |
| 796 | + q_uptr.release(); |
| 797 | + memory_uptr.release(); |
| 798 | + } |
| 799 | + |
751 | 800 | sycl::queue get_queue() const
|
752 | 801 | {
|
753 | 802 | Py_MemoryObject *mem_obj = reinterpret_cast<Py_MemoryObject *>(m_ptr);
|
|
0 commit comments