diff --git a/offload/liboffload/API/Memory.td b/offload/liboffload/API/Memory.td index cc98b672a26a9..adca91dbbe2a3 100644 --- a/offload/liboffload/API/Memory.td +++ b/offload/liboffload/API/Memory.td @@ -13,14 +13,17 @@ def ol_alloc_type_t : Enum { let desc = "Represents the type of allocation made with olMemAlloc."; let etors = [ - Etor<"HOST", "Host allocation">, - Etor<"DEVICE", "Device allocation">, - Etor<"MANAGED", "Managed allocation"> + Etor<"HOST", "Host allocation. Allocated on the host and visible to the host and all devices sharing the same platform.">, + Etor<"DEVICE", "Device allocation. Allocated on a specific device and visible only to that device.">, + Etor<"MANAGED", "Managed allocation. Allocated on a specific device and visible to the host and all devices sharing the same platform."> ]; } def olMemAlloc : Function { let desc = "Creates a memory allocation on the specified device."; + let details = [ + "`DEVICE` allocations do not share the same address space as the host or other devices. The `AllocationOut` pointer cannot be used to uniquely identify the allocation in these cases.", + ]; let params = [ Param<"ol_device_handle_t", "Device", "handle of the device to allocate on", PARAM_IN>, Param<"ol_alloc_type_t", "Type", "type of the allocation", PARAM_IN>, @@ -36,10 +39,18 @@ def olMemAlloc : Function { def olMemFree : Function { let desc = "Frees a memory allocation previously made by olMemAlloc."; + let details = [ + "`Address` must be the beginning of the allocation.", + "`Device` must be provided for memory allocated as `OL_ALLOC_TYPE_DEVICE`, and may be provided for other types.", + "If `Device` is provided, it must match the device used to allocate the memory with `olMemAlloc`.", + ]; let params = [ + Param<"ol_device_handle_t", "Device", "handle of the device this allocation was allocated on", PARAM_IN_OPTIONAL>, Param<"void*", "Address", "address of the allocation to free", PARAM_IN>, ]; - let returns = []; + let returns = [ + Return<"OL_ERRC_NOT_FOUND", ["The address was not found in the list of allocations"]> + ]; } def olMemcpy : Function { diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index 7e8e297831f45..b20549250220e 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -183,6 +183,7 @@ namespace llvm { namespace offload { struct AllocInfo { + void *Base; ol_device_handle_t Device; ol_alloc_type_t Type; }; @@ -201,8 +202,8 @@ struct OffloadContext { bool TracingEnabled = false; bool ValidationEnabled = true; - DenseMap AllocInfoMap{}; - std::mutex AllocInfoMapMutex{}; + SmallVector AllocInfoList{}; + std::mutex AllocInfoListMutex{}; SmallVector Platforms{}; size_t RefCount; @@ -625,30 +626,37 @@ Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type, *AllocationOut = *Alloc; { - std::lock_guard Lock(OffloadContext::get().AllocInfoMapMutex); - OffloadContext::get().AllocInfoMap.insert_or_assign( - *Alloc, AllocInfo{Device, Type}); + std::lock_guard Lock(OffloadContext::get().AllocInfoListMutex); + OffloadContext::get().AllocInfoList.emplace_back( + AllocInfo{*AllocationOut, Device, Type}); } return Error::success(); } -Error olMemFree_impl(void *Address) { - ol_device_handle_t Device; - ol_alloc_type_t Type; +Error olMemFree_impl(ol_device_handle_t Device, void *Address) { + AllocInfo Removed; { - std::lock_guard Lock(OffloadContext::get().AllocInfoMapMutex); - if (!OffloadContext::get().AllocInfoMap.contains(Address)) - return createOffloadError(ErrorCode::INVALID_ARGUMENT, - "address is not a known allocation"); - - auto AllocInfo = OffloadContext::get().AllocInfoMap.at(Address); - Device = AllocInfo.Device; - Type = AllocInfo.Type; - OffloadContext::get().AllocInfoMap.erase(Address); + std::lock_guard Lock(OffloadContext::get().AllocInfoListMutex); + + auto &List = OffloadContext::get().AllocInfoList; + auto Entry = std::find_if(List.begin(), List.end(), [&](AllocInfo &Entry) { + return Address == Entry.Base && (!Device || Entry.Device == Device); + }); + + if (Entry == List.end()) + return Plugin::error(ErrorCode::NOT_FOUND, + "could not find memory allocated by olMemAlloc"); + if (!Device && Entry->Type == OL_ALLOC_TYPE_DEVICE) + return Plugin::error( + ErrorCode::NOT_FOUND, + "specifying the Device parameter is required to query device memory"); + + Removed = std::move(*Entry); + *Entry = List.pop_back_val(); } - if (auto Res = - Device->Device->dataDelete(Address, convertOlToPluginAllocTy(Type))) + if (auto Res = Removed.Device->Device->dataDelete( + Removed.Base, convertOlToPluginAllocTy(Removed.Type))) return Res; return Error::success(); diff --git a/offload/unittests/Conformance/include/mathtest/DeviceResources.hpp b/offload/unittests/Conformance/include/mathtest/DeviceResources.hpp index 860448afa3a01..9650678b53066 100644 --- a/offload/unittests/Conformance/include/mathtest/DeviceResources.hpp +++ b/offload/unittests/Conformance/include/mathtest/DeviceResources.hpp @@ -29,7 +29,7 @@ class DeviceContext; namespace detail { -void freeDeviceMemory(void *Address) noexcept; +void freeDeviceMemory(ol_device_handle_t Device, void *Address) noexcept; } // namespace detail //===----------------------------------------------------------------------===// @@ -40,7 +40,7 @@ template class [[nodiscard]] ManagedBuffer { public: ~ManagedBuffer() noexcept { if (Address) - detail::freeDeviceMemory(Address); + detail::freeDeviceMemory(nullptr, Address); } ManagedBuffer(const ManagedBuffer &) = delete; diff --git a/offload/unittests/Conformance/lib/DeviceResources.cpp b/offload/unittests/Conformance/lib/DeviceResources.cpp index d1c7b90e751e6..29c9efa4852a1 100644 --- a/offload/unittests/Conformance/lib/DeviceResources.cpp +++ b/offload/unittests/Conformance/lib/DeviceResources.cpp @@ -24,9 +24,10 @@ using namespace mathtest; // Helpers //===----------------------------------------------------------------------===// -void detail::freeDeviceMemory(void *Address) noexcept { +void detail::freeDeviceMemory(ol_device_handle_t Device, + void *Address) noexcept { if (Address) - OL_CHECK(olMemFree(Address)); + OL_CHECK(olMemFree(Device, Address)); } //===----------------------------------------------------------------------===// diff --git a/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp b/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp index 1dac8c50271b5..a7f4881bcc709 100644 --- a/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp +++ b/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp @@ -101,7 +101,7 @@ TEST_P(olLaunchKernelFooTest, Success) { ASSERT_EQ(Data[i], i); } - ASSERT_SUCCESS(olMemFree(Mem)); + ASSERT_SUCCESS(olMemFree(Device, Mem)); } TEST_P(olLaunchKernelFooTest, SuccessThreaded) { @@ -123,7 +123,7 @@ TEST_P(olLaunchKernelFooTest, SuccessThreaded) { ASSERT_EQ(Data[i], i); } - ASSERT_SUCCESS(olMemFree(Mem)); + ASSERT_SUCCESS(olMemFree(Device, Mem)); }); } @@ -151,7 +151,7 @@ TEST_P(olLaunchKernelFooTest, SuccessSynchronous) { ASSERT_EQ(Data[i], i); } - ASSERT_SUCCESS(olMemFree(Mem)); + ASSERT_SUCCESS(olMemFree(Device, Mem)); } TEST_P(olLaunchKernelLocalMemTest, Success) { @@ -176,7 +176,7 @@ TEST_P(olLaunchKernelLocalMemTest, Success) { for (uint32_t i = 0; i < LaunchArgs.GroupSize.x * LaunchArgs.NumGroups.x; i++) ASSERT_EQ(Data[i], (i % 64) * 2); - ASSERT_SUCCESS(olMemFree(Mem)); + ASSERT_SUCCESS(olMemFree(Device, Mem)); } TEST_P(olLaunchKernelLocalMemReductionTest, Success) { @@ -199,7 +199,7 @@ TEST_P(olLaunchKernelLocalMemReductionTest, Success) { for (uint32_t i = 0; i < LaunchArgs.NumGroups.x; i++) ASSERT_EQ(Data[i], 2 * LaunchArgs.GroupSize.x); - ASSERT_SUCCESS(olMemFree(Mem)); + ASSERT_SUCCESS(olMemFree(Device, Mem)); } TEST_P(olLaunchKernelLocalMemStaticTest, Success) { @@ -222,7 +222,7 @@ TEST_P(olLaunchKernelLocalMemStaticTest, Success) { for (uint32_t i = 0; i < LaunchArgs.NumGroups.x; i++) ASSERT_EQ(Data[i], 2 * LaunchArgs.GroupSize.x); - ASSERT_SUCCESS(olMemFree(Mem)); + ASSERT_SUCCESS(olMemFree(Device, Mem)); } TEST_P(olLaunchKernelGlobalTest, Success) { @@ -245,7 +245,7 @@ TEST_P(olLaunchKernelGlobalTest, Success) { ASSERT_EQ(Data[i], i * 2); } - ASSERT_SUCCESS(olMemFree(Mem)); + ASSERT_SUCCESS(olMemFree(Device, Mem)); } TEST_P(olLaunchKernelGlobalTest, InvalidNotAKernel) { @@ -273,7 +273,7 @@ TEST_P(olLaunchKernelGlobalCtorTest, Success) { ASSERT_EQ(Data[i], i + 100); } - ASSERT_SUCCESS(olMemFree(Mem)); + ASSERT_SUCCESS(olMemFree(Device, Mem)); } TEST_P(olLaunchKernelGlobalDtorTest, Success) { diff --git a/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp b/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp index 00e428ec2abc7..c1d585d7271f3 100644 --- a/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp +++ b/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp @@ -17,21 +17,21 @@ TEST_P(olMemAllocTest, SuccessAllocManaged) { void *Alloc = nullptr; ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc)); ASSERT_NE(Alloc, nullptr); - olMemFree(Alloc); + olMemFree(Device, Alloc); } TEST_P(olMemAllocTest, SuccessAllocHost) { void *Alloc = nullptr; ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc)); ASSERT_NE(Alloc, nullptr); - olMemFree(Alloc); + olMemFree(Device, Alloc); } TEST_P(olMemAllocTest, SuccessAllocDevice) { void *Alloc = nullptr; ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc)); ASSERT_NE(Alloc, nullptr); - olMemFree(Alloc); + olMemFree(Device, Alloc); } TEST_P(olMemAllocTest, InvalidNullDevice) { diff --git a/offload/unittests/OffloadAPI/memory/olMemFill.cpp b/offload/unittests/OffloadAPI/memory/olMemFill.cpp index a84ed3d78eccf..e22b0001ca838 100644 --- a/offload/unittests/OffloadAPI/memory/olMemFill.cpp +++ b/offload/unittests/OffloadAPI/memory/olMemFill.cpp @@ -39,7 +39,7 @@ struct olMemFillTest : OffloadQueueTest { ASSERT_EQ(AllocPtr[i], Pattern); } - olMemFree(Alloc); + olMemFree(Device, Alloc); } }; OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemFillTest); @@ -92,7 +92,7 @@ TEST_P(olMemFillTest, SuccessLarge) { ASSERT_EQ(AllocPtr[i].B, UINT64_MAX); } - olMemFree(Alloc); + olMemFree(Device, Alloc); } TEST_P(olMemFillTest, SuccessLargeEnqueue) { @@ -120,7 +120,7 @@ TEST_P(olMemFillTest, SuccessLargeEnqueue) { ASSERT_EQ(AllocPtr[i].B, UINT64_MAX); } - olMemFree(Alloc); + olMemFree(Device, Alloc); } TEST_P(olMemFillTest, SuccessLargeByteAligned) { @@ -146,7 +146,7 @@ TEST_P(olMemFillTest, SuccessLargeByteAligned) { ASSERT_EQ(AllocPtr[i].C, 255); } - olMemFree(Alloc); + olMemFree(Device, Alloc); } TEST_P(olMemFillTest, SuccessLargeByteAlignedEnqueue) { @@ -176,7 +176,7 @@ TEST_P(olMemFillTest, SuccessLargeByteAlignedEnqueue) { ASSERT_EQ(AllocPtr[i].C, 255); } - olMemFree(Alloc); + olMemFree(Device, Alloc); } TEST_P(olMemFillTest, InvalidPatternSize) { @@ -189,5 +189,5 @@ TEST_P(olMemFillTest, InvalidPatternSize) { olMemFill(Queue, Alloc, sizeof(Pattern), &Pattern, Size)); olSyncQueue(Queue); - olMemFree(Alloc); + olMemFree(Device, Alloc); } diff --git a/offload/unittests/OffloadAPI/memory/olMemFree.cpp b/offload/unittests/OffloadAPI/memory/olMemFree.cpp index dfaf9bdef3189..a68ae2514adfe 100644 --- a/offload/unittests/OffloadAPI/memory/olMemFree.cpp +++ b/offload/unittests/OffloadAPI/memory/olMemFree.cpp @@ -10,30 +10,63 @@ #include #include -using olMemFreeTest = OffloadDeviceTest; -OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemFreeTest); +template struct olMemFreeTestBase : OffloadDeviceTest { + void SetUp() override { + RETURN_ON_FATAL_FAILURE(OffloadDeviceTest::SetUp()); + ASSERT_SUCCESS(olMemAlloc(Device, Type, 0x1000, &Alloc)); + } -TEST_P(olMemFreeTest, SuccessFreeManaged) { - void *Alloc = nullptr; - ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc)); - ASSERT_SUCCESS(olMemFree(Alloc)); + void *Alloc; +}; + +struct olMemFreeDeviceTest : olMemFreeTestBase {}; +OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemFreeDeviceTest); + +struct olMemFreeHostTest : olMemFreeTestBase {}; +OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemFreeHostTest); + +struct olMemFreeManagedTest : olMemFreeTestBase {}; +OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemFreeManagedTest); + +TEST_P(olMemFreeManagedTest, SuccessFree) { + ASSERT_SUCCESS(olMemFree(Device, Alloc)); +} + +TEST_P(olMemFreeManagedTest, SuccessFreeNull) { + ASSERT_SUCCESS(olMemFree(nullptr, Alloc)); +} + +TEST_P(olMemFreeHostTest, SuccessFree) { + ASSERT_SUCCESS(olMemFree(Device, Alloc)); +} + +TEST_P(olMemFreeHostTest, SuccessFreeNull) { + ASSERT_SUCCESS(olMemFree(nullptr, Alloc)); +} + +TEST_P(olMemFreeDeviceTest, SuccessFree) { + ASSERT_SUCCESS(olMemFree(Device, Alloc)); +} + +TEST_P(olMemFreeDeviceTest, InvalidNullPtr) { + ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, olMemFree(Device, nullptr)); +} + +TEST_P(olMemFreeDeviceTest, InvalidNullDevice) { + ASSERT_ERROR(OL_ERRC_NOT_FOUND, olMemFree(nullptr, Alloc)); } -TEST_P(olMemFreeTest, SuccessFreeHost) { - void *Alloc = nullptr; - ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc)); - ASSERT_SUCCESS(olMemFree(Alloc)); +TEST_P(olMemFreeDeviceTest, InvalidFreeWrongDevice) { + ASSERT_ERROR(OL_ERRC_NOT_FOUND, + olMemFree(TestEnvironment::getHostDevice(), Alloc)); } -TEST_P(olMemFreeTest, SuccessFreeDevice) { - void *Alloc = nullptr; - ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc)); - ASSERT_SUCCESS(olMemFree(Alloc)); +TEST_P(olMemFreeHostTest, InvalidFreeWrongDevice) { + ASSERT_ERROR(OL_ERRC_NOT_FOUND, + olMemFree(TestEnvironment::getHostDevice(), Alloc)); } -TEST_P(olMemFreeTest, InvalidNullPtr) { - void *Alloc = nullptr; - ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc)); - ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, olMemFree(nullptr)); - ASSERT_SUCCESS(olMemFree(Alloc)); +TEST_P(olMemFreeManagedTest, InvalidFreeWrongDevice) { + ASSERT_ERROR(OL_ERRC_NOT_FOUND, + olMemFree(TestEnvironment::getHostDevice(), Alloc)); } diff --git a/offload/unittests/OffloadAPI/memory/olMemcpy.cpp b/offload/unittests/OffloadAPI/memory/olMemcpy.cpp index cc67d782ef403..d028099916848 100644 --- a/offload/unittests/OffloadAPI/memory/olMemcpy.cpp +++ b/offload/unittests/OffloadAPI/memory/olMemcpy.cpp @@ -46,7 +46,7 @@ TEST_P(olMemcpyTest, SuccessHtoD) { std::vector Input(Size, 42); ASSERT_SUCCESS(olMemcpy(Queue, Alloc, Device, Input.data(), Host, Size)); olSyncQueue(Queue); - olMemFree(Alloc); + olMemFree(Device, Alloc); } TEST_P(olMemcpyTest, SuccessDtoH) { @@ -62,7 +62,7 @@ TEST_P(olMemcpyTest, SuccessDtoH) { for (uint8_t Val : Output) { ASSERT_EQ(Val, 42); } - ASSERT_SUCCESS(olMemFree(Alloc)); + ASSERT_SUCCESS(olMemFree(Device, Alloc)); } TEST_P(olMemcpyTest, SuccessDtoD) { @@ -81,8 +81,8 @@ TEST_P(olMemcpyTest, SuccessDtoD) { for (uint8_t Val : Output) { ASSERT_EQ(Val, 42); } - ASSERT_SUCCESS(olMemFree(AllocA)); - ASSERT_SUCCESS(olMemFree(AllocB)); + ASSERT_SUCCESS(olMemFree(Device, AllocA)); + ASSERT_SUCCESS(olMemFree(Device, AllocB)); } TEST_P(olMemcpyTest, SuccessHtoHSync) { @@ -110,7 +110,7 @@ TEST_P(olMemcpyTest, SuccessDtoHSync) { for (uint8_t Val : Output) { ASSERT_EQ(Val, 42); } - ASSERT_SUCCESS(olMemFree(Alloc)); + ASSERT_SUCCESS(olMemFree(Device, Alloc)); } TEST_P(olMemcpyTest, SuccessSizeZero) { @@ -146,8 +146,8 @@ TEST_P(olMemcpyGlobalTest, SuccessRoundTrip) { for (uint32_t I = 0; I < 64; I++) ASSERT_EQ(DestData[I], I); - ASSERT_SUCCESS(olMemFree(DestMem)); - ASSERT_SUCCESS(olMemFree(SourceMem)); + ASSERT_SUCCESS(olMemFree(Device, DestMem)); + ASSERT_SUCCESS(olMemFree(Device, SourceMem)); } TEST_P(olMemcpyGlobalTest, SuccessWrite) { @@ -178,8 +178,8 @@ TEST_P(olMemcpyGlobalTest, SuccessWrite) { for (uint32_t I = 0; I < 64; I++) ASSERT_EQ(DestData[I], I); - ASSERT_SUCCESS(olMemFree(DestMem)); - ASSERT_SUCCESS(olMemFree(SourceMem)); + ASSERT_SUCCESS(olMemFree(Device, DestMem)); + ASSERT_SUCCESS(olMemFree(Device, SourceMem)); } TEST_P(olMemcpyGlobalTest, SuccessRead) { @@ -199,5 +199,5 @@ TEST_P(olMemcpyGlobalTest, SuccessRead) { for (uint32_t I = 0; I < 64; I++) ASSERT_EQ(DestData[I], I * 2); - ASSERT_SUCCESS(olMemFree(DestMem)); + ASSERT_SUCCESS(olMemFree(Device, DestMem)); } diff --git a/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp b/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp index aa86750f6adf9..2f3fda6fb729b 100644 --- a/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp +++ b/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp @@ -93,7 +93,7 @@ TEST_P(olLaunchHostFunctionKernelTest, SuccessBlocking) { } ASSERT_SUCCESS(olDestroyQueue(Queue)); - ASSERT_SUCCESS(olMemFree(Mem)); + ASSERT_SUCCESS(olMemFree(Device, Mem)); } TEST_P(olLaunchHostFunctionTest, InvalidNullCallback) {