@@ -209,7 +209,7 @@ ur_discrete_mem_handle_t::ur_discrete_mem_handle_t(
209209 device_access_mode_t accessMode)
210210 : ur_mem_handle_t_(hContext, size, accessMode),
211211 deviceAllocations (hContext->getPlatform ()->getNumDevices()),
212- activeAllocationDevice(nullptr ), hostAllocations() {
212+ activeAllocationDevice(nullptr ), mapToPtr(hostPtr), hostAllocations() {
213213 if (hostPtr) {
214214 auto initialDevice = hContext->getDevices ()[0 ];
215215 UR_CALL_THROWS (migrateBufferTo (initialDevice, hostPtr, size));
@@ -246,12 +246,18 @@ ur_discrete_mem_handle_t::~ur_discrete_mem_handle_t() {
246246 if (!activeAllocationDevice || !writeBackPtr)
247247 return ;
248248
249- auto srcPtr = ur_cast<char *>(
250- deviceAllocations[activeAllocationDevice->Id .value ()].get ());
249+ auto srcPtr = getActiveDeviceAlloc ();
251250 synchronousZeCopy (hContext, activeAllocationDevice, writeBackPtr, srcPtr,
252251 getSize ());
253252}
254253
254+ void *ur_discrete_mem_handle_t ::getActiveDeviceAlloc(size_t offset) {
255+ assert (activeAllocationDevice);
256+ return ur_cast<char *>(
257+ deviceAllocations[activeAllocationDevice->Id .value ()].get ()) +
258+ offset;
259+ }
260+
255261void *ur_discrete_mem_handle_t ::getDevicePtr(
256262 ur_device_handle_t hDevice, device_access_mode_t access, size_t offset,
257263 size_t size, std::function<void (void *src, void *dst, size_t )> migrate) {
@@ -272,10 +278,8 @@ void *ur_discrete_mem_handle_t::getDevicePtr(
272278 hDevice = activeAllocationDevice;
273279 }
274280
275- char *ptr;
276281 if (activeAllocationDevice == hDevice) {
277- ptr = ur_cast<char *>(deviceAllocations[hDevice->Id .value ()].get ());
278- return ptr + offset;
282+ return getActiveDeviceAlloc (offset);
279283 }
280284
281285 auto &p2pDevices = hContext->getP2PDevices (hDevice);
@@ -288,9 +292,7 @@ void *ur_discrete_mem_handle_t::getDevicePtr(
288292 }
289293
290294 // TODO: see if it's better to migrate the memory to the specified device
291- return ur_cast<char *>(
292- deviceAllocations[activeAllocationDevice->Id .value ()].get ()) +
293- offset;
295+ return getActiveDeviceAlloc (offset);
294296}
295297
296298void *ur_discrete_mem_handle_t ::mapHostPtr(
@@ -299,55 +301,60 @@ void *ur_discrete_mem_handle_t::mapHostPtr(
299301 TRACK_SCOPE_LATENCY (" ur_discrete_mem_handle_t::mapHostPtr" );
300302 // TODO: use async alloc?
301303
302- void *ptr;
303- UR_CALL_THROWS (hContext->getDefaultUSMPool ()->allocate (
304- hContext, nullptr , nullptr , UR_USM_TYPE_HOST, size, &ptr));
304+ void *ptr = mapToPtr;
305+ if (!ptr) {
306+ UR_CALL_THROWS (hContext->getDefaultUSMPool ()->allocate (
307+ hContext, nullptr , nullptr , UR_USM_TYPE_HOST, size, &ptr));
308+ }
305309
306- hostAllocations.emplace_back (ptr, size, offset, flags);
310+ usm_unique_ptr_t mappedPtr =
311+ usm_unique_ptr_t (ptr, [ownsAlloc = bool (mapToPtr), this ](void *p) {
312+ if (ownsAlloc) {
313+ UR_CALL_THROWS (hContext->getDefaultUSMPool ()->free (p));
314+ }
315+ });
316+
317+ hostAllocations.emplace_back (std::move (mappedPtr), size, offset, flags);
307318
308319 if (activeAllocationDevice && (flags & UR_MAP_FLAG_READ)) {
309- auto srcPtr =
310- ur_cast<char *>(
311- deviceAllocations[activeAllocationDevice->Id .value ()].get ()) +
312- offset;
313- migrate (srcPtr, hostAllocations.back ().ptr , size);
320+ auto srcPtr = getActiveDeviceAlloc (offset);
321+ migrate (srcPtr, hostAllocations.back ().ptr .get (), size);
314322 }
315323
316- return hostAllocations.back ().ptr ;
324+ return hostAllocations.back ().ptr . get () ;
317325}
318326
319327void ur_discrete_mem_handle_t::unmapHostPtr (
320328 void *pMappedPtr,
321329 std::function<void (void *src, void *dst, size_t )> migrate) {
322330 TRACK_SCOPE_LATENCY (" ur_discrete_mem_handle_t::unmapHostPtr" );
323331
324- for (auto &hostAllocation : hostAllocations) {
325- if (hostAllocation.ptr == pMappedPtr) {
326- void *devicePtr = nullptr ;
327- if (activeAllocationDevice) {
328- devicePtr =
329- ur_cast<char *>(
330- deviceAllocations[activeAllocationDevice->Id .value ()].get ()) +
331- hostAllocation.offset ;
332- } else if (!(hostAllocation.flags &
333- UR_MAP_FLAG_WRITE_INVALIDATE_REGION)) {
334- devicePtr = ur_cast<char *>(getDevicePtr (
335- hContext->getDevices ()[0 ], device_access_mode_t ::read_only,
336- hostAllocation.offset , hostAllocation.size , migrate));
337- }
332+ auto hostAlloc =
333+ std::find_if (hostAllocations.begin (), hostAllocations.end (),
334+ [pMappedPtr](const host_allocation_desc_t &desc) {
335+ return desc.ptr .get () == pMappedPtr;
336+ });
338337
339- if (devicePtr ) {
340- migrate (hostAllocation. ptr , devicePtr, hostAllocation. size ) ;
341- }
338+ if (hostAlloc == hostAllocations. end () ) {
339+ throw UR_RESULT_ERROR_INVALID_ARGUMENT ;
340+ }
342341
343- // TODO: use async free here?
344- UR_CALL_THROWS (hContext->getDefaultUSMPool ()->free (hostAllocation.ptr ));
345- return ;
346- }
342+ bool shouldMigrateToDevice =
343+ !(hostAlloc->flags & UR_MAP_FLAG_WRITE_INVALIDATE_REGION);
344+
345+ if (!activeAllocationDevice && shouldMigrateToDevice) {
346+ allocateOnDevice (hContext->getDevices ()[0 ], getSize ());
347+ }
348+
349+ // TODO: tests require that memory is migrated even for
350+ // UR_MAP_FLAG_WRITE_INVALIDATE_REGION when there is an active device
351+ // allocation. is this correct?
352+ if (activeAllocationDevice) {
353+ migrate (hostAlloc->ptr .get (), getActiveDeviceAlloc (hostAlloc->offset ),
354+ hostAlloc->size );
347355 }
348356
349- // No mapping found
350- throw UR_RESULT_ERROR_INVALID_ARGUMENT;
357+ hostAllocations.erase (hostAlloc);
351358}
352359
353360static bool useHostBuffer (ur_context_handle_t hContext) {
0 commit comments