Skip to content

Conversation

@RossBrunton
Copy link
Contributor

In a future change, most of the allocation tracking will be removed from
liboffload itself and be delegated to the plugins. Therefore, we will
need to know which plugin is in charge of the allocation.

In a future change, most of the allocation tracking will be removed from
liboffload itself and be delegated to the plugins. Therefore, we will
need to know which plugin is in charge of the allocation.
@RossBrunton RossBrunton marked this pull request as ready for review September 9, 2025 09:25

template <typename T>
ManagedBuffer<T> createManagedBuffer(std::size_t Size) const noexcept {
ManagedBuffer<T> createManagedBuffer(std::size_t Size) noexcept {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@leandrolcampos I've needed to touch a lot of the conformance tests here, but for some reason I'm unable to coerce it into being built and tested.

Can you have a look and let me know if all of this looks okay?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @RossBrunton,
I'll take a look until the end of the day.

Copy link
Contributor

@leandrolcampos leandrolcampos Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@RossBrunton, could you take a look at this #157773? Merging this patch will ensure the conformance tests run correctly.

I also prepared a Gist containing a step-by-step guide to setting up an environment to build and run the conformance tests on Windows Subsystem for Linux (WSL 2) with NVIDIA GPUs. I think it can be easily adapted to different OS and/or platform configurations.

Finally, thanks for your care with the conformance tests! If you're interested in the broader context of this work, I recently wrote about it in this post on the LLVM Blog.

@llvmbot llvmbot added the offload label Sep 9, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 9, 2025

@llvm/pr-subscribers-offload

Author: Ross Brunton (RossBrunton)

Changes

In a future change, most of the allocation tracking will be removed from
liboffload itself and be delegated to the plugins. Therefore, we will
need to know which plugin is in charge of the allocation.


Full diff: https://github.com/llvm/llvm-project/pull/157478.diff

15 Files Affected:

  • (modified) offload/liboffload/API/Memory.td (+4-1)
  • (modified) offload/liboffload/src/OffloadImpl.cpp (+2-1)
  • (modified) offload/unittests/Conformance/include/mathtest/DeviceContext.hpp (+6-2)
  • (modified) offload/unittests/Conformance/include/mathtest/DeviceResources.hpp (+7-5)
  • (modified) offload/unittests/Conformance/include/mathtest/GpuMathTest.hpp (+2-2)
  • (modified) offload/unittests/Conformance/include/mathtest/OffloadForward.hpp (+3)
  • (modified) offload/unittests/Conformance/lib/DeviceContext.cpp (+23)
  • (modified) offload/unittests/Conformance/lib/DeviceResources.cpp (+3-2)
  • (modified) offload/unittests/OffloadAPI/common/Fixtures.hpp (+4-4)
  • (modified) offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp (+8-8)
  • (modified) offload/unittests/OffloadAPI/memory/olMemAlloc.cpp (+3-3)
  • (modified) offload/unittests/OffloadAPI/memory/olMemFill.cpp (+6-6)
  • (modified) offload/unittests/OffloadAPI/memory/olMemFree.cpp (+12-5)
  • (modified) offload/unittests/OffloadAPI/memory/olMemcpy.cpp (+10-10)
  • (modified) offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp (+1-1)
diff --git a/offload/liboffload/API/Memory.td b/offload/liboffload/API/Memory.td
index cc98b672a26a9..a24f05e72f5be 100644
--- a/offload/liboffload/API/Memory.td
+++ b/offload/liboffload/API/Memory.td
@@ -37,9 +37,12 @@ def olMemAlloc : Function {
 def olMemFree : Function {
   let desc = "Frees a memory allocation previously made by olMemAlloc.";
   let params = [
+    Param<"ol_platform_handle_t", "Platform", "handle of the platform that allocated this memory", PARAM_IN>,
     Param<"void*", "Address", "address of the allocation to free", PARAM_IN>,
   ];
-  let returns = [];
+  let returns = [
+    Return<"OL_ERRC_NOT_FOUND", ["memory was not allocated by this platform"]>
+  ];
 }
 
 def olMemcpy : Function {
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index 7e8e297831f45..fef3a5669e0d5 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -632,7 +632,7 @@ Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type,
   return Error::success();
 }
 
-Error olMemFree_impl(void *Address) {
+Error olMemFree_impl(ol_platform_handle_t Platform, void *Address) {
   ol_device_handle_t Device;
   ol_alloc_type_t Type;
   {
@@ -646,6 +646,7 @@ Error olMemFree_impl(void *Address) {
     Type = AllocInfo.Type;
     OffloadContext::get().AllocInfoMap.erase(Address);
   }
+  assert(Platform == Device->Platform);
 
   if (auto Res =
           Device->Device->dataDelete(Address, convertOlToPluginAllocTy(Type)))
diff --git a/offload/unittests/Conformance/include/mathtest/DeviceContext.hpp b/offload/unittests/Conformance/include/mathtest/DeviceContext.hpp
index 5c31fc3da53cd..7a11798856550 100644
--- a/offload/unittests/Conformance/include/mathtest/DeviceContext.hpp
+++ b/offload/unittests/Conformance/include/mathtest/DeviceContext.hpp
@@ -57,13 +57,13 @@ class DeviceContext {
   explicit DeviceContext(llvm::StringRef Platform, std::size_t DeviceId = 0);
 
   template <typename T>
-  ManagedBuffer<T> createManagedBuffer(std::size_t Size) const noexcept {
+  ManagedBuffer<T> createManagedBuffer(std::size_t Size) noexcept {
     void *UntypedAddress = nullptr;
 
     detail::allocManagedMemory(DeviceHandle, Size * sizeof(T), &UntypedAddress);
     T *TypedAddress = static_cast<T *>(UntypedAddress);
 
-    return ManagedBuffer<T>(TypedAddress, Size);
+    return ManagedBuffer<T>(getPlatformHandle(), TypedAddress, Size);
   }
 
   [[nodiscard]] llvm::Expected<std::shared_ptr<DeviceImage>>
@@ -120,6 +120,9 @@ class DeviceContext {
 
   [[nodiscard]] llvm::StringRef getPlatform() const noexcept;
 
+  [[nodiscard]] llvm::Expected<ol_platform_handle_t>
+  getPlatformHandle() noexcept;
+
 private:
   [[nodiscard]] llvm::Expected<ol_symbol_handle_t>
   getKernelHandle(ol_program_handle_t ProgramHandle,
@@ -131,6 +134,7 @@ class DeviceContext {
 
   std::size_t GlobalDeviceId;
   ol_device_handle_t DeviceHandle;
+  ol_platform_handle_t PlatformHandle = nullptr;
 };
 } // namespace mathtest
 
diff --git a/offload/unittests/Conformance/include/mathtest/DeviceResources.hpp b/offload/unittests/Conformance/include/mathtest/DeviceResources.hpp
index 860448afa3a01..6084732baf6ee 100644
--- a/offload/unittests/Conformance/include/mathtest/DeviceResources.hpp
+++ b/offload/unittests/Conformance/include/mathtest/DeviceResources.hpp
@@ -29,7 +29,7 @@ class DeviceContext;
 
 namespace detail {
 
-void freeDeviceMemory(void *Address) noexcept;
+void freeDeviceMemory(ol_platform_handle_t Platform, void *Address) noexcept;
 } // namespace detail
 
 //===----------------------------------------------------------------------===//
@@ -40,7 +40,7 @@ template <typename T> class [[nodiscard]] ManagedBuffer {
 public:
   ~ManagedBuffer() noexcept {
     if (Address)
-      detail::freeDeviceMemory(Address);
+      detail::freeDeviceMemory(Platform, Address);
   }
 
   ManagedBuffer(const ManagedBuffer &) = delete;
@@ -57,7 +57,7 @@ template <typename T> class [[nodiscard]] ManagedBuffer {
       return *this;
 
     if (Address)
-      detail::freeDeviceMemory(Address);
+      detail::freeDeviceMemory(Platform, Address);
 
     Address = Other.Address;
     Size = Other.Size;
@@ -85,9 +85,11 @@ template <typename T> class [[nodiscard]] ManagedBuffer {
 private:
   friend class DeviceContext;
 
-  explicit ManagedBuffer(T *Address, std::size_t Size) noexcept
-      : Address(Address), Size(Size) {}
+  explicit ManagedBuffer(ol_platform_handle_t Platform, T *Address,
+                         std::size_t Size) noexcept
+      : Platform(Platform), Address(Address), Size(Size) {}
 
+  ol_platform_handle_t Platform;
   T *Address = nullptr;
   std::size_t Size = 0;
 };
diff --git a/offload/unittests/Conformance/include/mathtest/GpuMathTest.hpp b/offload/unittests/Conformance/include/mathtest/GpuMathTest.hpp
index b88d6e9aebdc8..fdf30d58ae1e7 100644
--- a/offload/unittests/Conformance/include/mathtest/GpuMathTest.hpp
+++ b/offload/unittests/Conformance/include/mathtest/GpuMathTest.hpp
@@ -75,7 +75,7 @@ class [[nodiscard]] GpuMathTest final {
 
   ResultType run(GeneratorType &Generator,
                  std::size_t BufferSize = DefaultBufferSize,
-                 uint32_t GroupSize = DefaultGroupSize) const noexcept {
+                 uint32_t GroupSize = DefaultGroupSize) noexcept {
     assert(BufferSize > 0 && "Buffer size must be a positive value");
     assert(GroupSize > 0 && "Group size must be a positive value");
 
@@ -128,7 +128,7 @@ class [[nodiscard]] GpuMathTest final {
     return *ExpectedKernel;
   }
 
-  [[nodiscard]] auto createBuffers(std::size_t BufferSize) const {
+  [[nodiscard]] auto createBuffers(std::size_t BufferSize) {
     auto InBuffersTuple = std::apply(
         [&](auto... InTypeIdentities) {
           return std::make_tuple(
diff --git a/offload/unittests/Conformance/include/mathtest/OffloadForward.hpp b/offload/unittests/Conformance/include/mathtest/OffloadForward.hpp
index 788989a0d4211..44c4ab72c9be5 100644
--- a/offload/unittests/Conformance/include/mathtest/OffloadForward.hpp
+++ b/offload/unittests/Conformance/include/mathtest/OffloadForward.hpp
@@ -32,6 +32,9 @@ typedef struct ol_program_impl_t *ol_program_handle_t;
 struct ol_symbol_impl_t;
 typedef struct ol_symbol_impl_t *ol_symbol_handle_t;
 
+struct ol_platform_impl_t;
+typedef struct ol_platform_impl_t *ol_platform_handle_t;
+
 #ifdef __cplusplus
 }
 #endif // __cplusplus
diff --git a/offload/unittests/Conformance/lib/DeviceContext.cpp b/offload/unittests/Conformance/lib/DeviceContext.cpp
index 6c3425f1e17c2..d72f56ca1f175 100644
--- a/offload/unittests/Conformance/lib/DeviceContext.cpp
+++ b/offload/unittests/Conformance/lib/DeviceContext.cpp
@@ -286,6 +286,29 @@ DeviceContext::getKernelHandle(ol_program_handle_t ProgramHandle,
   return Handle;
 }
 
+llvm::Expected<ol_platform_handle_t>
+DeviceContext::getPlatformHandle() noexcept {
+  if (!PlatformHandle) {
+    const ol_result_t OlResult =
+        olGetDeviceInfo(DeviceHandle, OL_DEVICE_INFO_PLATFORM,
+                        sizeof(PlatformHandle), &PlatformHandle);
+
+    if (OlResult != OL_SUCCESS) {
+      PlatformHandle = nullptr;
+      llvm::StringRef Details =
+          OlResult->Details ? OlResult->Details : "No details provided";
+
+      // clang-format off
+      return llvm::createStringError(
+        llvm::Twine(Details) +
+        " (code " + llvm::Twine(OlResult->Code) + ")");
+      // clang-format on
+    }
+  }
+
+  return PlatformHandle;
+}
+
 void DeviceContext::launchKernelImpl(
     ol_symbol_handle_t KernelHandle, uint32_t NumGroups, uint32_t GroupSize,
     const void *KernelArgs, std::size_t KernelArgsSize) const noexcept {
diff --git a/offload/unittests/Conformance/lib/DeviceResources.cpp b/offload/unittests/Conformance/lib/DeviceResources.cpp
index d1c7b90e751e6..3271256917e45 100644
--- a/offload/unittests/Conformance/lib/DeviceResources.cpp
+++ b/offload/unittests/Conformance/lib/DeviceResources.cpp
@@ -24,9 +24,10 @@ using namespace mathtest;
 // Helpers
 //===----------------------------------------------------------------------===//
 
-void detail::freeDeviceMemory(void *Address) noexcept {
+void detail::freeDeviceMemory(ol_platform_handle_t Platform,
+                              void *Address) noexcept {
   if (Address)
-    OL_CHECK(olMemFree(Address));
+    OL_CHECK(olMemFree(Platform, Address));
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/offload/unittests/OffloadAPI/common/Fixtures.hpp b/offload/unittests/OffloadAPI/common/Fixtures.hpp
index 0538e60f276e3..db06a714c59a5 100644
--- a/offload/unittests/OffloadAPI/common/Fixtures.hpp
+++ b/offload/unittests/OffloadAPI/common/Fixtures.hpp
@@ -137,13 +137,12 @@ struct OffloadDeviceTest
     Device = DeviceParam.Handle;
     if (Device == nullptr)
       GTEST_SKIP() << "No available devices.";
+
+    ASSERT_SUCCESS(olGetDeviceInfo(Device, OL_DEVICE_INFO_PLATFORM,
+                                   sizeof(ol_platform_handle_t), &Platform));
   }
 
   ol_platform_backend_t getPlatformBackend() const {
-    ol_platform_handle_t Platform = nullptr;
-    if (olGetDeviceInfo(Device, OL_DEVICE_INFO_PLATFORM,
-                        sizeof(ol_platform_handle_t), &Platform))
-      return OL_PLATFORM_BACKEND_UNKNOWN;
     ol_platform_backend_t Backend;
     if (olGetPlatformInfo(Platform, OL_PLATFORM_INFO_BACKEND,
                           sizeof(ol_platform_backend_t), &Backend))
@@ -151,6 +150,7 @@ struct OffloadDeviceTest
     return Backend;
   }
 
+  ol_platform_handle_t Platform = nullptr;
   ol_device_handle_t Device = nullptr;
 };
 
diff --git a/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp b/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp
index 1dac8c50271b5..222c98d3bdc3f 100644
--- a/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp
+++ b/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp
@@ -101,7 +101,7 @@ TEST_P(olLaunchKernelFooTest, Success) {
     ASSERT_EQ(Data[i], i);
   }
 
-  ASSERT_SUCCESS(olMemFree(Mem));
+  ASSERT_SUCCESS(olMemFree(Platform, Mem));
 }
 
 TEST_P(olLaunchKernelFooTest, SuccessThreaded) {
@@ -123,7 +123,7 @@ TEST_P(olLaunchKernelFooTest, SuccessThreaded) {
       ASSERT_EQ(Data[i], i);
     }
 
-    ASSERT_SUCCESS(olMemFree(Mem));
+    ASSERT_SUCCESS(olMemFree(Platform, Mem));
   });
 }
 
@@ -151,7 +151,7 @@ TEST_P(olLaunchKernelFooTest, SuccessSynchronous) {
     ASSERT_EQ(Data[i], i);
   }
 
-  ASSERT_SUCCESS(olMemFree(Mem));
+  ASSERT_SUCCESS(olMemFree(Platform, Mem));
 }
 
 TEST_P(olLaunchKernelLocalMemTest, Success) {
@@ -176,7 +176,7 @@ TEST_P(olLaunchKernelLocalMemTest, Success) {
   for (uint32_t i = 0; i < LaunchArgs.GroupSize.x * LaunchArgs.NumGroups.x; i++)
     ASSERT_EQ(Data[i], (i % 64) * 2);
 
-  ASSERT_SUCCESS(olMemFree(Mem));
+  ASSERT_SUCCESS(olMemFree(Platform, Mem));
 }
 
 TEST_P(olLaunchKernelLocalMemReductionTest, Success) {
@@ -199,7 +199,7 @@ TEST_P(olLaunchKernelLocalMemReductionTest, Success) {
   for (uint32_t i = 0; i < LaunchArgs.NumGroups.x; i++)
     ASSERT_EQ(Data[i], 2 * LaunchArgs.GroupSize.x);
 
-  ASSERT_SUCCESS(olMemFree(Mem));
+  ASSERT_SUCCESS(olMemFree(Platform, Mem));
 }
 
 TEST_P(olLaunchKernelLocalMemStaticTest, Success) {
@@ -222,7 +222,7 @@ TEST_P(olLaunchKernelLocalMemStaticTest, Success) {
   for (uint32_t i = 0; i < LaunchArgs.NumGroups.x; i++)
     ASSERT_EQ(Data[i], 2 * LaunchArgs.GroupSize.x);
 
-  ASSERT_SUCCESS(olMemFree(Mem));
+  ASSERT_SUCCESS(olMemFree(Platform, Mem));
 }
 
 TEST_P(olLaunchKernelGlobalTest, Success) {
@@ -245,7 +245,7 @@ TEST_P(olLaunchKernelGlobalTest, Success) {
     ASSERT_EQ(Data[i], i * 2);
   }
 
-  ASSERT_SUCCESS(olMemFree(Mem));
+  ASSERT_SUCCESS(olMemFree(Platform, Mem));
 }
 
 TEST_P(olLaunchKernelGlobalTest, InvalidNotAKernel) {
@@ -273,7 +273,7 @@ TEST_P(olLaunchKernelGlobalCtorTest, Success) {
     ASSERT_EQ(Data[i], i + 100);
   }
 
-  ASSERT_SUCCESS(olMemFree(Mem));
+  ASSERT_SUCCESS(olMemFree(Platform, Mem));
 }
 
 TEST_P(olLaunchKernelGlobalDtorTest, Success) {
diff --git a/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp b/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp
index 00e428ec2abc7..46d382da61075 100644
--- a/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp
+++ b/offload/unittests/OffloadAPI/memory/olMemAlloc.cpp
@@ -17,21 +17,21 @@ TEST_P(olMemAllocTest, SuccessAllocManaged) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc));
   ASSERT_NE(Alloc, nullptr);
-  olMemFree(Alloc);
+  olMemFree(Platform, Alloc);
 }
 
 TEST_P(olMemAllocTest, SuccessAllocHost) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc));
   ASSERT_NE(Alloc, nullptr);
-  olMemFree(Alloc);
+  olMemFree(Platform, Alloc);
 }
 
 TEST_P(olMemAllocTest, SuccessAllocDevice) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
   ASSERT_NE(Alloc, nullptr);
-  olMemFree(Alloc);
+  olMemFree(Platform, Alloc);
 }
 
 TEST_P(olMemAllocTest, InvalidNullDevice) {
diff --git a/offload/unittests/OffloadAPI/memory/olMemFill.cpp b/offload/unittests/OffloadAPI/memory/olMemFill.cpp
index a84ed3d78eccf..e7098031e2ed3 100644
--- a/offload/unittests/OffloadAPI/memory/olMemFill.cpp
+++ b/offload/unittests/OffloadAPI/memory/olMemFill.cpp
@@ -39,7 +39,7 @@ struct olMemFillTest : OffloadQueueTest {
       ASSERT_EQ(AllocPtr[i], Pattern);
     }
 
-    olMemFree(Alloc);
+    olMemFree(Platform, Alloc);
   }
 };
 OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemFillTest);
@@ -92,7 +92,7 @@ TEST_P(olMemFillTest, SuccessLarge) {
     ASSERT_EQ(AllocPtr[i].B, UINT64_MAX);
   }
 
-  olMemFree(Alloc);
+  olMemFree(Platform, Alloc);
 }
 
 TEST_P(olMemFillTest, SuccessLargeEnqueue) {
@@ -120,7 +120,7 @@ TEST_P(olMemFillTest, SuccessLargeEnqueue) {
     ASSERT_EQ(AllocPtr[i].B, UINT64_MAX);
   }
 
-  olMemFree(Alloc);
+  olMemFree(Platform, Alloc);
 }
 
 TEST_P(olMemFillTest, SuccessLargeByteAligned) {
@@ -146,7 +146,7 @@ TEST_P(olMemFillTest, SuccessLargeByteAligned) {
     ASSERT_EQ(AllocPtr[i].C, 255);
   }
 
-  olMemFree(Alloc);
+  olMemFree(Platform, Alloc);
 }
 
 TEST_P(olMemFillTest, SuccessLargeByteAlignedEnqueue) {
@@ -176,7 +176,7 @@ TEST_P(olMemFillTest, SuccessLargeByteAlignedEnqueue) {
     ASSERT_EQ(AllocPtr[i].C, 255);
   }
 
-  olMemFree(Alloc);
+  olMemFree(Platform, Alloc);
 }
 
 TEST_P(olMemFillTest, InvalidPatternSize) {
@@ -189,5 +189,5 @@ TEST_P(olMemFillTest, InvalidPatternSize) {
                olMemFill(Queue, Alloc, sizeof(Pattern), &Pattern, Size));
 
   olSyncQueue(Queue);
-  olMemFree(Alloc);
+  olMemFree(Platform, Alloc);
 }
diff --git a/offload/unittests/OffloadAPI/memory/olMemFree.cpp b/offload/unittests/OffloadAPI/memory/olMemFree.cpp
index dfaf9bdef3189..9c602190f7814 100644
--- a/offload/unittests/OffloadAPI/memory/olMemFree.cpp
+++ b/offload/unittests/OffloadAPI/memory/olMemFree.cpp
@@ -16,24 +16,31 @@ OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemFreeTest);
 TEST_P(olMemFreeTest, SuccessFreeManaged) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, 1024, &Alloc));
-  ASSERT_SUCCESS(olMemFree(Alloc));
+  ASSERT_SUCCESS(olMemFree(Platform, Alloc));
 }
 
 TEST_P(olMemFreeTest, SuccessFreeHost) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_HOST, 1024, &Alloc));
-  ASSERT_SUCCESS(olMemFree(Alloc));
+  ASSERT_SUCCESS(olMemFree(Platform, Alloc));
 }
 
 TEST_P(olMemFreeTest, SuccessFreeDevice) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
-  ASSERT_SUCCESS(olMemFree(Alloc));
+  ASSERT_SUCCESS(olMemFree(Platform, Alloc));
 }
 
 TEST_P(olMemFreeTest, InvalidNullPtr) {
   void *Alloc = nullptr;
   ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
-  ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, olMemFree(nullptr));
-  ASSERT_SUCCESS(olMemFree(Alloc));
+  ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, olMemFree(Platform, nullptr));
+  ASSERT_SUCCESS(olMemFree(Platform, Alloc));
+}
+
+TEST_P(olMemFreeTest, InvalidPlatformPtr) {
+  void *Alloc = nullptr;
+  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_DEVICE, 1024, &Alloc));
+  ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE, olMemFree(nullptr, Alloc));
+  ASSERT_SUCCESS(olMemFree(Platform, Alloc));
 }
