1111#include < ur/ur.hpp>
1212
1313#include " common.hpp"
14+ #include " usm.hpp"
15+
16+ template <class T >
17+ void AllocDeleterCallback (cl_event event, cl_int, void *pUserData) {
18+ clReleaseEvent (event);
19+ auto Info = static_cast <T *>(pUserData);
20+ delete Info;
21+ }
1422
1523namespace umf {
1624ur_result_t getProviderNativeError (const char *, int32_t ) {
@@ -312,32 +320,19 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
312320 numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
313321 &CopyEvent));
314322
315- struct DeleteCallbackInfo {
316- DeleteCallbackInfo (clMemBlockingFreeINTEL_fn USMFree, cl_context CLContext,
317- void *HostBuffer)
318- : USMFree(USMFree), CLContext(CLContext), HostBuffer(HostBuffer) {
319- clRetainContext (CLContext);
320- }
321- ~DeleteCallbackInfo () {
322- USMFree (CLContext, HostBuffer);
323- clReleaseContext (CLContext);
324- }
325- DeleteCallbackInfo (const DeleteCallbackInfo &) = delete ;
326- DeleteCallbackInfo &operator =(const DeleteCallbackInfo &) = delete ;
327-
328- clMemBlockingFreeINTEL_fn USMFree;
329- cl_context CLContext;
330- void *HostBuffer;
331- };
332-
333- auto Info = new DeleteCallbackInfo (USMFree, CLContext, HostBuffer);
323+ if (phEvent) {
324+ // Since we're releasing this in the callback above we need to retain it
325+ // here to keep the user copy alive.
326+ CL_RETURN_ON_FAILURE (clRetainEvent (CopyEvent));
327+ *phEvent = cl_adapter::cast<ur_event_handle_t >(CopyEvent);
328+ }
334329
335- auto DeleteCallback = [](cl_event, cl_int, void *pUserData) {
336- auto Info = static_cast <DeleteCallbackInfo *>(pUserData);
337- delete Info;
338- };
330+ // This self destructs taking the event and allocation with it.
331+ auto Info = new AllocDeleterCallbackInfo (USMFree, CLContext, HostBuffer);
339332
340- ClErr = clSetEventCallback (CopyEvent, CL_COMPLETE, DeleteCallback, Info);
333+ ClErr =
334+ clSetEventCallback (CopyEvent, CL_COMPLETE,
335+ AllocDeleterCallback<AllocDeleterCallbackInfo>, Info);
341336 if (ClErr != CL_SUCCESS) {
342337 // We can attempt to recover gracefully by attempting to wait for the copy
343338 // to finish and deleting the info struct here.
@@ -346,11 +341,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill(
346341 clReleaseEvent (CopyEvent);
347342 CL_RETURN_ON_FAILURE (ClErr);
348343 }
349- if (phEvent) {
350- *phEvent = cl_adapter::cast<ur_event_handle_t >(CopyEvent);
351- } else {
352- CL_RETURN_ON_FAILURE (clReleaseEvent (CopyEvent));
353- }
354344
355345 return UR_RESULT_SUCCESS;
356346}
@@ -369,20 +359,131 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMMemcpy(
369359 return mapCLErrorToUR (CLErr);
370360 }
371361
372- clEnqueueMemcpyINTEL_fn FuncPtr = nullptr ;
373- ur_result_t RetVal = cl_ext::getExtFuncFromContext<clEnqueueMemcpyINTEL_fn>(
362+ clGetMemAllocInfoINTEL_fn GetMemAllocInfo = nullptr ;
363+ UR_RETURN_ON_FAILURE (cl_ext::getExtFuncFromContext<clGetMemAllocInfoINTEL_fn>(
364+ CLContext, cl_ext::ExtFuncPtrCache->clGetMemAllocInfoINTELCache ,
365+ cl_ext::GetMemAllocInfoName, &GetMemAllocInfo));
366+
367+ clEnqueueMemcpyINTEL_fn USMMemcpy = nullptr ;
368+ UR_RETURN_ON_FAILURE (cl_ext::getExtFuncFromContext<clEnqueueMemcpyINTEL_fn>(
374369 CLContext, cl_ext::ExtFuncPtrCache->clEnqueueMemcpyINTELCache ,
375- cl_ext::EnqueueMemcpyName, &FuncPtr );
370+ cl_ext::EnqueueMemcpyName, &USMMemcpy) );
376371
377- if (FuncPtr) {
378- RetVal = mapCLErrorToUR (
379- FuncPtr (cl_adapter::cast<cl_command_queue>(hQueue), blocking, pDst,
380- pSrc, size, numEventsInWaitList,
381- cl_adapter::cast<const cl_event *>(phEventWaitList),
382- cl_adapter::cast<cl_event *>(phEvent)));
372+ clMemBlockingFreeINTEL_fn USMFree = nullptr ;
373+ UR_RETURN_ON_FAILURE (cl_ext::getExtFuncFromContext<clMemBlockingFreeINTEL_fn>(
374+ CLContext, cl_ext::ExtFuncPtrCache->clMemBlockingFreeINTELCache ,
375+ cl_ext::MemBlockingFreeName, &USMFree));
376+
377+ // Check if the two allocations are DEVICE allocations from different
378+ // devices, if they are we need to do the copy indirectly via a host
379+ // allocation.
380+ cl_device_id SrcDevice = 0 , DstDevice = 0 ;
381+ CL_RETURN_ON_FAILURE (
382+ GetMemAllocInfo (CLContext, pSrc, CL_MEM_ALLOC_DEVICE_INTEL,
383+ sizeof (cl_device_id), &SrcDevice, nullptr ));
384+ CL_RETURN_ON_FAILURE (
385+ GetMemAllocInfo (CLContext, pDst, CL_MEM_ALLOC_DEVICE_INTEL,
386+ sizeof (cl_device_id), &DstDevice, nullptr ));
387+
388+ if ((SrcDevice && DstDevice) && SrcDevice != DstDevice) {
389+ // We need a queue associated with each device, so first figure out which
390+ // one we weren't given.
391+ cl_device_id QueueDevice = nullptr ;
392+ CL_RETURN_ON_FAILURE (clGetCommandQueueInfo (
393+ cl_adapter::cast<cl_command_queue>(hQueue), CL_QUEUE_DEVICE,
394+ sizeof (QueueDevice), &QueueDevice, nullptr ));
395+
396+ cl_command_queue MissingQueue = nullptr , SrcQueue = nullptr ,
397+ DstQueue = nullptr ;
398+ if (QueueDevice == SrcDevice) {
399+ MissingQueue = clCreateCommandQueue (CLContext, DstDevice, 0 , &CLErr);
400+ SrcQueue = cl_adapter::cast<cl_command_queue>(hQueue);
401+ DstQueue = MissingQueue;
402+ } else {
403+ MissingQueue = clCreateCommandQueue (CLContext, SrcDevice, 0 , &CLErr);
404+ DstQueue = cl_adapter::cast<cl_command_queue>(hQueue);
405+ SrcQueue = MissingQueue;
406+ }
407+ CL_RETURN_ON_FAILURE (CLErr);
408+
409+ cl_event HostCopyEvent = nullptr , FinalCopyEvent = nullptr ;
410+ clHostMemAllocINTEL_fn HostMemAlloc = nullptr ;
411+ UR_RETURN_ON_FAILURE (cl_ext::getExtFuncFromContext<clHostMemAllocINTEL_fn>(
412+ CLContext, cl_ext::ExtFuncPtrCache->clHostMemAllocINTELCache ,
413+ cl_ext::HostMemAllocName, &HostMemAlloc));
414+
415+ auto HostAlloc = HostMemAlloc (CLContext, nullptr , size, 0 , &CLErr);
416+ CL_RETURN_ON_FAILURE (CLErr);
417+
418+ // Now that we've successfully allocated we should try to clean it up if we
419+ // hit an error somewhere.
420+ auto checkCLErr = [&](cl_int CLErr) -> ur_result_t {
421+ if (CLErr != CL_SUCCESS) {
422+ if (HostCopyEvent) {
423+ clReleaseEvent (HostCopyEvent);
424+ }
425+ if (FinalCopyEvent) {
426+ clReleaseEvent (FinalCopyEvent);
427+ }
428+ USMFree (CLContext, HostAlloc);
429+ CL_RETURN_ON_FAILURE (CLErr);
430+ }
431+ return UR_RESULT_SUCCESS;
432+ };
433+
434+ UR_RETURN_ON_FAILURE (checkCLErr (USMMemcpy (
435+ SrcQueue, blocking, HostAlloc, pSrc, size, numEventsInWaitList,
436+ cl_adapter::cast<const cl_event *>(phEventWaitList), &HostCopyEvent)));
437+
438+ UR_RETURN_ON_FAILURE (
439+ checkCLErr (USMMemcpy (DstQueue, blocking, pDst, HostAlloc, size, 1 ,
440+ &HostCopyEvent, &FinalCopyEvent)));
441+
442+ // If this is a blocking operation we can do our cleanup immediately,
443+ // otherwise we need to defer it to an event callback.
444+ if (blocking) {
445+ CL_RETURN_ON_FAILURE (USMFree (CLContext, HostAlloc));
446+ CL_RETURN_ON_FAILURE (clReleaseEvent (HostCopyEvent));
447+ CL_RETURN_ON_FAILURE (clReleaseCommandQueue (MissingQueue));
448+ if (phEvent) {
449+ *phEvent = cl_adapter::cast<ur_event_handle_t >(FinalCopyEvent);
450+ } else {
451+ CL_RETURN_ON_FAILURE (clReleaseEvent (FinalCopyEvent));
452+ }
453+ } else {
454+ if (phEvent) {
455+ *phEvent = cl_adapter::cast<ur_event_handle_t >(FinalCopyEvent);
456+ // We are going to release this event in our callback so we need to
457+ // retain if the user wants a copy.
458+ CL_RETURN_ON_FAILURE (clRetainEvent (FinalCopyEvent));
459+ }
460+
461+ // This self destructs taking the event and allocation with it.
462+ auto DeleterInfo = new AllocDeleterCallbackInfoWithQueue (
463+ USMFree, CLContext, HostAlloc, MissingQueue);
464+
465+ CLErr = clSetEventCallback (
466+ HostCopyEvent, CL_COMPLETE,
467+ AllocDeleterCallback<AllocDeleterCallbackInfoWithQueue>, DeleterInfo);
468+
469+ if (CLErr != CL_SUCCESS) {
470+ // We can attempt to recover gracefully by attempting to wait for the
471+ // copy to finish and deleting the info struct here.
472+ clWaitForEvents (1 , &HostCopyEvent);
473+ delete DeleterInfo;
474+ clReleaseEvent (HostCopyEvent);
475+ CL_RETURN_ON_FAILURE (CLErr);
476+ }
477+ }
478+ } else {
479+ CL_RETURN_ON_FAILURE (
480+ USMMemcpy (cl_adapter::cast<cl_command_queue>(hQueue), blocking, pDst,
481+ pSrc, size, numEventsInWaitList,
482+ cl_adapter::cast<const cl_event *>(phEventWaitList),
483+ cl_adapter::cast<cl_event *>(phEvent)));
383484 }
384485
385- return RetVal ;
486+ return UR_RESULT_SUCCESS ;
386487}
387488
388489UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMPrefetch (
0 commit comments