Skip to content

Conversation

@RossBrunton
Copy link
Contributor

Add two new symbol info types for getting the bounds of a global
variable. As well as a number of tests for reading/writing to it.

@llvmbot
Copy link
Member

llvmbot commented Jul 10, 2025

@llvm/pr-subscribers-offload

Author: Ross Brunton (RossBrunton)

Changes

Add two new symbol info types for getting the bounds of a global
variable. As well as a number of tests for reading/writing to it.


Full diff: https://github.com/llvm/llvm-project/pull/147972.diff

6 Files Affected:

  • (modified) offload/liboffload/API/Symbol.td (+3-1)
  • (modified) offload/liboffload/src/OffloadImpl.cpp (+19)
  • (modified) offload/tools/offload-tblgen/PrintGen.cpp (+6-2)
  • (modified) offload/unittests/OffloadAPI/memory/olMemcpy.cpp (+103)
  • (modified) offload/unittests/OffloadAPI/symbol/olGetSymbolInfo.cpp (+28)
  • (modified) offload/unittests/OffloadAPI/symbol/olGetSymbolInfoSize.cpp (+14)
diff --git a/offload/liboffload/API/Symbol.td b/offload/liboffload/API/Symbol.td
index 8b1bef1b2b6e4..7310c772d757e 100644
--- a/offload/liboffload/API/Symbol.td
+++ b/offload/liboffload/API/Symbol.td
@@ -38,7 +38,9 @@ def : Enum {
   let desc = "Supported symbol info.";
   let is_typed = 1;
   let etors = [
-    TaggedEtor<"KIND", "ol_symbol_kind_t", "The kind of this symbol.">
+    TaggedEtor<"KIND", "ol_symbol_kind_t", "The kind of this symbol.">,
+    TaggedEtor<"GLOBAL_VARIABLE_ADDRESS", "void *", "The address in memory for this global variable.">,
+    TaggedEtor<"GLOBAL_VARIABLE_SIZE", "size_t", "The size in bytes for this global variable.">,
   ];
 }
 
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index 5aefafb1e57ea..ecdd1e76dfeff 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -751,9 +751,28 @@ Error olGetSymbolInfoImplDetail(ol_symbol_handle_t Symbol,
                                 void *PropValue, size_t *PropSizeRet) {
   InfoWriter Info(PropSize, PropValue, PropSizeRet);
 
+  auto CheckKind = [&](ol_symbol_kind_t Required) {
+    if (Symbol->Kind != Required) {
+      std::string ErrBuffer;
+      llvm::raw_string_ostream(ErrBuffer)
+          << PropName << ": Expected a symbol of Kind " << Required
+          << " but given a symbol of Kind " << Symbol->Kind;
+      return Plugin::error(ErrorCode::SYMBOL_KIND, ErrBuffer.c_str());
+    }
+    return Plugin::success();
+  };
+
   switch (PropName) {
   case OL_SYMBOL_INFO_KIND:
     return Info.write<ol_symbol_kind_t>(Symbol->Kind);
+  case OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS:
+    if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE))
+      return Err;
+    return Info.write<void *>(std::get<GlobalTy>(Symbol->PluginImpl).getPtr());
+  case OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE:
+    if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE))
+      return Err;
+    return Info.write<size_t>(std::get<GlobalTy>(Symbol->PluginImpl).getSize());
   default:
     return createOffloadError(ErrorCode::INVALID_ENUMERATION,
                               "olGetSymbolInfo enum '%i' is invalid", PropName);
