Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions offload/liboffload/API/Memory.td
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def ol_alloc_type_t : Enum {

def olMemAlloc : Function {
let desc = "Creates a memory allocation on the specified device.";
let details = [
"All allocations through olMemAlloc regardless of source share a single virtual address range. There is no risk of multiple devices returning equal pointers to different memory."
];
let params = [
Param<"ol_device_handle_t", "Device", "handle of the device to allocate on", PARAM_IN>,
Param<"ol_alloc_type_t", "Type", "type of the allocation", PARAM_IN>,
Expand Down
69 changes: 59 additions & 10 deletions offload/liboffload/src/OffloadImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,9 @@ namespace offload {
struct AllocInfo {
ol_device_handle_t Device;
ol_alloc_type_t Type;
void *Start;
// One byte past the end
void *End;
};

// Global shared state for liboffload
Expand All @@ -200,6 +203,9 @@ struct OffloadContext {
bool ValidationEnabled = true;
DenseMap<void *, AllocInfo> AllocInfoMap{};
std::mutex AllocInfoMapMutex{};
// Partitioned list of memory base addresses. Each element in this list is a
// key in AllocInfoMap
llvm::SmallVector<void *> AllocBases{};
SmallVector<ol_platform_impl_t, 4> Platforms{};
size_t RefCount;

Expand Down Expand Up @@ -615,18 +621,58 @@ TargetAllocTy convertOlToPluginAllocTy(ol_alloc_type_t Type) {

Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type,
size_t Size, void **AllocationOut) {
auto Alloc =
Device->Device->dataAlloc(Size, nullptr, convertOlToPluginAllocTy(Type));
if (!Alloc)
return Alloc.takeError();
void *OldAlloc = nullptr;

// Repeat the allocation up to a certain amount of times. If it happens to
// already be allocated (e.g. by a device from another vendor) throw it away
// and try again.
for (size_t Count = 0; Count < 10; Count++) {
auto NewAlloc = Device->Device->dataAlloc(Size, nullptr,
convertOlToPluginAllocTy(Type));
if (!NewAlloc)
return NewAlloc.takeError();

if (OldAlloc)
if (auto Err = Device->Device->dataDelete(OldAlloc,
convertOlToPluginAllocTy(Type)))
return Err;

*AllocationOut = *Alloc;
{
std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex);
OffloadContext::get().AllocInfoMap.insert_or_assign(
*Alloc, AllocInfo{Device, Type});
void *NewEnd = &static_cast<char *>(*NewAlloc)[Size];
auto &AllocBases = OffloadContext::get().AllocBases;
auto &AllocInfoMap = OffloadContext::get().AllocInfoMap;
{
std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex);

// Check that this memory region doesn't overlap another one
// That is, the start of this allocation needs to be after another
// allocation's end point, and the end of this allocation needs to be
// before the next one's start.
// `Gap` is the first alloc who ends after the new alloc's start point.
auto Gap =
std::lower_bound(AllocBases.begin(), AllocBases.end(), *NewAlloc,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to check bounds again? I thought that for the purposes of olMemFree it only mattered that we had something like this, where all that matters is they're unique.

Map<void *, Platform> Map;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless I'm misremembering plans, in the future olGetMemInfo will accept a pointer anywhere into any allocation allocated by liboffload. This means that we need to ensure that no part of the buffers overlap, rather than just the start.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, good point. It definitely make this a bit more restrictive and expensive if virtual addresses cannot overlap at all, versus just a single pointer passed to free. This at least means that we'll need this range-based table anyway, I suppose it's fine.

[&](const void *Iter, const void *Val) {
return AllocInfoMap.at(Iter).End <= Val;
});
if (Gap == AllocBases.end() || NewEnd <= AllocInfoMap.at(*Gap).Start) {
// Success, no conflict
AllocInfoMap.insert_or_assign(
*NewAlloc, AllocInfo{Device, Type, *NewAlloc, NewEnd});
AllocBases.insert(
std::lower_bound(AllocBases.begin(), AllocBases.end(), *NewAlloc),
*NewAlloc);
*AllocationOut = *NewAlloc;
return Error::success();
}

// To avoid the next attempt allocating the same memory we just freed, we
// hold onto it until we complete the next allocation
OldAlloc = *NewAlloc;
}
}
return Error::success();

// We've tried multiple times, and can't allocate a non-overlapping region.
return createOffloadError(ErrorCode::BACKEND_FAILURE,
"failed to allocate non-overlapping memory");
}

Error olMemFree_impl(void *Address) {
Expand All @@ -642,6 +688,9 @@ Error olMemFree_impl(void *Address) {
Device = AllocInfo.Device;
Type = AllocInfo.Type;
OffloadContext::get().AllocInfoMap.erase(Address);

auto &Bases = OffloadContext::get().AllocBases;
Bases.erase(std::lower_bound(Bases.begin(), Bases.end(), Address));
}

if (auto Res =
Expand Down
20 changes: 20 additions & 0 deletions offload/unittests/OffloadAPI/memory/olMemAlloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,26 @@ TEST_P(olMemAllocTest, SuccessAllocDevice) {
olMemFree(Alloc);
}

TEST_P(olMemAllocTest, SuccessAllocMany) {
std::vector<void *> Allocs;
Allocs.reserve(1000);

constexpr ol_alloc_type_t TYPES[3] = {
OL_ALLOC_TYPE_DEVICE, OL_ALLOC_TYPE_MANAGED, OL_ALLOC_TYPE_HOST};

for (size_t I = 1; I < 1000; I++) {
void *Alloc = nullptr;
ASSERT_SUCCESS(olMemAlloc(Device, TYPES[I % 3], 1024 * I, &Alloc));
ASSERT_NE(Alloc, nullptr);

Allocs.push_back(Alloc);
}

for (auto *A : Allocs) {
olMemFree(A);
}
}

TEST_P(olMemAllocTest, InvalidNullDevice) {
void *Alloc = nullptr;
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
Expand Down
Loading