@@ -115,10 +115,7 @@ ur_integrated_mem_handle_t::ur_integrated_mem_handle_t(
115115 if (!ownHostPtr) {
116116 return ;
117117 }
118- auto ret = hContext->getDefaultUSMPool ()->free (ptr);
119- if (ret != UR_RESULT_SUCCESS) {
120- logger::error (" Failed to free host memory: {}" , ret);
121- }
118+ ZE_CALL_NOCHECK (zeMemFree, (hContext->getZeHandle (), ptr));
122119 });
123120}
124121
@@ -209,7 +206,7 @@ ur_discrete_mem_handle_t::ur_discrete_mem_handle_t(
209206 device_access_mode_t accessMode)
210207 : ur_mem_handle_t_(hContext, size, accessMode),
211208 deviceAllocations (hContext->getPlatform ()->getNumDevices()),
212- activeAllocationDevice(nullptr ), hostAllocations() {
209+ activeAllocationDevice(nullptr ), mapToPtr(hostPtr), hostAllocations() {
213210 if (hostPtr) {
214211 auto initialDevice = hContext->getDevices ()[0 ];
215212 UR_CALL_THROWS (migrateBufferTo (initialDevice, hostPtr, size));
@@ -234,10 +231,7 @@ ur_discrete_mem_handle_t::ur_discrete_mem_handle_t(
234231 if (!ownZePtr) {
235232 return ;
236233 }
237- auto ret = hContext->getDefaultUSMPool ()->free (ptr);
238- if (ret != UR_RESULT_SUCCESS) {
239- logger::error (" Failed to free device memory: {}" , ret);
240- }
234+ ZE_CALL_NOCHECK (zeMemFree, (hContext->getZeHandle (), ptr));
241235 });
242236 }
243237}
@@ -246,12 +240,18 @@ ur_discrete_mem_handle_t::~ur_discrete_mem_handle_t() {
246240 if (!activeAllocationDevice || !writeBackPtr)
247241 return ;
248242
249- auto srcPtr = ur_cast<char *>(
250- deviceAllocations[activeAllocationDevice->Id .value ()].get ());
243+ auto srcPtr = getActiveDeviceAlloc ();
251244 synchronousZeCopy (hContext, activeAllocationDevice, writeBackPtr, srcPtr,
252245 getSize ());
253246}
254247
248+ void *ur_discrete_mem_handle_t ::getActiveDeviceAlloc(size_t offset) {
249+ assert (activeAllocationDevice);
250+ return ur_cast<char *>(
251+ deviceAllocations[activeAllocationDevice->Id .value ()].get ()) +
252+ offset;
253+ }
254+
255255void *ur_discrete_mem_handle_t ::getDevicePtr(
256256 ur_device_handle_t hDevice, device_access_mode_t access, size_t offset,
257257 size_t size, std::function<void (void *src, void *dst, size_t )> migrate) {
@@ -272,10 +272,8 @@ void *ur_discrete_mem_handle_t::getDevicePtr(
272272 hDevice = activeAllocationDevice;
273273 }
274274
275- char *ptr;
276275 if (activeAllocationDevice == hDevice) {
277- ptr = ur_cast<char *>(deviceAllocations[hDevice->Id .value ()].get ());
278- return ptr + offset;
276+ return getActiveDeviceAlloc (offset);
279277 }
280278
281279 auto &p2pDevices = hContext->getP2PDevices (hDevice);
@@ -288,9 +286,7 @@ void *ur_discrete_mem_handle_t::getDevicePtr(
288286 }
289287
290288 // TODO: see if it's better to migrate the memory to the specified device
291- return ur_cast<char *>(
292- deviceAllocations[activeAllocationDevice->Id .value ()].get ()) +
293- offset;
289+ return getActiveDeviceAlloc (offset);
294290}
295291
296292void *ur_discrete_mem_handle_t ::mapHostPtr(
@@ -299,55 +295,63 @@ void *ur_discrete_mem_handle_t::mapHostPtr(
299295 TRACK_SCOPE_LATENCY (" ur_discrete_mem_handle_t::mapHostPtr" );
300296 // TODO: use async alloc?
301297
302- void *ptr;
303- UR_CALL_THROWS (hContext->getDefaultUSMPool ()->allocate (
304- hContext, nullptr , nullptr , UR_USM_TYPE_HOST, size, &ptr));
298+ void *ptr = mapToPtr;
299+ if (!ptr) {
300+ UR_CALL_THROWS (hContext->getDefaultUSMPool ()->allocate (
301+ hContext, nullptr , nullptr , UR_USM_TYPE_HOST, size, &ptr));
302+ }
305303
306- hostAllocations.emplace_back (ptr, size, offset, flags);
304+ usm_unique_ptr_t mappedPtr =
305+ usm_unique_ptr_t (ptr, [ownsAlloc = bool (mapToPtr), this ](void *p) {
306+ if (ownsAlloc) {
307+ auto ret = hContext->getDefaultUSMPool ()->free (p);
308+ if (ret != UR_RESULT_SUCCESS) {
309+ logger::error (" Failed to mapped memory: {}" , ret);
310+ }
311+ }
312+ });
313+
314+ hostAllocations.emplace_back (std::move (mappedPtr), size, offset, flags);
307315
308316 if (activeAllocationDevice && (flags & UR_MAP_FLAG_READ)) {
309- auto srcPtr =
310- ur_cast<char *>(
311- deviceAllocations[activeAllocationDevice->Id .value ()].get ()) +
312- offset;
313- migrate (srcPtr, hostAllocations.back ().ptr , size);
317+ auto srcPtr = getActiveDeviceAlloc (offset);
318+ migrate (srcPtr, hostAllocations.back ().ptr .get (), size);
314319 }
315320
316- return hostAllocations.back ().ptr ;
321+ return hostAllocations.back ().ptr . get () ;
317322}
318323
319324void ur_discrete_mem_handle_t::unmapHostPtr (
320325 void *pMappedPtr,
321326 std::function<void (void *src, void *dst, size_t )> migrate) {
322327 TRACK_SCOPE_LATENCY (" ur_discrete_mem_handle_t::unmapHostPtr" );
323328
324- for (auto &hostAllocation : hostAllocations) {
325- if (hostAllocation.ptr == pMappedPtr) {
326- void *devicePtr = nullptr ;
327- if (activeAllocationDevice) {
328- devicePtr =
329- ur_cast<char *>(
330- deviceAllocations[activeAllocationDevice->Id .value ()].get ()) +
331- hostAllocation.offset ;
332- } else if (!(hostAllocation.flags &
333- UR_MAP_FLAG_WRITE_INVALIDATE_REGION)) {
334- devicePtr = ur_cast<char *>(getDevicePtr (
335- hContext->getDevices ()[0 ], device_access_mode_t ::read_only,
336- hostAllocation.offset , hostAllocation.size , migrate));
337- }
329+ auto hostAlloc =
330+ std::find_if (hostAllocations.begin (), hostAllocations.end (),
331+ [pMappedPtr](const host_allocation_desc_t &desc) {
332+ return desc.ptr .get () == pMappedPtr;
333+ });
338334
339- if (devicePtr ) {
340- migrate (hostAllocation. ptr , devicePtr, hostAllocation. size ) ;
341- }
335+ if (hostAlloc == hostAllocations. end () ) {
336+ throw UR_RESULT_ERROR_INVALID_ARGUMENT ;
337+ }
342338
343- // TODO: use async free here?
344- UR_CALL_THROWS (hContext->getDefaultUSMPool ()->free (hostAllocation.ptr ));
345- return ;
346- }
339+ bool shouldMigrateToDevice =
340+ !(hostAlloc->flags & UR_MAP_FLAG_WRITE_INVALIDATE_REGION);
341+
342+ if (!activeAllocationDevice && shouldMigrateToDevice) {
343+ allocateOnDevice (hContext->getDevices ()[0 ], getSize ());
344+ }
345+
346+ // TODO: tests require that memory is migrated even for
347+ // UR_MAP_FLAG_WRITE_INVALIDATE_REGION when there is an active device
348+ // allocation. is this correct?
349+ if (activeAllocationDevice) {
350+ migrate (hostAlloc->ptr .get (), getActiveDeviceAlloc (hostAlloc->offset ),
351+ hostAlloc->size );
347352 }
348353
349- // No mapping found
350- throw UR_RESULT_ERROR_INVALID_ARGUMENT;
354+ hostAllocations.erase (hostAlloc);
351355}
352356
353357static bool useHostBuffer (ur_context_handle_t hContext) {
@@ -419,8 +423,6 @@ ur_result_t urMemBufferCreate(ur_context_handle_t hContext,
419423 auto accessMode = getDeviceAccessMode (flags);
420424
421425 if (useHostBuffer (hContext)) {
422- // TODO: assert that if hostPtr is set, either UR_MEM_FLAG_USE_HOST_POINTER
423- // or UR_MEM_FLAG_ALLOC_COPY_HOST_POINTER is set?
424426 auto hostPtrAction =
425427 flags & UR_MEM_FLAG_USE_HOST_POINTER
426428 ? ur_integrated_mem_handle_t ::host_ptr_action_t ::import
0 commit comments