diff --git a/offload/tools/offload-tblgen/PrintGen.cpp b/offload/tools/offload-tblgen/PrintGen.cpp
index d1189688a90a3..89d7c820426cf 100644
--- a/offload/tools/offload-tblgen/PrintGen.cpp
+++ b/offload/tools/offload-tblgen/PrintGen.cpp
@@ -74,8 +74,12 @@ inline void printTagged(llvm::raw_ostream &os, const void *ptr, {0} value, size_
     if (Type == "char[]") {
       OS << formatv(TAB_2 "printPtr(os, (const char*) ptr);\n");
     } else {
-      OS << formatv(TAB_2 "const {0} * const tptr = (const {0} * const)ptr;\n",
-                    Type);
+      if (Type == "void *")
+        OS << formatv(TAB_2 "void * const * const tptr = (void * "
+                            "const * const)ptr;\n");
+      else
+        OS << formatv(
+            TAB_2 "const {0} * const tptr = (const {0} * const)ptr;\n", Type);
       // TODO: Handle other cases here
       OS << TAB_2 "os << (const void *)tptr << \" (\";\n";
       if (Type.ends_with("*")) {
diff --git a/offload/unittests/OffloadAPI/memory/olMemcpy.cpp b/offload/unittests/OffloadAPI/memory/olMemcpy.cpp
index c1762b451b81d..204ed8fa22fc8 100644
--- a/offload/unittests/OffloadAPI/memory/olMemcpy.cpp
+++ b/offload/unittests/OffloadAPI/memory/olMemcpy.cpp
@@ -13,6 +13,30 @@
 using olMemcpyTest = OffloadQueueTest;
 OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemcpyTest);
 
+struct olMemcpyGlobalTest : OffloadGlobalTest {
+  void SetUp() override {
+    RETURN_ON_FATAL_FAILURE(OffloadGlobalTest::SetUp());
+    ASSERT_SUCCESS(olGetKernel(Program, "read", &ReadKernel));
+    ASSERT_SUCCESS(olGetKernel(Program, "write", &WriteKernel));
+    ASSERT_SUCCESS(olCreateQueue(Device, &Queue));
+    ASSERT_SUCCESS(olGetSymbolInfo(
+        Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS, sizeof(Addr), &Addr));
+
+    LaunchArgs.Dimensions = 1;
+    LaunchArgs.GroupSize = {64, 1, 1};
+    LaunchArgs.NumGroups = {1, 1, 1};
+
+    LaunchArgs.DynSharedMemory = 0;
+  }
+
+  ol_kernel_launch_size_args_t LaunchArgs{};
+  void *Addr;
+  ol_symbol_handle_t ReadKernel;
+  ol_symbol_handle_t WriteKernel;
+  ol_queue_handle_t Queue;
+};
+OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemcpyGlobalTest);
+
 TEST_P(olMemcpyTest, SuccessHtoD) {
   constexpr size_t Size = 1024;
   void *Alloc;
@@ -105,3 +129,82 @@ TEST_P(olMemcpyTest, SuccessSizeZero) {
   ASSERT_SUCCESS(
       olMemcpy(nullptr, Output.data(), Host, Input.data(), Host, 0, nullptr));
 }
+
+TEST_P(olMemcpyGlobalTest, SuccessRoundTrip) {
+  void *SourceMem;
+  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+                            64 * sizeof(uint32_t), &SourceMem));
+  uint32_t *SourceData = (uint32_t *)SourceMem;
+  for (auto I = 0; I < 64; I++)
+    SourceData[I] = I;
+
+  void *DestMem;
+  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+                            64 * sizeof(uint32_t), &DestMem));
+
+  ASSERT_SUCCESS(olMemcpy(Queue, Addr, Device, SourceMem, Host,
+                          64 * sizeof(uint32_t), nullptr));
+  ASSERT_SUCCESS(olWaitQueue(Queue));
+  ASSERT_SUCCESS(olMemcpy(Queue, DestMem, Host, Addr, Device,
+                          64 * sizeof(uint32_t), nullptr));
+  ASSERT_SUCCESS(olWaitQueue(Queue));
+
+  uint32_t *DestData = (uint32_t *)DestMem;
+  for (uint32_t I = 0; I < 64; I++)
+    ASSERT_EQ(DestData[I], I);
+
+  ASSERT_SUCCESS(olMemFree(DestMem));
+  ASSERT_SUCCESS(olMemFree(SourceMem));
+}
+
+TEST_P(olMemcpyGlobalTest, SuccessWrite) {
+  void *SourceMem;
+  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+                            LaunchArgs.GroupSize.x * sizeof(uint32_t),
+                            &SourceMem));
+  uint32_t *SourceData = (uint32_t *)SourceMem;
+  for (auto I = 0; I < 64; I++)
+    SourceData[I] = I;
+
+  void *DestMem;
+  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+                            LaunchArgs.GroupSize.x * sizeof(uint32_t),
+                            &DestMem));
+  struct {
+    void *Mem;
+  } Args{DestMem};
+
+  ASSERT_SUCCESS(olMemcpy(Queue, Addr, Device, SourceMem, Host,
+                          64 * sizeof(uint32_t), nullptr));
+  ASSERT_SUCCESS(olWaitQueue(Queue));
+  ASSERT_SUCCESS(olLaunchKernel(Queue, Device, ReadKernel, &Args, sizeof(Args),
+                                &LaunchArgs, nullptr));
+  ASSERT_SUCCESS(olWaitQueue(Queue));
+
+  uint32_t *DestData = (uint32_t *)DestMem;
+  for (uint32_t I = 0; I < 64; I++)
+    ASSERT_EQ(DestData[I], I);
+
+  ASSERT_SUCCESS(olMemFree(DestMem));
+  ASSERT_SUCCESS(olMemFree(SourceMem));
+}
+
+TEST_P(olMemcpyGlobalTest, SuccessRead) {
+  void *DestMem;
+  ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+                            LaunchArgs.GroupSize.x * sizeof(uint32_t),
+                            &DestMem));
+
+  ASSERT_SUCCESS(olLaunchKernel(Queue, Device, WriteKernel, nullptr, 0,
+                                &LaunchArgs, nullptr));
+  ASSERT_SUCCESS(olWaitQueue(Queue));
+  ASSERT_SUCCESS(olMemcpy(Queue, DestMem, Host, Addr, Device,
+                          64 * sizeof(uint32_t), nullptr));
+  ASSERT_SUCCESS(olWaitQueue(Queue));
+
+  uint32_t *DestData = (uint32_t *)DestMem;
+  for (uint32_t I = 0; I < 64; I++)
+    ASSERT_EQ(DestData[I], I * 2);
+
+  ASSERT_SUCCESS(olMemFree(DestMem));
+}
diff --git a/offload/unittests/OffloadAPI/symbol/olGetSymbolInfo.cpp b/offload/unittests/OffloadAPI/symbol/olGetSymbolInfo.cpp
index 100a374430372..ed8f4716974cd 100644
--- a/offload/unittests/OffloadAPI/symbol/olGetSymbolInfo.cpp
+++ b/offload/unittests/OffloadAPI/symbol/olGetSymbolInfo.cpp
@@ -30,6 +30,34 @@ TEST_P(olGetSymbolInfoGlobalTest, SuccessKind) {
   ASSERT_EQ(RetrievedKind, OL_SYMBOL_KIND_GLOBAL_VARIABLE);
 }
 
