Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion unified-runtime/source/adapters/offload/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) {
// Subbuffers should not free their parents
if (!BufferImpl->Parent) {
// TODO: Handle registered host memory
OL_RETURN_ON_ERR(olMemFree(BufferImpl->Ptr));
if (hMem->IsNativeHandleOwned) {
OL_RETURN_ON_ERR(olMemFree(BufferImpl->Ptr));
}
} else {
return urMemRelease(BufferImpl->Parent);
}
Expand Down Expand Up @@ -143,6 +145,50 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferPartition(
return urMemRetain(hBuffer);
}

// Liboffload has no equivalent to buffers. Buffers are implemented as USM
UR_APIEXPORT ur_result_t UR_APICALL urMemGetNativeHandle(
ur_mem_handle_t hMem, ur_device_handle_t, ur_native_handle_t *phNativeMem) {
*phNativeMem = reinterpret_cast<ur_native_handle_t>(hMem->AsBufferMem()->Ptr);
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urMemBufferCreateWithNativeHandle(
ur_native_handle_t hNativeMem, ur_context_handle_t hContext,
const ur_mem_native_properties_t *pProperties, ur_mem_handle_t *phMem) {
void *Ptr = reinterpret_cast<void *>(hNativeMem);
ol_device_handle_t Device;
OL_RETURN_ON_ERR(
olGetMemInfo(Ptr, OL_MEM_INFO_DEVICE, sizeof(Device), &Device));
void *Base;
OL_RETURN_ON_ERR(olGetMemInfo(Ptr, OL_MEM_INFO_BASE, sizeof(Base), &Base));

// Check that this pointer is valid
if (Base != Ptr || Device != hContext->Device->OffloadDevice) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

size_t Size;
OL_RETURN_ON_ERR(olGetMemInfo(Ptr, OL_MEM_INFO_SIZE, sizeof(Size), &Size));

ol_alloc_type_t Type;
OL_RETURN_ON_ERR(olGetMemInfo(Ptr, OL_MEM_INFO_TYPE, sizeof(Type), &Type));

*phMem = new ur_mem_handle_t_{/*Context=*/hContext,
/*Parent=*/nullptr,
/*MemFlags=*/UR_MEM_FLAG_READ_WRITE,
/*Mode=*/
(Type == OL_ALLOC_TYPE_HOST
? BufferMem::AllocMode::AllocHostPtr
: BufferMem::AllocMode::Default),
/*Ptr=*/Ptr,
/*HostPtr=*/nullptr,
/*Size=*/Size};
(*phMem)->IsNativeHandleOwned =
pProperties ? pProperties->isNativeHandleOwned : false;

return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL
urMemImageCreate(ur_context_handle_t, ur_mem_flags_t, const ur_image_format_t *,
const ur_image_desc_t *, void *, ur_mem_handle_t *) {
Expand Down
4 changes: 2 additions & 2 deletions unified-runtime/source/adapters/offload/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
struct BufferMem {
enum class AllocMode {
Default,
UseHostPtr,
CopyIn,
AllocHostPtr,
};
Expand Down Expand Up @@ -93,6 +92,7 @@ struct ur_mem_handle_t_ : RefCounted {
ur_context_handle_t Context;

ur_mem_flags_t MemFlags;
bool IsNativeHandleOwned;

// For now we only support BufferMem. Eventually we'll support images, so use
// a variant to store the underlying object.
Expand All @@ -101,7 +101,7 @@ struct ur_mem_handle_t_ : RefCounted {
ur_mem_handle_t_(ur_context_handle_t Context, ur_mem_handle_t Parent,
ur_mem_flags_t MemFlags, BufferMem::AllocMode Mode,
void *Ptr, void *HostPtr, size_t Size)
: Context{Context}, MemFlags{MemFlags},
: Context{Context}, MemFlags{MemFlags}, IsNativeHandleOwned(true),
Mem{BufferMem{Parent, Mode, Ptr, HostPtr, Size}} {
urContextRetain(Context);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,11 @@ urGetMemProcAddrTable(ur_api_version_t version, ur_mem_dditable_t *pDdiTable) {
}
pDdiTable->pfnBufferCreate = urMemBufferCreate;
pDdiTable->pfnBufferPartition = urMemBufferPartition;
pDdiTable->pfnBufferCreateWithNativeHandle = nullptr;
pDdiTable->pfnBufferCreateWithNativeHandle =
urMemBufferCreateWithNativeHandle;
pDdiTable->pfnImageCreateWithNativeHandle = urMemImageCreateWithNativeHandle;
pDdiTable->pfnGetInfo = urMemGetInfo;
pDdiTable->pfnGetNativeHandle = nullptr;
pDdiTable->pfnGetNativeHandle = urMemGetNativeHandle;
pDdiTable->pfnImageCreate = urMemImageCreate;
pDdiTable->pfnImageGetInfo = urMemImageGetInfo;
pDdiTable->pfnRelease = urMemRelease;
Expand Down
Loading