1111#include < ur/ur.hpp>
1212
1313#include " common.hpp"
14+ #include " usm.hpp"
15+
16+ void AllocDeleterCallback (cl_event event, cl_int, void *pUserData) {
17+ clReleaseEvent (event);
18+ auto Info = static_cast <AllocDeleterCallbackInfo *>(pUserData);
19+ delete Info;
20+ }
1421
1522namespace umf {
1623ur_result_t getProviderNativeError (const char *, int32_t ) {
@@ -312,32 +319,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
312319 numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
313320 &CopyEvent));
314321
315- struct DeleteCallbackInfo {
316- DeleteCallbackInfo (clMemBlockingFreeINTEL_fn USMFree, cl_context CLContext,
317- void *HostBuffer)
318- : USMFree(USMFree), CLContext(CLContext), HostBuffer(HostBuffer) {
319- clRetainContext (CLContext);
320- }
321- ~DeleteCallbackInfo () {
322- USMFree (CLContext, HostBuffer);
323- clReleaseContext (CLContext);
324- }
325- DeleteCallbackInfo (const DeleteCallbackInfo &) = delete ;
326- DeleteCallbackInfo &operator =(const DeleteCallbackInfo &) = delete ;
327-
328- clMemBlockingFreeINTEL_fn USMFree;
329- cl_context CLContext;
330- void *HostBuffer;
331- };
332-
333- auto Info = new DeleteCallbackInfo (USMFree, CLContext, HostBuffer);
322+ // This self destructs taking the event and allocation with it.
323+ auto Info = new AllocDeleterCallbackInfo (USMFree, CLContext, HostBuffer);
334324
335- auto DeleteCallback = [](cl_event, cl_int, void *pUserData) {
336- auto Info = static_cast <DeleteCallbackInfo *>(pUserData);
337- delete Info;
338- };
339-
340- ClErr = clSetEventCallback (CopyEvent, CL_COMPLETE, DeleteCallback, Info);
325+ ClErr =
326+ clSetEventCallback (CopyEvent, CL_COMPLETE, AllocDeleterCallback, Info);
341327 if (ClErr != CL_SUCCESS) {
342328 // We can attempt to recover gracefully by attempting to wait for the copy
343329 // to finish and deleting the info struct here.
@@ -347,9 +333,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
347333 CL_RETURN_ON_FAILURE (ClErr);
348334 }
349335 if (phEvent) {
336+ // Since we're releasing this in the callback above we need to retain it
337+ // here to keep the user copy alive.
338+ CL_RETURN_ON_FAILURE (clRetainEvent (CopyEvent));
350339 *phEvent = cl_adapter::cast<ur_event_handle_t >(CopyEvent);
351- } else {
352- CL_RETURN_ON_FAILURE (clReleaseEvent (CopyEvent));
353340 }
354341
355342 return UR_RESULT_SUCCESS;
@@ -369,20 +356,110 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
369356 return mapCLErrorToUR (CLErr);
370357 }
371358
372- clEnqueueMemcpyINTEL_fn FuncPtr = nullptr ;
373- ur_result_t RetVal = cl_ext::getExtFuncFromContext<clEnqueueMemcpyINTEL_fn>(
359+ clGetMemAllocInfoINTEL_fn GetMemAllocInfo = nullptr ;
360+ UR_RETURN_ON_FAILURE (cl_ext::getExtFuncFromContext<clGetMemAllocInfoINTEL_fn>(
361+ CLContext, cl_ext::ExtFuncPtrCache->clGetMemAllocInfoINTELCache ,
362+ cl_ext::GetMemAllocInfoName, &GetMemAllocInfo));
363+
364+ clEnqueueMemcpyINTEL_fn USMMemcpy = nullptr ;
365+ UR_RETURN_ON_FAILURE (cl_ext::getExtFuncFromContext<clEnqueueMemcpyINTEL_fn>(
374366 CLContext, cl_ext::ExtFuncPtrCache->clEnqueueMemcpyINTELCache ,
375- cl_ext::EnqueueMemcpyName, &FuncPtr );
367+ cl_ext::EnqueueMemcpyName, &USMMemcpy) );
376368
377- if (FuncPtr) {
378- RetVal = mapCLErrorToUR (
379- FuncPtr (cl_adapter::cast<cl_command_queue>(hQueue), blocking, pDst,
380- pSrc, size, numEventsInWaitList,
381- cl_adapter::cast<const cl_event *>(phEventWaitList),
382- cl_adapter::cast<cl_event *>(phEvent)));
369+ clMemBlockingFreeINTEL_fn USMFree = nullptr ;
370+ UR_RETURN_ON_FAILURE (cl_ext::getExtFuncFromContext<clMemBlockingFreeINTEL_fn>(
371+ CLContext, cl_ext::ExtFuncPtrCache->clMemBlockingFreeINTELCache ,
372+ cl_ext::MemBlockingFreeName, &USMFree));
373+
374+ // Check if the two allocations are DEVICE allocations from different
375+ // devices, if they are we need to do the copy indirectly via a host
376+ // allocation.
377+ cl_device_id SrcDevice = 0 , DstDevice = 0 ;
378+ CL_RETURN_ON_FAILURE (
379+ GetMemAllocInfo (CLContext, pSrc, CL_MEM_ALLOC_DEVICE_INTEL,
380+ sizeof (cl_device_id), &SrcDevice, nullptr ));
381+ CL_RETURN_ON_FAILURE (
382+ GetMemAllocInfo (CLContext, pDst, CL_MEM_ALLOC_DEVICE_INTEL,
383+ sizeof (cl_device_id), &SrcDevice, nullptr ));
384+
385+ if ((SrcDevice && DstDevice) && SrcDevice != DstDevice) {
386+ cl_event HostCopyEvent = nullptr , FinalCopyEvent = nullptr ;
387+ clHostMemAllocINTEL_fn HostMemAlloc = nullptr ;
388+ UR_RETURN_ON_FAILURE (cl_ext::getExtFuncFromContext<clHostMemAllocINTEL_fn>(
389+ CLContext, cl_ext::ExtFuncPtrCache->clHostMemAllocINTELCache ,
390+ cl_ext::HostMemAllocName, &HostMemAlloc));
391+
392+ auto HostAlloc = HostMemAlloc (CLContext, nullptr , size, 0 , &CLErr);
393+ CL_RETURN_ON_FAILURE (CLErr);
394+
395+ // Now that we've successfully allocated we should try to clean it up if we
396+ // hit an error somewhere.
397+ auto checkCLErr = [&](cl_int CLErr) -> ur_result_t {
398+ if (CLErr != CL_SUCCESS) {
399+ if (HostCopyEvent) {
400+ clReleaseEvent (HostCopyEvent);
401+ }
402+ if (FinalCopyEvent) {
403+ clReleaseEvent (FinalCopyEvent);
404+ }
405+ USMFree (CLContext, HostAlloc);
406+ CL_RETURN_ON_FAILURE (CLErr);
407+ }
408+ return UR_RESULT_SUCCESS;
409+ };
410+
411+ UR_RETURN_ON_FAILURE (checkCLErr (USMMemcpy (
412+ cl_adapter::cast<cl_command_queue>(hQueue), blocking, HostAlloc, pSrc,
413+ size, numEventsInWaitList,
414+ cl_adapter::cast<const cl_event *>(phEventWaitList), &HostCopyEvent)));
415+
416+ UR_RETURN_ON_FAILURE (checkCLErr (
417+ USMMemcpy (cl_adapter::cast<cl_command_queue>(hQueue), blocking, pDst,
418+ HostAlloc, size, 1 , &HostCopyEvent, &FinalCopyEvent)));
419+
420+ // If this is a blocking operation we can do our cleanup immediately,
421+ // otherwise we need to defer it to an event callback.
422+ if (blocking) {
423+ CL_RETURN_ON_FAILURE (USMFree (CLContext, HostAlloc));
424+ CL_RETURN_ON_FAILURE (clReleaseEvent (HostCopyEvent));
425+ if (phEvent) {
426+ *phEvent = cl_adapter::cast<ur_event_handle_t >(FinalCopyEvent);
427+ } else {
428+ CL_RETURN_ON_FAILURE (clReleaseEvent (FinalCopyEvent));
429+ }
430+ } else {
431+ if (phEvent) {
432+ *phEvent = cl_adapter::cast<ur_event_handle_t >(FinalCopyEvent);
433+ // We are going to release this event in our callback so we need to
434+ // retain if the user wants a copy.
435+ CL_RETURN_ON_FAILURE (clRetainEvent (FinalCopyEvent));
436+ }
437+
438+ // This self destructs taking the event and allocation with it.
439+ auto DeleterInfo =
440+ new AllocDeleterCallbackInfo{USMFree, CLContext, HostAlloc};
441+
442+ CLErr = clSetEventCallback (HostCopyEvent, CL_COMPLETE,
443+ AllocDeleterCallback, DeleterInfo);
444+
445+ if (CLErr != CL_SUCCESS) {
446+ // We can attempt to recover gracefully by attempting to wait for the
447+ // copy to finish and deleting the info struct here.
448+ clWaitForEvents (1 , &HostCopyEvent);
449+ delete DeleterInfo;
450+ clReleaseEvent (HostCopyEvent);
451+ CL_RETURN_ON_FAILURE (CLErr);
452+ }
453+ }
454+ } else {
455+ CL_RETURN_ON_FAILURE (
456+ USMMemcpy (cl_adapter::cast<cl_command_queue>(hQueue), blocking, pDst,
457+ pSrc, size, numEventsInWaitList,
458+ cl_adapter::cast<const cl_event *>(phEventWaitList),
459+ cl_adapter::cast<cl_event *>(phEvent)));
383460 }
384461
385- return RetVal ;
462+ return UR_RESULT_SUCCESS ;
386463}
387464
388465UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch (
0 commit comments