+TEST_P(olGetSymbolInfoKernelTest, InvalidAddress) {
+  void *RetrievedAddr;
+  ASSERT_ERROR(OL_ERRC_SYMBOL_KIND,
+               olGetSymbolInfo(Kernel, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS,
+                               sizeof(RetrievedAddr), &RetrievedAddr));
+}
+
+TEST_P(olGetSymbolInfoGlobalTest, SuccessAddress) {
+  void *RetrievedAddr = nullptr;
+  ASSERT_SUCCESS(olGetSymbolInfo(Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS,
+                                 sizeof(RetrievedAddr), &RetrievedAddr));
+  ASSERT_NE(RetrievedAddr, nullptr);
+}
+
+TEST_P(olGetSymbolInfoKernelTest, InvalidSize) {
+  size_t RetrievedSize;
+  ASSERT_ERROR(OL_ERRC_SYMBOL_KIND,
+               olGetSymbolInfo(Kernel, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE,
+                               sizeof(RetrievedSize), &RetrievedSize));
+}
+
+TEST_P(olGetSymbolInfoGlobalTest, SuccessSize) {
+  size_t RetrievedSize = 0;
+  ASSERT_SUCCESS(olGetSymbolInfo(Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE,
+                                 sizeof(RetrievedSize), &RetrievedSize));
+  ASSERT_EQ(RetrievedSize, 64 * sizeof(uint32_t));
+}
+
 TEST_P(olGetSymbolInfoKernelTest, InvalidNullHandle) {
   ol_symbol_kind_t RetrievedKind;
   ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
diff --git a/offload/unittests/OffloadAPI/symbol/olGetSymbolInfoSize.cpp b/offload/unittests/OffloadAPI/symbol/olGetSymbolInfoSize.cpp
index aa7a061a9ef7a..ec011865cc6ad 100644
--- a/offload/unittests/OffloadAPI/symbol/olGetSymbolInfoSize.cpp
+++ b/offload/unittests/OffloadAPI/symbol/olGetSymbolInfoSize.cpp
@@ -28,6 +28,20 @@ TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessKind) {
   ASSERT_EQ(Size, sizeof(ol_symbol_kind_t));
 }
 
+TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessAddress) {
+  size_t Size = 0;
+  ASSERT_SUCCESS(olGetSymbolInfoSize(
+      Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS, &Size));
+  ASSERT_EQ(Size, sizeof(void *));
+}
+
+TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessSize) {
+  size_t Size = 0;
+  ASSERT_SUCCESS(
+      olGetSymbolInfoSize(Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE, &Size));
+  ASSERT_EQ(Size, sizeof(size_t));
+}
+
 TEST_P(olGetSymbolInfoSizeKernelTest, InvalidNullHandle) {
   size_t Size = 0;
   ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,

Copy link
Contributor

@jhuber6 jhuber6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we could just make it getSymbol instead of getKernel. We need to disambiguate for the underlying runtime calls (AMDGPU needs .kd and CUDA uses a different one) but that should be easily doable with a quick lookup in the ELF to see if the symbol is a function or variable.

@RossBrunton RossBrunton force-pushed the users/RossBrunton/symbol3 branch from 8c44953 to f2001d4 Compare July 11, 2025 14:14
@RossBrunton RossBrunton force-pushed the users/RossBrunton/symbol4 branch from b2406a5 to 77a4183 Compare July 11, 2025 14:19
Base automatically changed from users/RossBrunton/symbol3 to main July 11, 2025 14:29
Add two new symbol info types for getting the bounds of a global
variable. As well as a number of tests for reading/writing to it.
@RossBrunton RossBrunton force-pushed the users/RossBrunton/symbol4 branch from 77a4183 to 69089e4 Compare July 11, 2025 14:30
@RossBrunton RossBrunton merged commit 2fdeeef into main Jul 11, 2025
9 checks passed
@RossBrunton RossBrunton deleted the users/RossBrunton/symbol4 branch July 11, 2025 15:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants