Skip to content

Commit fcebe6b

Browse files
authored
[Offload] Re-allocate overlapping memory (#159567)
If olMemAlloc happens to allocate memory that was already allocated elsewhere (possibly by another device on another platform), it is now thrown away and a new allocation generated. A new `AllocBases` vector is now available, which is an ordered list of allocation start addresses.
1 parent 3826849 commit fcebe6b

File tree

3 files changed

+83
-10
lines changed

3 files changed

+83
-10
lines changed

offload/liboffload/API/Memory.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ def ol_alloc_type_t : Enum {
2121

2222
def olMemAlloc : Function {
2323
let desc = "Creates a memory allocation on the specified device.";
24+
let details = [
25+
"All allocations through olMemAlloc regardless of source share a single virtual address range. There is no risk of multiple devices returning equal pointers to different memory."
26+
];
2427
let params = [
2528
Param<"ol_device_handle_t", "Device", "handle of the device to allocate on", PARAM_IN>,
2629
Param<"ol_alloc_type_t", "Type", "type of the allocation", PARAM_IN>,

offload/liboffload/src/OffloadImpl.cpp

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,9 @@ namespace offload {
182182
struct AllocInfo {
183183
ol_device_handle_t Device;
184184
ol_alloc_type_t Type;
185+
void *Start;
186+
// One byte past the end
187+
void *End;
185188
};
186189

187190
// Global shared state for liboffload
@@ -200,6 +203,9 @@ struct OffloadContext {
200203
bool ValidationEnabled = true;
201204
DenseMap<void *, AllocInfo> AllocInfoMap{};
202205
std::mutex AllocInfoMapMutex{};
206+
// Partitioned list of memory base addresses. Each element in this list is a
207+
// key in AllocInfoMap
208+
llvm::SmallVector<void *> AllocBases{};
203209
SmallVector<ol_platform_impl_t, 4> Platforms{};
204210
size_t RefCount;
205211

@@ -613,20 +619,61 @@ TargetAllocTy convertOlToPluginAllocTy(ol_alloc_type_t Type) {
613619
}
614620
}
615621

622+
constexpr size_t MAX_ALLOC_TRIES = 50;
616623
Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type,
617624
size_t Size, void **AllocationOut) {
618-
auto Alloc =
619-
Device->Device->dataAlloc(Size, nullptr, convertOlToPluginAllocTy(Type));
620-
if (!Alloc)
621-
return Alloc.takeError();
625+
SmallVector<void *> Rejects;
626+
627+
// Repeat the allocation up to a certain amount of times. If it happens to
628+
// already be allocated (e.g. by a device from another vendor) throw it away
629+
// and try again.
630+
for (size_t Count = 0; Count < MAX_ALLOC_TRIES; Count++) {
631+
auto NewAlloc = Device->Device->dataAlloc(Size, nullptr,
632+
convertOlToPluginAllocTy(Type));
633+
if (!NewAlloc)
634+
return NewAlloc.takeError();
635+
636+
void *NewEnd = &static_cast<char *>(*NewAlloc)[Size];
637+
auto &AllocBases = OffloadContext::get().AllocBases;
638+
auto &AllocInfoMap = OffloadContext::get().AllocInfoMap;
639+
{
640+
std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex);
641+
642+
// Check that this memory region doesn't overlap another one
643+
// That is, the start of this allocation needs to be after another
644+
// allocation's end point, and the end of this allocation needs to be
645+
// before the next one's start.
646+
// `Gap` is the first alloc who ends after the new alloc's start point.
647+
auto Gap =
648+
std::lower_bound(AllocBases.begin(), AllocBases.end(), *NewAlloc,
649+
[&](const void *Iter, const void *Val) {
650+
return AllocInfoMap.at(Iter).End <= Val;
651+
});
652+
if (Gap == AllocBases.end() || NewEnd <= AllocInfoMap.at(*Gap).Start) {
653+
// Success, no conflict
654+
AllocInfoMap.insert_or_assign(
655+
*NewAlloc, AllocInfo{Device, Type, *NewAlloc, NewEnd});
656+
AllocBases.insert(
657+
std::lower_bound(AllocBases.begin(), AllocBases.end(), *NewAlloc),
658+
*NewAlloc);
659+
*AllocationOut = *NewAlloc;
660+
661+
for (void *R : Rejects)
662+
if (auto Err =
663+
Device->Device->dataDelete(R, convertOlToPluginAllocTy(Type)))
664+
return Err;
665+
return Error::success();
666+
}
622667

623-
*AllocationOut = *Alloc;
624-
{
625-
std::lock_guard<std::mutex> Lock(OffloadContext::get().AllocInfoMapMutex);
626-
OffloadContext::get().AllocInfoMap.insert_or_assign(
627-
*Alloc, AllocInfo{Device, Type});
668+
// To avoid the next attempt allocating the same memory we just freed, we
669+
// hold onto it until we complete the allocation
670+
Rejects.push_back(*NewAlloc);
671+
}
628672
}
629-
return Error::success();
673+
674+
// We've tried multiple times, and can't allocate a non-overlapping region.
675+
return createOffloadError(ErrorCode::BACKEND_FAILURE,
676+
"failed to allocate non-overlapping memory");
630677
}
631678

632679
Error olMemFree_impl(void *Address) {
@@ -642,6 +689,9 @@ Error olMemFree_impl(void *Address) {
642689
Device = AllocInfo.Device;
643690
Type = AllocInfo.Type;
644691
OffloadContext::get().AllocInfoMap.erase(Address);
692+
693+
auto &Bases = OffloadContext::get().AllocBases;
694+
Bases.erase(std::lower_bound(Bases.begin(), Bases.end(), Address));
645695
}
646696

647697
if (auto Res =

offload/unittests/OffloadAPI/memory/olMemAlloc.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,26 @@ TEST_P(olMemAllocTest, SuccessAllocDevice) {
3434
olMemFree(Alloc);
3535
}
3636

37+
TEST_P(olMemAllocTest, SuccessAllocMany) {
38+
std::vector<void *> Allocs;
39+
Allocs.reserve(1000);
40+
41+
constexpr ol_alloc_type_t TYPES[3] = {
42+
OL_ALLOC_TYPE_DEVICE, OL_ALLOC_TYPE_MANAGED, OL_ALLOC_TYPE_HOST};
43+
44+
for (size_t I = 1; I < 1000; I++) {
45+
void *Alloc = nullptr;
46+
ASSERT_SUCCESS(olMemAlloc(Device, TYPES[I % 3], 1024 * I, &Alloc));
47+
ASSERT_NE(Alloc, nullptr);
48+
49+
Allocs.push_back(Alloc);
50+
}
51+
52+
for (auto *A : Allocs) {
53+
olMemFree(A);
54+
}
55+
}
56+
3757
TEST_P(olMemAllocTest, InvalidNullDevice) {
3858
void *Alloc = nullptr;
3959
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,

0 commit comments

Comments
 (0)