Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 1 addition & 11 deletions unified-runtime/source/adapters/cuda/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,18 +348,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferPartition(
flags = UR_MEM_FLAG_READ_WRITE;
}

UR_ASSERT(!(flags &
(UR_MEM_FLAG_ALLOC_COPY_HOST_POINTER |
UR_MEM_FLAG_ALLOC_HOST_POINTER | UR_MEM_FLAG_USE_HOST_POINTER)),
UR_ASSERT(subBufferFlagsAreLegal(hBuffer->MemFlags, flags),
UR_RESULT_ERROR_INVALID_VALUE);
if (hBuffer->MemFlags & UR_MEM_FLAG_WRITE_ONLY) {
UR_ASSERT(!(flags & (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_READ_ONLY)),
UR_RESULT_ERROR_INVALID_VALUE);
}
if (hBuffer->MemFlags & UR_MEM_FLAG_READ_ONLY) {
UR_ASSERT(!(flags & (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)),
UR_RESULT_ERROR_INVALID_VALUE);
}

auto &BufferImpl = std::get<BufferMem>(hBuffer->Mem);
UR_ASSERT(((pRegion->origin + pRegion->size) <= BufferImpl.getSize()),
Expand Down
12 changes: 1 addition & 11 deletions unified-runtime/source/adapters/hip/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,18 +162,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferPartition(
flags = UR_MEM_FLAG_READ_WRITE;
}

UR_ASSERT(!(flags &
(UR_MEM_FLAG_ALLOC_COPY_HOST_POINTER |
UR_MEM_FLAG_ALLOC_HOST_POINTER | UR_MEM_FLAG_USE_HOST_POINTER)),
UR_ASSERT(subBufferFlagsAreLegal(hBuffer->MemFlags, flags),
UR_RESULT_ERROR_INVALID_VALUE);
if (hBuffer->MemFlags & UR_MEM_FLAG_WRITE_ONLY) {
UR_ASSERT(!(flags & (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_READ_ONLY)),
UR_RESULT_ERROR_INVALID_VALUE);
}
if (hBuffer->MemFlags & UR_MEM_FLAG_READ_ONLY) {
UR_ASSERT(!(flags & (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)),
UR_RESULT_ERROR_INVALID_VALUE);
}

auto &BufferImpl = std::get<BufferMem>(hBuffer->Mem);
UR_ASSERT(((pRegion->origin + pRegion->size) <= BufferImpl.getSize()),
Expand Down
43 changes: 39 additions & 4 deletions unified-runtime/source/adapters/offload/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemRelease(ur_mem_handle_t hMem) {
}

std::unique_ptr<ur_mem_handle_t_> MemObjPtr(hMem);
if (hMem->MemType == ur_mem_handle_t_::Type::Buffer) {
// TODO: Handle registered host memory
auto &BufferImpl = std::get<BufferMem>(MemObjPtr->Mem);
OL_RETURN_ON_ERR(olMemFree(BufferImpl.Ptr));
if (auto *BufferImpl = MemObjPtr->AsBufferMem()) {
// Subbuffers should not free their parents
if (!BufferImpl->Parent) {
// TODO: Handle registered host memory
OL_RETURN_ON_ERR(olMemFree(BufferImpl->Ptr));
} else {
return urMemRelease(BufferImpl->Parent);
}
}

return UR_RESULT_SUCCESS;
Expand Down Expand Up @@ -107,3 +111,34 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemGetInfo(ur_mem_handle_t hMemory,
return UR_RESULT_ERROR_INVALID_ENUMERATION;
}
}

UR_APIEXPORT ur_result_t UR_APICALL urMemBufferPartition(
ur_mem_handle_t hBuffer, ur_mem_flags_t flags,
ur_buffer_create_type_t /*BufferCreateType*/,
const ur_buffer_region_t *pRegion, ur_mem_handle_t *phMem) {
auto *SrcBuffer = hBuffer->AsBufferMem();
if (!SrcBuffer || SrcBuffer->Parent) {
return UR_RESULT_ERROR_INVALID_VALUE;
}

// Default value for flags means UR_MEM_FLAG_READ_WRITE.
if (flags == 0) {
flags = UR_MEM_FLAG_READ_WRITE;
}
UR_ASSERT(subBufferFlagsAreLegal(hBuffer->MemFlags, flags),
UR_RESULT_ERROR_INVALID_VALUE);

UR_ASSERT(((pRegion->origin + pRegion->size) <= SrcBuffer->getSize()),
UR_RESULT_ERROR_INVALID_BUFFER_SIZE);

void *DeviceBase =
reinterpret_cast<uint8_t *>(SrcBuffer->Ptr) + pRegion->origin;
void *HostBase =
reinterpret_cast<uint8_t *>(SrcBuffer->HostPtr) + pRegion->origin;
auto URMemObj = std::unique_ptr<ur_mem_handle_t_>(new ur_mem_handle_t_{
hBuffer->getContext(), hBuffer, flags, SrcBuffer->MemAllocMode,
DeviceBase, HostBase, pRegion->size});
*phMem = URMemObj.release();

return urMemRetain(hBuffer);
}
10 changes: 8 additions & 2 deletions unified-runtime/source/adapters/offload/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ struct BufferMem {
struct ur_mem_handle_t_ : RefCounted {
ur_context_handle_t Context;

enum class Type { Buffer } MemType;
ur_mem_flags_t MemFlags;

// For now we only support BufferMem. Eventually we'll support images, so use
Expand All @@ -102,12 +101,19 @@ struct ur_mem_handle_t_ : RefCounted {
ur_mem_handle_t_(ur_context_handle_t Context, ur_mem_handle_t Parent,
ur_mem_flags_t MemFlags, BufferMem::AllocMode Mode,
void *Ptr, void *HostPtr, size_t Size)
: Context{Context}, MemType{Type::Buffer}, MemFlags{MemFlags},
: Context{Context}, MemFlags{MemFlags},
Mem{BufferMem{Parent, Mode, Ptr, HostPtr, Size}} {
urContextRetain(Context);
};

~ur_mem_handle_t_() { urContextRelease(Context); }

ur_context_handle_t getContext() const noexcept { return Context; }

BufferMem *AsBufferMem() noexcept {
if (std::holds_alternative<BufferMem>(Mem)) {
return &std::get<BufferMem>(Mem);
}
return nullptr;
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ urGetMemProcAddrTable(ur_api_version_t version, ur_mem_dditable_t *pDdiTable) {
return result;
}
pDdiTable->pfnBufferCreate = urMemBufferCreate;
pDdiTable->pfnBufferPartition = nullptr;
pDdiTable->pfnBufferPartition = urMemBufferPartition;
pDdiTable->pfnBufferCreateWithNativeHandle = nullptr;
pDdiTable->pfnImageCreateWithNativeHandle = nullptr;
pDdiTable->pfnGetInfo = urMemGetInfo;
Expand Down
17 changes: 17 additions & 0 deletions unified-runtime/source/common/ur_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,4 +568,21 @@ std::array<T, N> createArrayOf(F &&ctor) {
std::make_index_sequence<N>{});
}

// Helper function for `urMemBufferPartition`
//
// Returns true if and only if `child`'s flags are compatible with `parent`'s.
inline bool subBufferFlagsAreLegal(ur_mem_flags_t parent,
ur_mem_flags_t child) {
if (child & (UR_MEM_FLAG_ALLOC_COPY_HOST_POINTER |
UR_MEM_FLAG_ALLOC_HOST_POINTER | UR_MEM_FLAG_USE_HOST_POINTER))
return false;
if (parent & UR_MEM_FLAG_WRITE_ONLY &&
(child & (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_READ_ONLY)))
return false;
if (parent & UR_MEM_FLAG_READ_ONLY &&
(child & (UR_MEM_FLAG_READ_WRITE | UR_MEM_FLAG_WRITE_ONLY)))
return false;
return true;
}

#endif /* UR_UTIL_H */