1111#include < ur/ur.hpp>
1212
1313#include " common.hpp"
14+ #include " usm.hpp"
15+
16+ void AllocDeleterCallback (cl_event event, cl_int, void *pUserData) {
17+ clReleaseEvent (event);
18+ auto Info = static_cast <AllocDeleterCallbackInfo *>(pUserData);
19+ delete Info;
20+ }
1421
1522inline cl_mem_alloc_flags_intel
1623hostDescToClFlags (const ur_usm_host_desc_t &desc) {
@@ -305,32 +312,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
305312 numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
306313 &CopyEvent));
307314
308- struct DeleteCallbackInfo {
309- DeleteCallbackInfo (clMemBlockingFreeINTEL_fn USMFree, cl_context CLContext,
310- void *HostBuffer)
311- : USMFree(USMFree), CLContext(CLContext), HostBuffer(HostBuffer) {
312- clRetainContext (CLContext);
313- }
314- ~DeleteCallbackInfo () {
315- USMFree (CLContext, HostBuffer);
316- clReleaseContext (CLContext);
317- }
318- DeleteCallbackInfo (const DeleteCallbackInfo &) = delete ;
319- DeleteCallbackInfo &operator =(const DeleteCallbackInfo &) = delete ;
320-
321- clMemBlockingFreeINTEL_fn USMFree;
322- cl_context CLContext;
323- void *HostBuffer;
324- };
325-
326- auto Info = new DeleteCallbackInfo (USMFree, CLContext, HostBuffer);
315+ // This self destructs taking the event and allocation with it.
316+ auto Info = new AllocDeleterCallbackInfo (USMFree, CLContext, HostBuffer);
327317
328- auto DeleteCallback = [](cl_event, cl_int, void *pUserData) {
329- auto Info = static_cast <DeleteCallbackInfo *>(pUserData);
330- delete Info;
331- };
332-
333- ClErr = clSetEventCallback (CopyEvent, CL_COMPLETE, DeleteCallback, Info);
318+ ClErr =
319+ clSetEventCallback (CopyEvent, CL_COMPLETE, AllocDeleterCallback, Info);
334320 if (ClErr != CL_SUCCESS) {
335321 // We can attempt to recover gracefully by attempting to wait for the copy
336322 // to finish and deleting the info struct here.
@@ -340,9 +326,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
340326 CL_RETURN_ON_FAILURE (ClErr);
341327 }
342328 if (phEvent) {
329+ // Since we're releasing this in the callback above we need to retain it
330+ // here to keep the user copy alive.
331+ CL_RETURN_ON_FAILURE (clRetainEvent (CopyEvent));
343332 *phEvent = cl_adapter::cast<ur_event_handle_t >(CopyEvent);
344- } else {
345- CL_RETURN_ON_FAILURE (clReleaseEvent (CopyEvent));
346333 }
347334
348335 return UR_RESULT_SUCCESS;
@@ -362,20 +349,110 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
362349 return mapCLErrorToUR (CLErr);
363350 }
364351
365- clEnqueueMemcpyINTEL_fn FuncPtr = nullptr ;
366- ur_result_t RetVal = cl_ext::getExtFuncFromContext<clEnqueueMemcpyINTEL_fn>(
352+ clGetMemAllocInfoINTEL_fn GetMemAllocInfo = nullptr ;
353+ UR_RETURN_ON_FAILURE (cl_ext::getExtFuncFromContext<clGetMemAllocInfoINTEL_fn>(
354+ CLContext, cl_ext::ExtFuncPtrCache->clGetMemAllocInfoINTELCache ,
355+ cl_ext::GetMemAllocInfoName, &GetMemAllocInfo));
356+
357+ clEnqueueMemcpyINTEL_fn USMMemcpy = nullptr ;
358+ UR_RETURN_ON_FAILURE (cl_ext::getExtFuncFromContext<clEnqueueMemcpyINTEL_fn>(
367359 CLContext, cl_ext::ExtFuncPtrCache->clEnqueueMemcpyINTELCache ,
368- cl_ext::EnqueueMemcpyName, &FuncPtr );
360+ cl_ext::EnqueueMemcpyName, &USMMemcpy) );
369361
370- if (FuncPtr) {
371- RetVal = mapCLErrorToUR (
372- FuncPtr (cl_adapter::cast<cl_command_queue>(hQueue), blocking, pDst,
373- pSrc, size, numEventsInWaitList,
374- cl_adapter::cast<const cl_event *>(phEventWaitList),
375- cl_adapter::cast<cl_event *>(phEvent)));
362+ clMemBlockingFreeINTEL_fn USMFree = nullptr ;
363+ UR_RETURN_ON_FAILURE (cl_ext::getExtFuncFromContext<clMemBlockingFreeINTEL_fn>(
364+ CLContext, cl_ext::ExtFuncPtrCache->clMemBlockingFreeINTELCache ,
365+ cl_ext::MemBlockingFreeName, &USMFree));
366+
367+ // Check if the two allocations are DEVICE allocations from different
368+ // devices, if they are we need to do the copy indirectly via a host
369+ // allocation.
370+ cl_device_id SrcDevice = 0 , DstDevice = 0 ;
371+ CL_RETURN_ON_FAILURE (
372+ GetMemAllocInfo (CLContext, pSrc, CL_MEM_ALLOC_DEVICE_INTEL,
373+ sizeof (cl_device_id), &SrcDevice, nullptr ));
374+ CL_RETURN_ON_FAILURE (
375+ GetMemAllocInfo (CLContext, pDst, CL_MEM_ALLOC_DEVICE_INTEL,
376+ sizeof (cl_device_id), &SrcDevice, nullptr ));
377+
378+ if ((SrcDevice && DstDevice) && SrcDevice != DstDevice) {
379+ cl_event HostCopyEvent = nullptr , FinalCopyEvent = nullptr ;
380+ clHostMemAllocINTEL_fn HostMemAlloc = nullptr ;
381+ UR_RETURN_ON_FAILURE (cl_ext::getExtFuncFromContext<clHostMemAllocINTEL_fn>(
382+ CLContext, cl_ext::ExtFuncPtrCache->clHostMemAllocINTELCache ,
383+ cl_ext::HostMemAllocName, &HostMemAlloc));
384+
385+ auto HostAlloc = HostMemAlloc (CLContext, nullptr , size, 0 , &CLErr);
386+ CL_RETURN_ON_FAILURE (CLErr);
387+
388+ // Now that we've successfully allocated we should try to clean it up if we
389+ // hit an error somewhere.
390+ auto checkCLErr = [&](cl_int CLErr) -> ur_result_t {
391+ if (CLErr != CL_SUCCESS) {
392+ if (HostCopyEvent) {
393+ clReleaseEvent (HostCopyEvent);
394+ }
395+ if (FinalCopyEvent) {
396+ clReleaseEvent (FinalCopyEvent);
397+ }
398+ USMFree (CLContext, HostAlloc);
399+ CL_RETURN_ON_FAILURE (CLErr);
400+ }
401+ return UR_RESULT_SUCCESS;
402+ };
403+
404+ UR_RETURN_ON_FAILURE (checkCLErr (USMMemcpy (
405+ cl_adapter::cast<cl_command_queue>(hQueue), blocking, HostAlloc, pSrc,
406+ size, numEventsInWaitList,
407+ cl_adapter::cast<const cl_event *>(phEventWaitList), &HostCopyEvent)));
408+
409+ UR_RETURN_ON_FAILURE (checkCLErr (
410+ USMMemcpy (cl_adapter::cast<cl_command_queue>(hQueue), blocking, pDst,
411+ HostAlloc, size, 1 , &HostCopyEvent, &FinalCopyEvent)));
412+
413+ // If this is a blocking operation we can do our cleanup immediately,
414+ // otherwise we need to defer it to an event callback.
415+ if (blocking) {
416+ CL_RETURN_ON_FAILURE (USMFree (CLContext, HostAlloc));
417+ CL_RETURN_ON_FAILURE (clReleaseEvent (HostCopyEvent));
418+ if (phEvent) {
419+ *phEvent = cl_adapter::cast<ur_event_handle_t >(FinalCopyEvent);
420+ } else {
421+ CL_RETURN_ON_FAILURE (clReleaseEvent (FinalCopyEvent));
422+ }
423+ } else {
424+ if (phEvent) {
425+ *phEvent = cl_adapter::cast<ur_event_handle_t >(FinalCopyEvent);
426+ // We are going to release this event in our callback so we need to
427+ // retain if the user wants a copy.
428+ CL_RETURN_ON_FAILURE (clRetainEvent (FinalCopyEvent));
429+ }
430+
431+ // This self destructs taking the event and allocation with it.
432+ auto DeleterInfo =
433+ new AllocDeleterCallbackInfo{USMFree, CLContext, HostAlloc};
434+
435+ CLErr = clSetEventCallback (HostCopyEvent, CL_COMPLETE,
436+ AllocDeleterCallback, DeleterInfo);
437+
438+ if (CLErr != CL_SUCCESS) {
439+ // We can attempt to recover gracefully by attempting to wait for the
440+ // copy to finish and deleting the info struct here.
441+ clWaitForEvents (1 , &HostCopyEvent);
442+ delete DeleterInfo;
443+ clReleaseEvent (HostCopyEvent);
444+ CL_RETURN_ON_FAILURE (CLErr);
445+ }
446+ }
447+ } else {
448+ CL_RETURN_ON_FAILURE (
449+ USMMemcpy (cl_adapter::cast<cl_command_queue>(hQueue), blocking, pDst,
450+ pSrc, size, numEventsInWaitList,
451+ cl_adapter::cast<const cl_event *>(phEventWaitList),
452+ cl_adapter::cast<cl_event *>(phEvent)));
376453 }
377454
378- return RetVal ;
455+ return UR_RESULT_SUCCESS ;
379456}
380457
381458UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch (
0 commit comments