Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion offload/liboffload/API/Memory.td
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,12 @@ def olMemAlloc : Function {
def olMemFree : Function {
let desc = "Frees a memory allocation previously made by olMemAlloc.";
let params = [
Param<"ol_platform_handle_t", "Platform", "handle of the platform that allocated this memory", PARAM_IN>,
Param<"void*", "Address", "address of the allocation to free", PARAM_IN>,
];
let returns = [];
let returns = [
Return<"OL_ERRC_NOT_FOUND", ["memory was not allocated by this platform"]>
];
}

def olMemcpy : Function {
Expand Down
3 changes: 2 additions & 1 deletion offload/liboffload/src/OffloadImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type,
return Error::success();
}

Error olMemFree_impl(void *Address) {
Error olMemFree_impl(ol_platform_handle_t Platform, void *Address) {
ol_device_handle_t Device;
ol_alloc_type_t Type;
{
Expand All @@ -646,6 +646,7 @@ Error olMemFree_impl(void *Address) {
Type = AllocInfo.Type;
OffloadContext::get().AllocInfoMap.erase(Address);
}
assert(Platform == Device->Platform);

if (auto Res =
Device->Device->dataDelete(Address, convertOlToPluginAllocTy(Type)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class DeviceContext {
detail::allocManagedMemory(DeviceHandle, Size * sizeof(T), &UntypedAddress);
T *TypedAddress = static_cast<T *>(UntypedAddress);

return ManagedBuffer<T>(TypedAddress, Size);
return ManagedBuffer<T>(PlatformHandle, TypedAddress, Size);
}

[[nodiscard]] llvm::Expected<std::shared_ptr<DeviceImage>>
Expand Down Expand Up @@ -131,6 +131,7 @@ class DeviceContext {

std::size_t GlobalDeviceId;
ol_device_handle_t DeviceHandle;
ol_platform_handle_t PlatformHandle;
};
} // namespace mathtest

Expand Down
17 changes: 11 additions & 6 deletions offload/unittests/Conformance/include/mathtest/DeviceResources.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class DeviceContext;

namespace detail {

void freeDeviceMemory(void *Address) noexcept;
void freeDeviceMemory(ol_platform_handle_t Platform, void *Address) noexcept;
} // namespace detail

//===----------------------------------------------------------------------===//
Expand All @@ -40,14 +40,15 @@ template <typename T> class [[nodiscard]] ManagedBuffer {
public:
~ManagedBuffer() noexcept {
if (Address)
detail::freeDeviceMemory(Address);
detail::freeDeviceMemory(Platform, Address);
}

ManagedBuffer(const ManagedBuffer &) = delete;
ManagedBuffer &operator=(const ManagedBuffer &) = delete;

ManagedBuffer(ManagedBuffer &&Other) noexcept
: Address(Other.Address), Size(Other.Size) {
: Platform(Other.Platform), Address(Other.Address), Size(Other.Size) {
Other.Platform = nullptr;
Other.Address = nullptr;
Other.Size = 0;
}
Expand All @@ -57,11 +58,13 @@ template <typename T> class [[nodiscard]] ManagedBuffer {
return *this;

if (Address)
detail::freeDeviceMemory(Address);
detail::freeDeviceMemory(Platform, Address);

Platform = Other.Platform;
Address = Other.Address;
Size = Other.Size;

Other.Platform = nullptr;
Other.Address = nullptr;
Other.Size = 0;

Expand All @@ -85,9 +88,11 @@ template <typename T> class [[nodiscard]] ManagedBuffer {
private:
friend class DeviceContext;

explicit ManagedBuffer(T *Address, std::size_t Size) noexcept
: Address(Address), Size(Size) {}
explicit ManagedBuffer(ol_platform_handle_t Platform, T *Address,
std::size_t Size) noexcept
: Platform(Platform), Address(Address), Size(Size) {}

ol_platform_handle_t Platform = nullptr;
T *Address = nullptr;
std::size_t Size = 0;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ typedef struct ol_program_impl_t *ol_program_handle_t;
struct ol_symbol_impl_t;
typedef struct ol_symbol_impl_t *ol_symbol_handle_t;

struct ol_platform_impl_t;
typedef struct ol_platform_impl_t *ol_platform_handle_t;

#ifdef __cplusplus
}
#endif // __cplusplus
Expand Down
5 changes: 4 additions & 1 deletion offload/unittests/Conformance/lib/DeviceContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ getPlatformBackend(ol_platform_handle_t PlatformHandle) noexcept {

struct Device {
ol_device_handle_t Handle;
ol_platform_handle_t PlatformHandle;
std::string Name;
std::string Platform;
ol_platform_backend_t Backend;
Expand All @@ -124,7 +125,7 @@ const std::vector<Device> &getDevices() {
auto Platform = getPlatformName(PlatformHandle);

static_cast<std::vector<Device> *>(Data)->push_back(
{DeviceHandle, Name, Platform, Backend});
{DeviceHandle, PlatformHandle, Name, Platform, Backend});
}

return true;
Expand Down Expand Up @@ -175,6 +176,7 @@ DeviceContext::DeviceContext(std::size_t GlobalDeviceId)
llvm::Twine(Devices.size()));

DeviceHandle = Devices[GlobalDeviceId].Handle;
PlatformHandle = Devices[GlobalDeviceId].PlatformHandle;
}

DeviceContext::DeviceContext(llvm::StringRef Platform, std::size_t DeviceId)
Expand Down Expand Up @@ -210,6 +212,7 @@ DeviceContext::DeviceContext(llvm::StringRef Platform, std::size_t DeviceId)

GlobalDeviceId = *FoundGlobalDeviceId;
DeviceHandle = Devices[GlobalDeviceId].Handle;
PlatformHandle = Devices[GlobalDeviceId].PlatformHandle;
}

[[nodiscard]] llvm::Expected<std::shared_ptr<DeviceImage>>
Expand Down
5 changes: 3 additions & 2 deletions offload/unittests/Conformance/lib/DeviceResources.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@ using namespace mathtest;
// Helpers
//===----------------------------------------------------------------------===//

void detail::freeDeviceMemory(void *Address) noexcept {
void detail::freeDeviceMemory(ol_platform_handle_t Platform,
void *Address) noexcept {
if (Address)
OL_CHECK(olMemFree(Address));
OL_CHECK(olMemFree(Platform, Address));
}

//===----------------------------------------------------------------------===//
Expand Down
8 changes: 4 additions & 4 deletions offload/unittests/OffloadAPI/common/Fixtures.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,20 +137,20 @@ struct OffloadDeviceTest
Device = DeviceParam.Handle;
if (Device == nullptr)
GTEST_SKIP() << "No available devices.";

ASSERT_SUCCESS(olGetDeviceInfo(Device, OL_DEVICE_INFO_PLATFORM,
sizeof(ol_platform_handle_t), &Platform));
}

ol_platform_backend_t getPlatformBackend() const {
ol_platform_handle_t Platform = nullptr;
if (olGetDeviceInfo(Device, OL_DEVICE_INFO_PLATFORM,
sizeof(ol_platform_handle_t), &Platform))
return OL_PLATFORM_BACKEND_UNKNOWN;
ol_platform_backend_t Backend;
if (olGetPlatformInfo(Platform, OL_PLATFORM_INFO_BACKEND,
sizeof(ol_platform_backend_t), &Backend))
return OL_PLATFORM_BACKEND_UNKNOWN;
return Backend;
}

ol_platform_handle_t Platform = nullptr;
ol_device_handle_t Device = nullptr;
};

Expand Down
16 changes: 8 additions & 8 deletions offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ TEST_P(olLaunchKernelFooTest, Success) {
ASSERT_EQ(Data[i], i);
}

ASSERT_SUCCESS(olMemFree(Mem));
ASSERT_SUCCESS(olMemFree(Platform, Mem));
}

TEST_P(olLaunchKernelFooTest, SuccessThreaded) {
Expand All @@ -123,7 +123,7 @@ TEST_P(olLaunchKernelFooTest, SuccessThreaded) {
ASSERT_EQ(Data[i], i);
}

ASSERT_SUCCESS(olMemFree(Mem));
ASSERT_SUCCESS(olMemFree(Platform, Mem));
});
}

Expand Down Expand Up @@ -151,7 +151,7 @@ TEST_P(olLaunchKernelFooTest, SuccessSynchronous) {
ASSERT_EQ(Data[i], i);
}

ASSERT_SUCCESS(olMemFree(Mem));
ASSERT_SUCCESS(olMemFree(Platform, Mem));
}

TEST_P(olLaunchKernelLocalMemTest, Success) {
Expand All @@ -176,7 +176,7 @@ TEST_P(olLaunchKernelLocalMemTest, Success) {
for (uint32_t i = 0; i < LaunchArgs.GroupSize.x * LaunchArgs.NumGroups.x; i++)
ASSERT_EQ(Data[i], (i % 64) * 2);

ASSERT_SUCCESS(olMemFree(Mem));
ASSERT_SUCCESS(olMemFree(Platform, Mem));
}

TEST_P(olLaunchKernelLocalMemReductionTest, Success) {
Expand All @@ -199,7 +199,7 @@ TEST_P(olLaunchKernelLocalMemReductionTest, Success) {
for (uint32_t i = 0; i < LaunchArgs.NumGroups.x; i++)
ASSERT_EQ(Data[i], 2 * LaunchArgs.GroupSize.x);

ASSERT_SUCCESS(olMemFree(Mem));
ASSERT_SUCCESS(olMemFree(Platform, Mem));
}

TEST_P(olLaunchKernelLocalMemStaticTest, Success) {
Expand All @@ -222,7 +222,7 @@ TEST_P(olLaunchKernelLocalMemStaticTest, Success) {
for (uint32_t i = 0; i < LaunchArgs.NumGroups.x; i++)
ASSERT_EQ(Data[i], 2 * LaunchArgs.GroupSize.x);

ASSERT_SUCCESS(olMemFree(Mem));
ASSERT_SUCCESS(olMemFree(Platform, Mem));
}

TEST_P(olLaunchKernelGlobalTest, Success) {
Expand All @@ -245,7 +245,7 @@ TEST_P(olLaunchKernelGlobalTest, Success) {
ASSERT_EQ(Data[i], i * 2);
}

ASSERT_SUCCESS(olMemFree(Mem));
ASSERT_SUCCESS(olMemFree(Platform, Mem));
}

TEST_P(olLaunchKernelGlobalTest, InvalidNotAKernel) {
Expand Down Expand Up @@ -273,7 +273,7 @@ TEST_P(olLaunchKernelGlobalCtorTest, Success) {
ASSERT_EQ(Data[i], i + 100);
}

ASSERT_SUCCESS(olMemFree(Mem));
ASSERT_SUCCESS(olMemFree(Platform, Mem));
}

TEST_P(olLaunchKernelGlobalDtorTest, Success) {
Expand Down
6 changes: 3 additions & 3 deletions offload/unittests/OffloadAPI/memory/olMemAlloc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,21 @@ TEST_P(olMemAllocTest, SuccessAllocManaged) {
void *Alloc = nullptr;
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc));
ASSERT_NE(Alloc, nullptr);
olMemFree(Alloc);
olMemFree(Platform, Alloc);
}

TEST_P(olMemAllocTest, SuccessAllocHost) {
void *Alloc = nullptr;
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc));
ASSERT_NE(Alloc, nullptr);
olMemFree(Alloc);
olMemFree(Platform, Alloc);
}

TEST_P(olMemAllocTest, SuccessAllocDevice) {
void *Alloc = nullptr;
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
ASSERT_NE(Alloc, nullptr);
olMemFree(Alloc);
olMemFree(Platform, Alloc);
}

TEST_P(olMemAllocTest, InvalidNullDevice) {
Expand Down
12 changes: 6 additions & 6 deletions offload/unittests/OffloadAPI/memory/olMemFill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ struct olMemFillTest : OffloadQueueTest {
ASSERT_EQ(AllocPtr[i], Pattern);
}

olMemFree(Alloc);
olMemFree(Platform, Alloc);
}
};
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemFillTest);
Expand Down Expand Up @@ -92,7 +92,7 @@ TEST_P(olMemFillTest, SuccessLarge) {
ASSERT_EQ(AllocPtr[i].B, UINT64_MAX);
}

