@@ -28,10 +28,8 @@ ur_host_mem_handle_t::ur_host_mem_handle_t(ur_context_handle_t hContext,
2828 }
2929
3030 if (!hostPtrImported) {
31- // TODO: use UMF
32- ZeStruct<ze_host_mem_alloc_desc_t > hostDesc;
33- ZE2UR_CALL_THROWS (zeMemAllocHost, (hContext->getZeHandle (), &hostDesc, size,
34- 0 , &this ->ptr ));
31+ UR_CALL_THROWS (hContext->getDefaultUSMPool ()->allocate (
32+ hContext, nullptr , nullptr , UR_USM_TYPE_HOST, size, &this ->ptr ));
3533
3634 if (hostPtr) {
3735 std::memcpy (this ->ptr , hostPtr, size);
@@ -40,9 +38,11 @@ ur_host_mem_handle_t::ur_host_mem_handle_t(ur_context_handle_t hContext,
4038}
4139
4240ur_host_mem_handle_t ::~ur_host_mem_handle_t () {
43- // TODO: use UMF API here
4441 if (ptr) {
45- ZE_CALL_NOCHECK (zeMemFree, (hContext->getZeHandle (), ptr));
42+ auto ret = hContext->getDefaultUSMPool ()->free (ptr);
43+ if (ret != UR_RESULT_SUCCESS) {
44+ logger::error (" Failed to free host memory: {}" , ret);
45+ }
4646 }
4747}
4848
@@ -51,55 +51,80 @@ void *ur_host_mem_handle_t::getPtr(ur_device_handle_t hDevice) {
5151 return ptr;
5252}
5353
54+ ur_result_t ur_device_mem_handle_t::migrateBufferTo (ur_device_handle_t hDevice,
55+ void *src, size_t size) {
56+ auto Id = hDevice->Id .value ();
57+
58+ if (!deviceAllocations[Id]) {
59+ UR_CALL (hContext->getDefaultUSMPool ()->allocate (hContext, hDevice, nullptr ,
60+ UR_USM_TYPE_DEVICE, size,
61+ &deviceAllocations[Id]));
62+ }
63+
64+ auto commandList = hContext->commandListCache .getImmediateCommandList (
65+ hDevice->ZeDevice , true ,
66+ hDevice
67+ ->QueueGroup [ur_device_handle_t_::queue_group_info_t ::type::Compute]
68+ .ZeOrdinal ,
69+ ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS, ZE_COMMAND_QUEUE_PRIORITY_NORMAL,
70+ std::nullopt );
71+
72+ ZE2UR_CALL (zeCommandListAppendMemoryCopy,
73+ (commandList.get (), deviceAllocations[Id], src, size, nullptr , 0 ,
74+ nullptr ));
75+
76+ activeAllocationDevice = hDevice;
77+
78+ return UR_RESULT_SUCCESS;
79+ }
80+
5481ur_device_mem_handle_t ::ur_device_mem_handle_t (ur_context_handle_t hContext,
5582 void *hostPtr, size_t size)
5683 : ur_mem_handle_t_(hContext, size),
57- deviceAllocations (hContext->getPlatform ()->getNumDevices()) {
58- // Legacy adapter allocated the memory directly on a device (first on the
59- // contxt) and if the buffer is used on another device, memory is migrated
60- // (depending on an env var setting).
61- //
62- // TODO: port this behavior or figure out if it makes sense to keep the memory
63- // in a host buffer (e.g. for smaller sizes).
84+ deviceAllocations (hContext->getPlatform ()->getNumDevices()),
85+ activeAllocationDevice(nullptr ) {
6486 if (hostPtr) {
65- buffer. assign ( reinterpret_cast < char *>(hostPtr),
66- reinterpret_cast < char *>( hostPtr) + size);
87+ auto initialDevice = hContext-> getDevices ()[ 0 ];
88+ UR_CALL_THROWS ( migrateBufferTo (initialDevice, hostPtr, size) );
6789 }
6890}
6991
7092ur_device_mem_handle_t ::~ur_device_mem_handle_t () {
71- // TODO: use UMF API here
7293 for (auto &ptr : deviceAllocations) {
7394 if (ptr) {
74- ZE_CALL_NOCHECK (zeMemFree, (hContext->getZeHandle (), ptr));
95+ auto ret = hContext->getDefaultUSMPool ()->free (ptr);
96+ if (ret != UR_RESULT_SUCCESS) {
97+ logger::error (" Failed to free device memory: {}" , ret);
98+ }
7599 }
76100 }
77101}
78102
79103void *ur_device_mem_handle_t ::getPtr(ur_device_handle_t hDevice) {
80104 std::lock_guard lock (this ->Mutex );
81105
82- auto &ptr = deviceAllocations[hDevice->Id .value ()];
83- if (!ptr) {
84- ZeStruct<ze_device_mem_alloc_desc_t > deviceDesc;
85- ZE2UR_CALL_THROWS (zeMemAllocDevice, (hContext->getZeHandle (), &deviceDesc,
86- size, 0 , hDevice->ZeDevice , &ptr));
87-
88- if (!buffer.empty ()) {
89- auto commandList = hContext->commandListCache .getImmediateCommandList (
90- hDevice->ZeDevice , true ,
91- hDevice
92- ->QueueGroup
93- [ur_device_handle_t_::queue_group_info_t ::type::Compute]
94- .ZeOrdinal ,
95- ZE_COMMAND_QUEUE_MODE_SYNCHRONOUS, ZE_COMMAND_QUEUE_PRIORITY_NORMAL,
96- std::nullopt );
97- ZE2UR_CALL_THROWS (
98- zeCommandListAppendMemoryCopy,
99- (commandList.get (), ptr, buffer.data (), size, nullptr , 0 , nullptr ));
100- }
106+ if (!activeAllocationDevice) {
107+ UR_CALL_THROWS (hContext->getDefaultUSMPool ()->allocate (
108+ hContext, hDevice, nullptr , UR_USM_TYPE_DEVICE, getSize (),
109+ &deviceAllocations[hDevice->Id .value ()]));
110+ activeAllocationDevice = hDevice;
101111 }
102- return ptr;
112+
113+ if (activeAllocationDevice == hDevice) {
114+ return deviceAllocations[hDevice->Id .value ()];
115+ }
116+
117+ auto &p2pDevices = hContext->getP2PDevices (hDevice);
118+ auto p2pAccessible = std::find (p2pDevices.begin (), p2pDevices.end (),
119+ activeAllocationDevice) != p2pDevices.end ();
120+
121+ if (!p2pAccessible) {
122+ // TODO: migrate buffer through the host
123+ throw UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
124+ }
125+
126+ // TODO: see if it's better to migrate the memory to the specified device
127+ return deviceAllocations[activeAllocationDevice->Id .value ()];
103128}
104129
105130namespace ur ::level_zero {
@@ -166,6 +191,28 @@ ur_result_t urMemBufferCreateWithNativeHandle(
166191 return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
167192}
168193
194+ ur_result_t urMemGetInfo (ur_mem_handle_t hMemory, ur_mem_info_t propName,
195+ size_t propSize, void *pPropValue,
196+ size_t *pPropSizeRet) {
197+ std::shared_lock<ur_shared_mutex> Lock (hMemory->Mutex );
198+ UrReturnHelper returnValue (propSize, pPropValue, pPropSizeRet);
199+
200+ switch (propName) {
201+ case UR_MEM_INFO_CONTEXT: {
202+ return returnValue (hMemory->getContext ());
203+ }
204+ case UR_MEM_INFO_SIZE: {
205+ // Get size of the allocation
206+ return returnValue (size_t {hMemory->getSize ()});
207+ }
208+ default : {
209+ return UR_RESULT_ERROR_INVALID_ENUMERATION;
210+ }
211+ }
212+
213+ return UR_RESULT_SUCCESS;
214+ }
215+
169216ur_result_t urMemRetain (ur_mem_handle_t hMem) {
170217 hMem->RefCount .increment ();
171218 return UR_RESULT_SUCCESS;
0 commit comments