diff --git a/offload/liboffload/API/Memory.td b/offload/liboffload/API/Memory.td index cc98b672a26a9..a24f05e72f5be 100644 --- a/offload/liboffload/API/Memory.td +++ b/offload/liboffload/API/Memory.td @@ -37,9 +37,12 @@ def olMemAlloc : Function { def olMemFree : Function { let desc = "Frees a memory allocation previously made by olMemAlloc."; let params = [ + Param<"ol_platform_handle_t", "Platform", "handle of the platform that allocated this memory", PARAM_IN>, Param<"void*", "Address", "address of the allocation to free", PARAM_IN>, ]; - let returns = []; + let returns = [ + Return<"OL_ERRC_NOT_FOUND", ["memory was not allocated by this platform"]> + ]; } def olMemcpy : Function { diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index 7e8e297831f45..fef3a5669e0d5 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -632,7 +632,7 @@ Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type, return Error::success(); } -Error olMemFree_impl(void *Address) { +Error olMemFree_impl(ol_platform_handle_t Platform, void *Address) { ol_device_handle_t Device; ol_alloc_type_t Type; { @@ -646,6 +646,7 @@ Error olMemFree_impl(void *Address) { Type = AllocInfo.Type; OffloadContext::get().AllocInfoMap.erase(Address); } + assert(Platform == Device->Platform); if (auto Res = Device->Device->dataDelete(Address, convertOlToPluginAllocTy(Type))) diff --git a/offload/unittests/Conformance/include/mathtest/DeviceContext.hpp b/offload/unittests/Conformance/include/mathtest/DeviceContext.hpp index 5c31fc3da53cd..95e90139593f9 100644 --- a/offload/unittests/Conformance/include/mathtest/DeviceContext.hpp +++ b/offload/unittests/Conformance/include/mathtest/DeviceContext.hpp @@ -63,7 +63,7 @@ class DeviceContext { detail::allocManagedMemory(DeviceHandle, Size * sizeof(T), &UntypedAddress); T *TypedAddress = static_cast(UntypedAddress); - return ManagedBuffer(TypedAddress, Size); + return ManagedBuffer(PlatformHandle, TypedAddress, Size); } [[nodiscard]] llvm::Expected> @@ -131,6 +131,7 @@ class DeviceContext { std::size_t GlobalDeviceId; ol_device_handle_t DeviceHandle; + ol_platform_handle_t PlatformHandle; }; } // namespace mathtest diff --git a/offload/unittests/Conformance/include/mathtest/DeviceResources.hpp b/offload/unittests/Conformance/include/mathtest/DeviceResources.hpp index 860448afa3a01..d6d9be6525f5e 100644 --- a/offload/unittests/Conformance/include/mathtest/DeviceResources.hpp +++ b/offload/unittests/Conformance/include/mathtest/DeviceResources.hpp @@ -29,7 +29,7 @@ class DeviceContext; namespace detail { -void freeDeviceMemory(void *Address) noexcept; +void freeDeviceMemory(ol_platform_handle_t Platform, void *Address) noexcept; } // namespace detail //===----------------------------------------------------------------------===// @@ -40,14 +40,15 @@ template class [[nodiscard]] ManagedBuffer { public: ~ManagedBuffer() noexcept { if (Address) - detail::freeDeviceMemory(Address); + detail::freeDeviceMemory(Platform, Address); } ManagedBuffer(const ManagedBuffer &) = delete; ManagedBuffer &operator=(const ManagedBuffer &) = delete; ManagedBuffer(ManagedBuffer &&Other) noexcept - : Address(Other.Address), Size(Other.Size) { + : Platform(Other.Platform), Address(Other.Address), Size(Other.Size) { + Other.Platform = nullptr; Other.Address = nullptr; Other.Size = 0; } @@ -57,11 +58,13 @@ template class [[nodiscard]] ManagedBuffer { return *this; if (Address) - detail::freeDeviceMemory(Address); + detail::freeDeviceMemory(Platform, Address); + Platform = Other.Platform; Address = Other.Address; Size = Other.Size; + Other.Platform = nullptr; Other.Address = nullptr; Other.Size = 0; @@ -85,9 +88,11 @@ template class [[nodiscard]] ManagedBuffer { private: friend class DeviceContext; - explicit ManagedBuffer(T *Address, std::size_t Size) noexcept - : Address(Address), Size(Size) {} + explicit ManagedBuffer(ol_platform_handle_t Platform, T *Address, + std::size_t Size) noexcept + : Platform(Platform), Address(Address), Size(Size) {} + ol_platform_handle_t Platform = nullptr; T *Address = nullptr; std::size_t Size = 0; }; diff --git a/offload/unittests/Conformance/include/mathtest/OffloadForward.hpp b/offload/unittests/Conformance/include/mathtest/OffloadForward.hpp index 788989a0d4211..44c4ab72c9be5 100644 --- a/offload/unittests/Conformance/include/mathtest/OffloadForward.hpp +++ b/offload/unittests/Conformance/include/mathtest/OffloadForward.hpp @@ -32,6 +32,9 @@ typedef struct ol_program_impl_t *ol_program_handle_t; struct ol_symbol_impl_t; typedef struct ol_symbol_impl_t *ol_symbol_handle_t; +struct ol_platform_impl_t; +typedef struct ol_platform_impl_t *ol_platform_handle_t; + #ifdef __cplusplus } #endif // __cplusplus diff --git a/offload/unittests/Conformance/lib/DeviceContext.cpp b/offload/unittests/Conformance/lib/DeviceContext.cpp index 6c3425f1e17c2..987d7841fa763 100644 --- a/offload/unittests/Conformance/lib/DeviceContext.cpp +++ b/offload/unittests/Conformance/lib/DeviceContext.cpp @@ -103,6 +103,7 @@ getPlatformBackend(ol_platform_handle_t PlatformHandle) noexcept { struct Device { ol_device_handle_t Handle; + ol_platform_handle_t PlatformHandle; std::string Name; std::string Platform; ol_platform_backend_t Backend; @@ -124,7 +125,7 @@ const std::vector &getDevices() { auto Platform = getPlatformName(PlatformHandle); static_cast *>(Data)->push_back( - {DeviceHandle, Name, Platform, Backend}); + {DeviceHandle, PlatformHandle, Name, Platform, Backend}); } return true; @@ -175,6 +176,7 @@ DeviceContext::DeviceContext(std::size_t GlobalDeviceId) llvm::Twine(Devices.size())); DeviceHandle = Devices[GlobalDeviceId].Handle; + PlatformHandle = Devices[GlobalDeviceId].PlatformHandle; } DeviceContext::DeviceContext(llvm::StringRef Platform, std::size_t DeviceId) @@ -210,6 +212,7 @@ DeviceContext::DeviceContext(llvm::StringRef Platform, std::size_t DeviceId) GlobalDeviceId = *FoundGlobalDeviceId; DeviceHandle = Devices[GlobalDeviceId].Handle; + PlatformHandle = Devices[GlobalDeviceId].PlatformHandle; } [[nodiscard]] llvm::Expected> diff --git a/offload/unittests/Conformance/lib/DeviceResources.cpp b/offload/unittests/Conformance/lib/DeviceResources.cpp index d1c7b90e751e6..3271256917e45 100644 --- a/offload/unittests/Conformance/lib/DeviceResources.cpp +++ b/offload/unittests/Conformance/lib/DeviceResources.cpp @@ -24,9 +24,10 @@ using namespace mathtest; // Helpers //===----------------------------------------------------------------------===// -void detail::freeDeviceMemory(void *Address) noexcept { +void detail::freeDeviceMemory(ol_platform_handle_t Platform, + void *Address) noexcept { if (Address) - OL_CHECK(olMemFree(Address)); + OL_CHECK(olMemFree(Platform, Address)); } //===----------------------------------------------------------------------===// diff --git a/offload/unittests/OffloadAPI/common/Fixtures.hpp b/offload/unittests/OffloadAPI/common/Fixtures.hpp index 0538e60f276e3..db06a714c59a5 100644 --- a/offload/unittests/OffloadAPI/common/Fixtures.hpp +++ b/offload/unittests/OffloadAPI/common/Fixtures.hpp @@ -137,13 +137,12 @@ struct OffloadDeviceTest Device = DeviceParam.Handle; if (Device == nullptr) GTEST_SKIP() << "No available devices."; + + ASSERT_SUCCESS(olGetDeviceInfo(Device, OL_DEVICE_INFO_PLATFORM, + sizeof(ol_platform_handle_t), &Platform)); } ol_platform_backend_t getPlatformBackend() const { - ol_platform_handle_t Platform = nullptr; - if (olGetDeviceInfo(Device, OL_DEVICE_INFO_PLATFORM, - sizeof(ol_platform_handle_t), &Platform)) - return OL_PLATFORM_BACKEND_UNKNOWN; ol_platform_backend_t Backend; if (olGetPlatformInfo(Platform, OL_PLATFORM_INFO_BACKEND, sizeof(ol_platform_backend_t), &Backend)) @@ -151,6 +150,7 @@ struct OffloadDeviceTest return Backend; } + ol_platform_handle_t Platform = nullptr; ol_device_handle_t Device = nullptr; }; diff --git a/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp b/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp index 1dac8c50271b5..222c98d3bdc3f 100644 --- a/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp +++ b/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp @@ -101,7 +101,7 @@ TEST_P(olLaunchKernelFooTest, Success) { ASSERT_EQ(Data[i], i); } - ASSERT_SUCCESS(olMemFree(Mem)); + ASSERT_SUCCESS(olMemFree(Platform, Mem)); } TEST_P(olLaunchKernelFooTest, SuccessThreaded) { @@ -123,7 +123,7 @@ TEST_P(olLaunchKernelFooTest, SuccessThreaded) { ASSERT_EQ(Data[i], i); } - ASSERT_SUCCESS(olMemFree(Mem)); + ASSERT_SUCCESS(olMemFree(Platform, Mem)); }); } @@ -151,7 +151,7 @@ TEST_P(olLaunchKernelFooTest, SuccessSynchronous) { ASSERT_EQ(Data[i], i); } - ASSERT_SUCCESS(olMemFree(Mem)); + ASSERT_SUCCESS(olMemFree(Platform, Mem)); } TEST_P(olLaunchKernelLocalMemTest, Success) { @@ -176,7 +176,7 @@ TEST_P(olLaunchKernelLocalMemTest, Success) { for (uint32_t i = 0; i < LaunchArgs.GroupSize.x * LaunchArgs.NumGroups.x; i++) ASSERT_EQ(Data[i], (i % 64) * 2); - ASSERT_SUCCESS(olMemFree(Mem)); + ASSERT_SUCCESS(olMemFree(Platform, Mem)); } TEST_P(olLaunchKernelLocalMemReductionTest, Success) { @@ -199,7 +199,7 @@ TEST_P(olLaunchKernelLocalMemReductionTest, Success) { for (uint32_t i = 0; i < LaunchArgs.NumGroups.x; i++) ASSERT_EQ(Data[i], 2 * LaunchArgs.GroupSize.x); - ASSERT_SUCCESS(olMemFree(Mem)); + ASSERT_SUCCESS(olMemFree(Platform, Mem)); } TEST_P(olLaunchKernelLocalMemStaticTest, Success) { @@ -222,7 +222,7 @@ TEST_P(olLaunchKernelLocalMemStaticTest, Success) { for (uint32_t i = 0; i < LaunchArgs.NumGroups.x; i++) ASSERT_EQ(Data[i], 2 * LaunchArgs.GroupSize.x); - ASSERT_SUCCESS(olMemFree(Mem)); + ASSERT_SUCCESS(olMemFree(Platform, Mem)); } TEST_P(olLaunchKernelGlobalTest, Success) { @@ -245,7 +245,7 @@ TEST_P(olLaunchKernelGlobalTest, Success) { ASSERT_EQ(Data[i], i * 2); } - ASSERT_SUCCESS(olMemFree(Mem)); + ASSERT_SUCCESS(olMemFree(Platform, Mem)); } TEST_P(olLaunchKernelGlobalTest, InvalidNotAKernel) { @@ -273,7 +273,7 @@ TEST_P(olLaunchKernelGlobalCtorTest, Success) { ASSERT_EQ(Data[i], i + 100); } - ASSERT_SUCCESS(olMemFree(Mem)); + ASSERT_SUCCESS(olMemFree(Platform, Mem)); } TEST_P(olLaunchKernelGlobalDtorTest, Success) { diff --git a/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp b/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp index 00e428ec2abc7..46d382da61075 100644 --- a/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp +++ b/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp @@ -17,21 +17,21 @@ TEST_P(olMemAllocTest, SuccessAllocManaged) { void *Alloc = nullptr; ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc)); ASSERT_NE(Alloc, nullptr); - olMemFree(Alloc); + olMemFree(Platform, Alloc); } TEST_P(olMemAllocTest, SuccessAllocHost) { void *Alloc = nullptr; ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc)); ASSERT_NE(Alloc, nullptr); - olMemFree(Alloc); + olMemFree(Platform, Alloc); } TEST_P(olMemAllocTest, SuccessAllocDevice) { void *Alloc = nullptr; ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc)); ASSERT_NE(Alloc, nullptr); - olMemFree(Alloc); + olMemFree(Platform, Alloc); } TEST_P(olMemAllocTest, InvalidNullDevice) { diff --git a/offload/unittests/OffloadAPI/memory/olMemFill.cpp b/offload/unittests/OffloadAPI/memory/olMemFill.cpp index a84ed3d78eccf..e7098031e2ed3 100644 --- a/offload/unittests/OffloadAPI/memory/olMemFill.cpp +++ b/offload/unittests/OffloadAPI/memory/olMemFill.cpp @@ -39,7 +39,7 @@ struct olMemFillTest : OffloadQueueTest { ASSERT_EQ(AllocPtr[i], Pattern); } - olMemFree(Alloc); + olMemFree(Platform, Alloc); } }; OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemFillTest); @@ -92,7 +92,7 @@ TEST_P(olMemFillTest, SuccessLarge) { ASSERT_EQ(AllocPtr[i].B, UINT64_MAX); } - olMemFree(Alloc); + olMemFree(Platform, Alloc); } TEST_P(olMemFillTest, SuccessLargeEnqueue) { @@ -120,7 +120,7 @@ TEST_P(olMemFillTest, SuccessLargeEnqueue) { ASSERT_EQ(AllocPtr[i].B, UINT64_MAX); } - olMemFree(Alloc); + olMemFree(Platform, Alloc); } TEST_P(olMemFillTest, SuccessLargeByteAligned) { @@ -146,7 +146,7 @@ TEST_P(olMemFillTest, SuccessLargeByteAligned) { ASSERT_EQ(AllocPtr[i].C, 255); } - olMemFree(Alloc); + olMemFree(Platform, Alloc); } TEST_P(olMemFillTest, SuccessLargeByteAlignedEnqueue) { @@ -176,7 +176,7 @@ TEST_P(olMemFillTest, SuccessLargeByteAlignedEnqueue) { ASSERT_EQ(AllocPtr[i].C, 255); } - olMemFree(Alloc); + olMemFree(Platform, Alloc); } TEST_P(olMemFillTest, InvalidPatternSize) { @@ -189,5 +189,5 @@ TEST_P(olMemFillTest, InvalidPatternSize) { olMemFill(Queue, Alloc, sizeof(Pattern), &Pattern, Size)); olSyncQueue(Queue); - olMemFree(Alloc); + olMemFree(Platform, Alloc); } diff --git a/offload/unittests/OffloadAPI/memory/olMemFree.cpp b/offload/unittests/OffloadAPI/memory/olMemFree.cpp index dfaf9bdef3189..9c602190f7814 100644 --- a/offload/unittests/OffloadAPI/memory/olMemFree.cpp +++ b/offload/unittests/OffloadAPI/memory/olMemFree.cpp @@ -16,24 +16,31 @@ OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemFreeTest); TEST_P(olMemFreeTest, SuccessFreeManaged) { void *Alloc = nullptr; ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc)); - ASSERT_SUCCESS(olMemFree(Alloc)); + ASSERT_SUCCESS(olMemFree(Platform, Alloc)); } TEST_P(olMemFreeTest, SuccessFreeHost) { void *Alloc = nullptr; ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc)); - ASSERT_SUCCESS(olMemFree(Alloc)); + ASSERT_SUCCESS(olMemFree(Platform, Alloc)); } TEST_P(olMemFreeTest, SuccessFreeDevice) { void *Alloc = nullptr; ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc)); - ASSERT_SUCCESS(olMemFree(Alloc)); + ASSERT_SUCCESS(olMemFree(Platform, Alloc)); } TEST_P(olMemFreeTest, InvalidNullPtr) { void *Alloc = nullptr; ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc)); - ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, olMemFree(nullptr)); - ASSERT_SUCCESS(olMemFree(Alloc)); + ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, olMemFree(Platform, nullptr)); + ASSERT_SUCCESS(olMemFree(Platform, Alloc)); +} + +TEST_P(olMemFreeTest, InvalidPlatformPtr) { + void *Alloc = nullptr; + ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc)); + ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE, olMemFree(nullptr, Alloc)); + ASSERT_SUCCESS(olMemFree(Platform, Alloc)); } diff --git a/offload/unittests/OffloadAPI/memory/olMemcpy.cpp b/offload/unittests/OffloadAPI/memory/olMemcpy.cpp index cc67d782ef403..3f15a957fa201 100644 --- a/offload/unittests/OffloadAPI/memory/olMemcpy.cpp +++ b/offload/unittests/OffloadAPI/memory/olMemcpy.cpp @@ -46,7 +46,7 @@ TEST_P(olMemcpyTest, SuccessHtoD) { std::vector Input(Size, 42); ASSERT_SUCCESS(olMemcpy(Queue, Alloc, Device, Input.data(), Host, Size)); olSyncQueue(Queue); - olMemFree(Alloc); + olMemFree(Platform, Alloc); } TEST_P(olMemcpyTest, SuccessDtoH) { @@ -62,7 +62,7 @@ TEST_P(olMemcpyTest, SuccessDtoH) { for (uint8_t Val : Output) { ASSERT_EQ(Val, 42); } - ASSERT_SUCCESS(olMemFree(Alloc)); + ASSERT_SUCCESS(olMemFree(Platform, Alloc)); } TEST_P(olMemcpyTest, SuccessDtoD) { @@ -81,8 +81,8 @@ TEST_P(olMemcpyTest, SuccessDtoD) { for (uint8_t Val : Output) { ASSERT_EQ(Val, 42); } - ASSERT_SUCCESS(olMemFree(AllocA)); - ASSERT_SUCCESS(olMemFree(AllocB)); + ASSERT_SUCCESS(olMemFree(Platform, AllocA)); + ASSERT_SUCCESS(olMemFree(Platform, AllocB)); } TEST_P(olMemcpyTest, SuccessHtoHSync) { @@ -110,7 +110,7 @@ TEST_P(olMemcpyTest, SuccessDtoHSync) { for (uint8_t Val : Output) { ASSERT_EQ(Val, 42); } - ASSERT_SUCCESS(olMemFree(Alloc)); + ASSERT_SUCCESS(olMemFree(Platform, Alloc)); } TEST_P(olMemcpyTest, SuccessSizeZero) { @@ -146,8 +146,8 @@ TEST_P(olMemcpyGlobalTest, SuccessRoundTrip) { for (uint32_t I = 0; I < 64; I++) ASSERT_EQ(DestData[I], I); - ASSERT_SUCCESS(olMemFree(DestMem)); - ASSERT_SUCCESS(olMemFree(SourceMem)); + ASSERT_SUCCESS(olMemFree(Platform, DestMem)); + ASSERT_SUCCESS(olMemFree(Platform, SourceMem)); } TEST_P(olMemcpyGlobalTest, SuccessWrite) { @@ -178,8 +178,8 @@ TEST_P(olMemcpyGlobalTest, SuccessWrite) { for (uint32_t I = 0; I < 64; I++) ASSERT_EQ(DestData[I], I); - ASSERT_SUCCESS(olMemFree(DestMem)); - ASSERT_SUCCESS(olMemFree(SourceMem)); + ASSERT_SUCCESS(olMemFree(Platform, DestMem)); + ASSERT_SUCCESS(olMemFree(Platform, SourceMem)); } TEST_P(olMemcpyGlobalTest, SuccessRead) { @@ -199,5 +199,5 @@ TEST_P(olMemcpyGlobalTest, SuccessRead) { for (uint32_t I = 0; I < 64; I++) ASSERT_EQ(DestData[I], I * 2); - ASSERT_SUCCESS(olMemFree(DestMem)); + ASSERT_SUCCESS(olMemFree(Platform, DestMem)); } diff --git a/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp b/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp index aa86750f6adf9..b45ca6977b4dc 100644 --- a/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp +++ b/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp @@ -93,7 +93,7 @@ TEST_P(olLaunchHostFunctionKernelTest, SuccessBlocking) { } ASSERT_SUCCESS(olDestroyQueue(Queue)); - ASSERT_SUCCESS(olMemFree(Mem)); + ASSERT_SUCCESS(olMemFree(Platform, Mem)); } TEST_P(olLaunchHostFunctionTest, InvalidNullCallback) {