diff --git a/unified-runtime/source/adapters/offload/enqueue.cpp b/unified-runtime/source/adapters/offload/enqueue.cpp index 87bd45eb3f817..bef76e396b839 100644 --- a/unified-runtime/source/adapters/offload/enqueue.cpp +++ b/unified-runtime/source/adapters/offload/enqueue.cpp @@ -192,6 +192,21 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( return UR_RESULT_SUCCESS; } +UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill( + ur_queue_handle_t hQueue, void *pMem, size_t patternSize, + const void *pPattern, size_t size, uint32_t numEventsInWaitList, + const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { + ol_queue_handle_t Queue; + OL_RETURN_ON_ERR(hQueue->nextQueue(Queue)); + OL_RETURN_ON_ERR(waitOnEvents(Queue, phEventWaitList, numEventsInWaitList)); + + OL_RETURN_ON_ERR( + olMemFill(Queue, pMem, patternSize, const_cast(pPattern), size)); + OL_RETURN_ON_ERR(makeEvent(UR_COMMAND_USM_FILL, Queue, hQueue, phEvent)); + + return UR_RESULT_SUCCESS; +} + UR_APIEXPORT ur_result_t UR_APICALL urEnqueueUSMFill2D( ur_queue_handle_t, void *, size_t, size_t, const void *, size_t, size_t, uint32_t, const ur_event_handle_t *, ur_event_handle_t *) { @@ -279,6 +294,25 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy( phEventWaitList, phEvent); } +UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill( + ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, const void *pPattern, + size_t patternSize, size_t offset, size_t size, + uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, + ur_event_handle_t *phEvent) { + ol_queue_handle_t Queue; + OL_RETURN_ON_ERR(hQueue->nextQueue(Queue)); + OL_RETURN_ON_ERR(waitOnEvents(Queue, phEventWaitList, numEventsInWaitList)); + + char *DevPtr = + reinterpret_cast(std::get(hBuffer->Mem).Ptr); + + OL_RETURN_ON_ERR(olMemFill(Queue, DevPtr + offset, patternSize, + const_cast(pPattern), size)); + OL_RETURN_ON_ERR(makeEvent(UR_COMMAND_USM_FILL, Queue, hQueue, phEvent)); + + return UR_RESULT_SUCCESS; +} + UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableRead( ur_queue_handle_t hQueue, ur_program_handle_t hProgram, const char *name, bool blockingRead, size_t count, size_t offset, void *pDst, diff --git a/unified-runtime/source/adapters/offload/ur_interface_loader.cpp b/unified-runtime/source/adapters/offload/ur_interface_loader.cpp index 0324e3cedb48f..d08875d73401c 100644 --- a/unified-runtime/source/adapters/offload/ur_interface_loader.cpp +++ b/unified-runtime/source/adapters/offload/ur_interface_loader.cpp @@ -175,7 +175,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable( pDdiTable->pfnKernelLaunch = urEnqueueKernelLaunch; pDdiTable->pfnMemBufferCopy = urEnqueueMemBufferCopy; pDdiTable->pfnMemBufferCopyRect = nullptr; - pDdiTable->pfnMemBufferFill = nullptr; + pDdiTable->pfnMemBufferFill = urEnqueueMemBufferFill; pDdiTable->pfnMemBufferMap = urEnqueueMemBufferMap; pDdiTable->pfnMemBufferRead = urEnqueueMemBufferRead; pDdiTable->pfnMemBufferReadRect = nullptr; @@ -186,7 +186,7 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetEnqueueProcAddrTable( pDdiTable->pfnMemImageWrite = nullptr; pDdiTable->pfnMemUnmap = urEnqueueMemUnmap; pDdiTable->pfnUSMFill2D = urEnqueueUSMFill2D; - pDdiTable->pfnUSMFill = nullptr; + pDdiTable->pfnUSMFill = urEnqueueUSMFill; pDdiTable->pfnUSMAdvise = nullptr; pDdiTable->pfnUSMMemcpy2D = urEnqueueUSMMemcpy2D; pDdiTable->pfnUSMMemcpy = urEnqueueUSMMemcpy;