diff --git a/offload/unittests/OffloadAPI/memory/olMemcpy.cpp b/offload/unittests/OffloadAPI/memory/olMemcpy.cpp
index cc67d782ef403..3f15a957fa201 100644
--- a/offload/unittests/OffloadAPI/memory/olMemcpy.cpp
+++ b/offload/unittests/OffloadAPI/memory/olMemcpy.cpp
@@ -46,7 +46,7 @@ TEST_P(olMemcpyTest, SuccessHtoD) {
   std::vector<uint8_t> Input(Size, 42);
   ASSERT_SUCCESS(olMemcpy(Queue, Alloc, Device, Input.data(), Host, Size));
   olSyncQueue(Queue);
-  olMemFree(Alloc);
+  olMemFree(Platform, Alloc);
 }
 
 TEST_P(olMemcpyTest, SuccessDtoH) {
@@ -62,7 +62,7 @@ TEST_P(olMemcpyTest, SuccessDtoH) {
   for (uint8_t Val : Output) {
     ASSERT_EQ(Val, 42);
   }
-  ASSERT_SUCCESS(olMemFree(Alloc));
+  ASSERT_SUCCESS(olMemFree(Platform, Alloc));
 }
 
 TEST_P(olMemcpyTest, SuccessDtoD) {
@@ -81,8 +81,8 @@ TEST_P(olMemcpyTest, SuccessDtoD) {
   for (uint8_t Val : Output) {
     ASSERT_EQ(Val, 42);
   }
-  ASSERT_SUCCESS(olMemFree(AllocA));
-  ASSERT_SUCCESS(olMemFree(AllocB));
+  ASSERT_SUCCESS(olMemFree(Platform, AllocA));
+  ASSERT_SUCCESS(olMemFree(Platform, AllocB));
 }
 
 TEST_P(olMemcpyTest, SuccessHtoHSync) {
@@ -110,7 +110,7 @@ TEST_P(olMemcpyTest, SuccessDtoHSync) {
   for (uint8_t Val : Output) {
     ASSERT_EQ(Val, 42);
   }
-  ASSERT_SUCCESS(olMemFree(Alloc));
+  ASSERT_SUCCESS(olMemFree(Platform, Alloc));
 }
 
 TEST_P(olMemcpyTest, SuccessSizeZero) {
@@ -146,8 +146,8 @@ TEST_P(olMemcpyGlobalTest, SuccessRoundTrip) {
   for (uint32_t I = 0; I < 64; I++)
     ASSERT_EQ(DestData[I], I);
 
-  ASSERT_SUCCESS(olMemFree(DestMem));
-  ASSERT_SUCCESS(olMemFree(SourceMem));
+  ASSERT_SUCCESS(olMemFree(Platform, DestMem));
+  ASSERT_SUCCESS(olMemFree(Platform, SourceMem));
 }
 
 TEST_P(olMemcpyGlobalTest, SuccessWrite) {
@@ -178,8 +178,8 @@ TEST_P(olMemcpyGlobalTest, SuccessWrite) {
   for (uint32_t I = 0; I < 64; I++)
     ASSERT_EQ(DestData[I], I);
 
-  ASSERT_SUCCESS(olMemFree(DestMem));
-  ASSERT_SUCCESS(olMemFree(SourceMem));
+  ASSERT_SUCCESS(olMemFree(Platform, DestMem));
+  ASSERT_SUCCESS(olMemFree(Platform, SourceMem));
 }
 
 TEST_P(olMemcpyGlobalTest, SuccessRead) {
@@ -199,5 +199,5 @@ TEST_P(olMemcpyGlobalTest, SuccessRead) {
   for (uint32_t I = 0; I < 64; I++)
     ASSERT_EQ(DestData[I], I * 2);
 
-  ASSERT_SUCCESS(olMemFree(DestMem));
+  ASSERT_SUCCESS(olMemFree(Platform, DestMem));
 }
diff --git a/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp b/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp
index aa86750f6adf9..b45ca6977b4dc 100644
--- a/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp
+++ b/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp
@@ -93,7 +93,7 @@ TEST_P(olLaunchHostFunctionKernelTest, SuccessBlocking) {
   }
 
   ASSERT_SUCCESS(olDestroyQueue(Queue));
-  ASSERT_SUCCESS(olMemFree(Mem));
+  ASSERT_SUCCESS(olMemFree(Platform, Mem));
 }
 
 TEST_P(olLaunchHostFunctionTest, InvalidNullCallback) {

…157773)

This PR is a follow-up to the change introduced in #157478, which added
a `platform` parameter to the `olMemFree` function.
@RossBrunton
Copy link
Contributor Author

@jhuber6 @callumfare Can you give this another look over when you get the chance?

Copy link
Contributor

@pbalcer pbalcer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me, but one problem I've noticed in the interface that a single plugin can only expose a single platform. This is not true in general. For example, in level-zero, there can be multiple separate drivers in the OS, handling different devices. These need to be separate platforms.

Copy link
Contributor

@jhuber6 jhuber6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would still really prefer if we didn't need to change the free interface. As far as I understand, the general problem is that when we start working with heterogeneous devices, we could have two devices that return the same address and thus confound our attempts to map pointers to devices.

The only time this can occur is for allocations on the device itself, managed allocations will go through the shared Linux kernel. If I understand correctly, after allocation we will put something into some kind of map. If we already have a map, then we should be able to detect if there is a conflict. If there is a conflict, we should just allocate the memory again, try adding it again, then free the old pointer. It's slightly inefficient, but this is a one in a million edge case, it does not need to be fast. Does this make sense to you?

@RossBrunton
Copy link
Contributor Author

@jhuber6 @pbalcer To summarize the discussion in the meeting:

We need the ability to, given a pointer to somewhere in memory, figure out which platform allocated it so that we can dispatch to the appropriate platform for olMemFree/olGetMemInfo and friends.

There are (at least) two possible approaches:

  1. We store a list of allocations in liboffload, along with their size, device and maybe some other information. This is an ordered list, and we can do a binary search to find which memory allocation a pointer is in. When we allocate, we check to see if the new allocation overlaps another one. If it does, we make a new allocation and throw away the old one. This ensures that every address is associated with at most one device/platform.
  2. We require the liboffload user to track the device/platform, and pass that as a parameter to MemFree and GetMemInfo. This means that liboffload doesn't need to keep its own list of allocations (they will de-facto be managed by the user, which may have been tracking it anyway). GetInfo queries themselves can be done in a platform-specific way which will hopefully be more performant.

Performance wise, option 1 requires a binary search on every allocation, free and info query. Hopefully collisions are rare enough that they don't have an impact on performance. Option 2's performance will depend on how the user keeps track of platforms, and could be faster if they can access the platform handle in constant time.

Ergonomics wise, option 1 is much simpler, since the user doesn't have to track metadata about their allocations.

I don't remember if this got mentioned in the meeting, but I think there was an idea for having the runtime query each platform in turn to see if they owned the pointer. I don't think this will work because unused platforms might be in some uninitialized or sleeping state. So sending information queries might provoke the driver into going and waking up a GPU or something. And that's assuming that the platform driver is not bugged in such a way that it "claims" memory it hasn't allocated.

@RossBrunton
Copy link
Contributor Author

Replaced by #159567

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants