1111#include < ur/ur.hpp>
1212
1313#include " common.hpp"
14+ #include " usm.hpp"
15+
16+ void AllocDeleterCallback (cl_event event, cl_int, void *pUserData) {
17+ clReleaseEvent (event);
18+ auto Info = static_cast <AllocDeleterCallbackInfo *>(pUserData);
19+ delete Info;
20+ }
1421
1522inline cl_mem_alloc_flags_intel
1623hostDescToClFlags (const ur_usm_host_desc_t &desc) {
@@ -290,32 +297,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
290297 numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
291298 &CopyEvent));
292299
293- struct DeleteCallbackInfo {
294- DeleteCallbackInfo (clMemBlockingFreeINTEL_fn USMFree, cl_context CLContext,
295- void *HostBuffer)
296- : USMFree(USMFree), CLContext(CLContext), HostBuffer(HostBuffer) {
297- clRetainContext (CLContext);
298- }
299- ~DeleteCallbackInfo () {
300- USMFree (CLContext, HostBuffer);
301- clReleaseContext (CLContext);
302- }
303- DeleteCallbackInfo (const DeleteCallbackInfo &) = delete ;
304- DeleteCallbackInfo &operator =(const DeleteCallbackInfo &) = delete ;
305-
306- clMemBlockingFreeINTEL_fn USMFree;
307- cl_context CLContext;
308- void *HostBuffer;
309- };
310-
311- auto Info = new DeleteCallbackInfo (USMFree, CLContext, HostBuffer);
300+ // This self destructs taking the event and allocation with it.
301+ auto Info = new AllocDeleterCallbackInfo (USMFree, CLContext, HostBuffer);
312302
313- auto DeleteCallback = [](cl_event, cl_int, void *pUserData) {
314- auto Info = static_cast <DeleteCallbackInfo *>(pUserData);
315- delete Info;
316- };
317-
318- ClErr = clSetEventCallback (CopyEvent, CL_COMPLETE, DeleteCallback, Info);
303+ ClErr =
304+ clSetEventCallback (CopyEvent, CL_COMPLETE, AllocDeleterCallback, Info);
319305 if (ClErr != CL_SUCCESS) {
320306 // We can attempt to recover gracefully by attempting to wait for the copy
321307 // to finish and deleting the info struct here.
@@ -325,9 +311,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
325311 CL_RETURN_ON_FAILURE (ClErr);
326312 }
327313 if (phEvent) {
314+ // Since we're releasing this in the callback above we need to retain it
315+ // here to keep the user copy alive.
316+ CL_RETURN_ON_FAILURE (clRetainEvent (CopyEvent));
328317 *phEvent = cl_adapter::cast<ur_event_handle_t >(CopyEvent);
329- } else {
330- CL_RETURN_ON_FAILURE (clReleaseEvent (CopyEvent));
331318 }
332319
333320 return UR_RESULT_SUCCESS;
@@ -347,20 +334,110 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
347334 return mapCLErrorToUR (CLErr);
348335 }
349336
350- clEnqueueMemcpyINTEL_fn FuncPtr = nullptr ;
351- ur_result_t RetVal = cl_ext::getExtFuncFromContext<clEnqueueMemcpyINTEL_fn>(
337+ clGetMemAllocInfoINTEL_fn GetMemAllocInfo = nullptr ;
338+ UR_RETURN_ON_FAILURE (cl_ext::getExtFuncFromContext<clGetMemAllocInfoINTEL_fn>(
339+ CLContext, cl_ext::ExtFuncPtrCache->clGetMemAllocInfoINTELCache ,
340+ cl_ext::GetMemAllocInfoName, &GetMemAllocInfo));
341+
342+ clEnqueueMemcpyINTEL_fn USMMemcpy = nullptr ;
343+ UR_RETURN_ON_FAILURE (cl_ext::getExtFuncFromContext<clEnqueueMemcpyINTEL_fn>(
352344 CLContext, cl_ext::ExtFuncPtrCache->clEnqueueMemcpyINTELCache ,
353- cl_ext::EnqueueMemcpyName, &FuncPtr );
345+ cl_ext::EnqueueMemcpyName, &USMMemcpy) );
354346
355- if (FuncPtr) {
356- RetVal = mapCLErrorToUR (
357- FuncPtr (cl_adapter::cast<cl_command_queue>(hQueue), blocking, pDst,
358- pSrc, size, numEventsInWaitList,
359- cl_adapter::cast<const cl_event *>(phEventWaitList),
360- cl_adapter::cast<cl_event *>(phEvent)));
347+ clMemBlockingFreeINTEL_fn USMFree = nullptr ;
348+ UR_RETURN_ON_FAILURE (cl_ext::getExtFuncFromContext<clMemBlockingFreeINTEL_fn>(
349+ CLContext, cl_ext::ExtFuncPtrCache->clMemBlockingFreeINTELCache ,
350+ cl_ext::MemBlockingFreeName, &USMFree));
351+
352+ // Check if the two allocations are DEVICE allocations from different
353+ // devices, if they are we need to do the copy indirectly via a host
354+ // allocation.
355+ cl_device_id SrcDevice = 0 , DstDevice = 0 ;
356+ CL_RETURN_ON_FAILURE (
357+ GetMemAllocInfo (CLContext, pSrc, CL_MEM_ALLOC_DEVICE_INTEL,
358+ sizeof (cl_device_id), &SrcDevice, nullptr ));
359+ CL_RETURN_ON_FAILURE (
360+ GetMemAllocInfo (CLContext, pDst, CL_MEM_ALLOC_DEVICE_INTEL,
361+ sizeof (cl_device_id), &SrcDevice, nullptr ));
362+
363+ if ((SrcDevice && DstDevice) && SrcDevice != DstDevice) {
364+ cl_event HostCopyEvent = nullptr , FinalCopyEvent = nullptr ;
365+ clHostMemAllocINTEL_fn HostMemAlloc = nullptr ;
366+ UR_RETURN_ON_FAILURE (cl_ext::getExtFuncFromContext<clHostMemAllocINTEL_fn>(
367+ CLContext, cl_ext::ExtFuncPtrCache->clHostMemAllocINTELCache ,
368+ cl_ext::HostMemAllocName, &HostMemAlloc));
369+
370+ auto HostAlloc = HostMemAlloc (CLContext, nullptr , size, 0 , &CLErr);
371+ CL_RETURN_ON_FAILURE (CLErr);
372+
373+ // Now that we've successfully allocated we should try to clean it up if we
374+ // hit an error somewhere.
375+ auto checkCLErr = [&](cl_int CLErr) -> ur_result_t {
376+ if (CLErr != CL_SUCCESS) {
377+ if (HostCopyEvent) {
378+ clReleaseEvent (HostCopyEvent);
379+ }
380+ if (FinalCopyEvent) {
381+ clReleaseEvent (FinalCopyEvent);
382+ }
383+ USMFree (CLContext, HostAlloc);
384+ CL_RETURN_ON_FAILURE (CLErr);
385+ }
386+ return UR_RESULT_SUCCESS;
387+ };
388+
389+ UR_RETURN_ON_FAILURE (checkCLErr (USMMemcpy (
390+ cl_adapter::cast<cl_command_queue>(hQueue), blocking, HostAlloc, pSrc,
391+ size, numEventsInWaitList,
392+ cl_adapter::cast<const cl_event *>(phEventWaitList), &HostCopyEvent)));
393+
394+ UR_RETURN_ON_FAILURE (checkCLErr (
395+ USMMemcpy (cl_adapter::cast<cl_command_queue>(hQueue), blocking, pDst,
396+ HostAlloc, size, 1 , &HostCopyEvent, &FinalCopyEvent)));
397+
398+ // If this is a blocking operation we can do our cleanup immediately,
399+ // otherwise we need to defer it to an event callback.
400+ if (blocking) {
401+ CL_RETURN_ON_FAILURE (USMFree (CLContext, HostAlloc));
402+ CL_RETURN_ON_FAILURE (clReleaseEvent (HostCopyEvent));
403+ if (phEvent) {
404+ *phEvent = cl_adapter::cast<ur_event_handle_t >(FinalCopyEvent);
405+ } else {
406+ CL_RETURN_ON_FAILURE (clReleaseEvent (FinalCopyEvent));
407+ }
408+ } else {
409+ if (phEvent) {
410+ *phEvent = cl_adapter::cast<ur_event_handle_t >(FinalCopyEvent);
411+ // We are going to release this event in our callback so we need to
412+ // retain if the user wants a copy.
413+ CL_RETURN_ON_FAILURE (clRetainEvent (FinalCopyEvent));
414+ }
415+
416+ // This self destructs taking the event and allocation with it.
417+ auto DeleterInfo =
418+ new AllocDeleterCallbackInfo{USMFree, CLContext, HostAlloc};
419+
420+ CLErr = clSetEventCallback (HostCopyEvent, CL_COMPLETE,
421+ AllocDeleterCallback, DeleterInfo);
422+
423+ if (CLErr != CL_SUCCESS) {
424+ // We can attempt to recover gracefully by attempting to wait for the
425+ // copy to finish and deleting the info struct here.
426+ clWaitForEvents (1 , &HostCopyEvent);
427+ delete DeleterInfo;
428+ clReleaseEvent (HostCopyEvent);
429+ CL_RETURN_ON_FAILURE (CLErr);
430+ }
431+ }
432+ } else {
433+ CL_RETURN_ON_FAILURE (
434+ USMMemcpy (cl_adapter::cast<cl_command_queue>(hQueue), blocking, pDst,
435+ pSrc, size, numEventsInWaitList,
436+ cl_adapter::cast<const cl_event *>(phEventWaitList),
437+ cl_adapter::cast<cl_event *>(phEvent)));
361438 }
362439
363- return RetVal ;
440+ return UR_RESULT_SUCCESS ;
364441}
365442
366443UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch (
0 commit comments