olMemFree(Alloc);
olMemFree(Platform, Alloc);
}

TEST_P(olMemFillTest, SuccessLargeEnqueue) {
Expand Down Expand Up @@ -120,7 +120,7 @@ TEST_P(olMemFillTest, SuccessLargeEnqueue) {
ASSERT_EQ(AllocPtr[i].B, UINT64_MAX);
}

olMemFree(Alloc);
olMemFree(Platform, Alloc);
}

TEST_P(olMemFillTest, SuccessLargeByteAligned) {
Expand All @@ -146,7 +146,7 @@ TEST_P(olMemFillTest, SuccessLargeByteAligned) {
ASSERT_EQ(AllocPtr[i].C, 255);
}

olMemFree(Alloc);
olMemFree(Platform, Alloc);
}

TEST_P(olMemFillTest, SuccessLargeByteAlignedEnqueue) {
Expand Down Expand Up @@ -176,7 +176,7 @@ TEST_P(olMemFillTest, SuccessLargeByteAlignedEnqueue) {
ASSERT_EQ(AllocPtr[i].C, 255);
}

olMemFree(Alloc);
olMemFree(Platform, Alloc);
}

TEST_P(olMemFillTest, InvalidPatternSize) {
Expand All @@ -189,5 +189,5 @@ TEST_P(olMemFillTest, InvalidPatternSize) {
olMemFill(Queue, Alloc, sizeof(Pattern), &Pattern, Size));

olSyncQueue(Queue);
olMemFree(Alloc);
olMemFree(Platform, Alloc);
}
17 changes: 12 additions & 5 deletions offload/unittests/OffloadAPI/memory/olMemFree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,31 @@ OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemFreeTest);
TEST_P(olMemFreeTest, SuccessFreeManaged) {
void *Alloc = nullptr;
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc));
ASSERT_SUCCESS(olMemFree(Alloc));
ASSERT_SUCCESS(olMemFree(Platform, Alloc));
}

TEST_P(olMemFreeTest, SuccessFreeHost) {
void *Alloc = nullptr;
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc));
ASSERT_SUCCESS(olMemFree(Alloc));
ASSERT_SUCCESS(olMemFree(Platform, Alloc));
}

TEST_P(olMemFreeTest, SuccessFreeDevice) {
void *Alloc = nullptr;
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
ASSERT_SUCCESS(olMemFree(Alloc));
ASSERT_SUCCESS(olMemFree(Platform, Alloc));
}

TEST_P(olMemFreeTest, InvalidNullPtr) {
void *Alloc = nullptr;
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, olMemFree(nullptr));
ASSERT_SUCCESS(olMemFree(Alloc));
ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, olMemFree(Platform, nullptr));
ASSERT_SUCCESS(olMemFree(Platform, Alloc));
}

TEST_P(olMemFreeTest, InvalidPlatformPtr) {
void *Alloc = nullptr;
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE, olMemFree(nullptr, Alloc));
ASSERT_SUCCESS(olMemFree(Platform, Alloc));
}
Loading
Loading