diff --git a/unified-runtime/source/adapters/offload/memory.cpp b/unified-runtime/source/adapters/offload/memory.cpp index 1a5b0df330c23..569e244037cbc 100644 --- a/unified-runtime/source/adapters/offload/memory.cpp +++ b/unified-runtime/source/adapters/offload/memory.cpp @@ -80,7 +80,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) { // Subbuffers should not free their parents if (!BufferImpl->Parent) { // TODO: Handle registered host memory - OL_RETURN_ON_ERR(olMemFree(BufferImpl->Ptr)); + if (hMem->IsNativeHandleOwned) { + OL_RETURN_ON_ERR(olMemFree(BufferImpl->Ptr)); + } } else { return urMemRelease(BufferImpl->Parent); } @@ -143,6 +145,50 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferPartition( return urMemRetain(hBuffer); } +// Liboffload has no equivalent to buffers. Buffers are implemented as USM +UR_APIEXPORT ur_result_t UR_APICALL urMemGetNativeHandle( + ur_mem_handle_t hMem, ur_device_handle_t, ur_native_handle_t *phNativeMem) { + *phNativeMem = reinterpret_cast(hMem->AsBufferMem()->Ptr); + return UR_RESULT_SUCCESS; +} + +UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreateWithNativeHandle( + ur_native_handle_t hNativeMem, ur_context_handle_t hContext, + const ur_mem_native_properties_t *pProperties, ur_mem_handle_t *phMem) { + void *Ptr = reinterpret_cast(hNativeMem); + ol_device_handle_t Device; + OL_RETURN_ON_ERR( + olGetMemInfo(Ptr, OL_MEM_INFO_DEVICE, sizeof(Device), &Device)); + void *Base; + OL_RETURN_ON_ERR(olGetMemInfo(Ptr, OL_MEM_INFO_BASE, sizeof(Base), &Base)); + + // Check that this pointer is valid + if (Base != Ptr || Device != hContext->Device->OffloadDevice) { + return UR_RESULT_ERROR_UNSUPPORTED_FEATURE; + } + + size_t Size; + OL_RETURN_ON_ERR(olGetMemInfo(Ptr, OL_MEM_INFO_SIZE, sizeof(Size), &Size)); + + ol_alloc_type_t Type; + OL_RETURN_ON_ERR(olGetMemInfo(Ptr, OL_MEM_INFO_TYPE, sizeof(Type), &Type)); + + *phMem = new ur_mem_handle_t_{/*Context=*/hContext, + /*Parent=*/nullptr, + /*MemFlags=*/UR_MEM_FLAG_READ_WRITE, + /*Mode=*/ + (Type == OL_ALLOC_TYPE_HOST + ? BufferMem::AllocMode::AllocHostPtr + : BufferMem::AllocMode::Default), + /*Ptr=*/Ptr, + /*HostPtr=*/nullptr, + /*Size=*/Size}; + (*phMem)->IsNativeHandleOwned = + pProperties ? pProperties->isNativeHandleOwned : false; + + return UR_RESULT_SUCCESS; +} + UR_APIEXPORT ur_result_t UR_APICALL urMemImageCreate(ur_context_handle_t, ur_mem_flags_t, const ur_image_format_t *, const ur_image_desc_t *, void *, ur_mem_handle_t *) { diff --git a/unified-runtime/source/adapters/offload/memory.hpp b/unified-runtime/source/adapters/offload/memory.hpp index 59b62ea12961a..c2850355c7e72 100644 --- a/unified-runtime/source/adapters/offload/memory.hpp +++ b/unified-runtime/source/adapters/offload/memory.hpp @@ -17,7 +17,6 @@ struct BufferMem { enum class AllocMode { Default, - UseHostPtr, CopyIn, AllocHostPtr, }; @@ -93,6 +92,7 @@ struct ur_mem_handle_t_ : RefCounted { ur_context_handle_t Context; ur_mem_flags_t MemFlags; + bool IsNativeHandleOwned; // For now we only support BufferMem. Eventually we'll support images, so use // a variant to store the underlying object. @@ -101,7 +101,7 @@ struct ur_mem_handle_t_ : RefCounted { ur_mem_handle_t_(ur_context_handle_t Context, ur_mem_handle_t Parent, ur_mem_flags_t MemFlags, BufferMem::AllocMode Mode, void *Ptr, void *HostPtr, size_t Size) - : Context{Context}, MemFlags{MemFlags}, + : Context{Context}, MemFlags{MemFlags}, IsNativeHandleOwned(true), Mem{BufferMem{Parent, Mode, Ptr, HostPtr, Size}} { urContextRetain(Context); }; diff --git a/unified-runtime/source/adapters/offload/ur_interface_loader.cpp b/unified-runtime/source/adapters/offload/ur_interface_loader.cpp index fa3adfdaf92bd..04259b22661a8 100644 --- a/unified-runtime/source/adapters/offload/ur_interface_loader.cpp +++ b/unified-runtime/source/adapters/offload/ur_interface_loader.cpp @@ -151,10 +151,11 @@ urGetMemProcAddrTable(ur_api_version_t version, ur_mem_dditable_t *pDdiTable) { } pDdiTable->pfnBufferCreate = urMemBufferCreate; pDdiTable->pfnBufferPartition = urMemBufferPartition; - pDdiTable->pfnBufferCreateWithNativeHandle = nullptr; + pDdiTable->pfnBufferCreateWithNativeHandle = + urMemBufferCreateWithNativeHandle; pDdiTable->pfnImageCreateWithNativeHandle = urMemImageCreateWithNativeHandle; pDdiTable->pfnGetInfo = urMemGetInfo; - pDdiTable->pfnGetNativeHandle = nullptr; + pDdiTable->pfnGetNativeHandle = urMemGetNativeHandle; pDdiTable->pfnImageCreate = urMemImageCreate; pDdiTable->pfnImageGetInfo = urMemImageGetInfo; pDdiTable->pfnRelease = urMemRelease;