-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[Offload] Add global variable address/size queries #147972
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-offload Author: Ross Brunton (RossBrunton) ChangesAdd two new symbol info types for getting the bounds of a global Full diff: https://github.com/llvm/llvm-project/pull/147972.diff 6 Files Affected:
diff --git a/offload/liboffload/API/Symbol.td b/offload/liboffload/API/Symbol.td
index 8b1bef1b2b6e4..7310c772d757e 100644
--- a/offload/liboffload/API/Symbol.td
+++ b/offload/liboffload/API/Symbol.td
@@ -38,7 +38,9 @@ def : Enum {
let desc = "Supported symbol info.";
let is_typed = 1;
let etors = [
- TaggedEtor<"KIND", "ol_symbol_kind_t", "The kind of this symbol.">
+ TaggedEtor<"KIND", "ol_symbol_kind_t", "The kind of this symbol.">,
+ TaggedEtor<"GLOBAL_VARIABLE_ADDRESS", "void *", "The address in memory for this global variable.">,
+ TaggedEtor<"GLOBAL_VARIABLE_SIZE", "size_t", "The size in bytes for this global variable.">,
];
}
diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp
index 5aefafb1e57ea..ecdd1e76dfeff 100644
--- a/offload/liboffload/src/OffloadImpl.cpp
+++ b/offload/liboffload/src/OffloadImpl.cpp
@@ -751,9 +751,28 @@ Error olGetSymbolInfoImplDetail(ol_symbol_handle_t Symbol,
void *PropValue, size_t *PropSizeRet) {
InfoWriter Info(PropSize, PropValue, PropSizeRet);
+ auto CheckKind = [&](ol_symbol_kind_t Required) {
+ if (Symbol->Kind != Required) {
+ std::string ErrBuffer;
+ llvm::raw_string_ostream(ErrBuffer)
+ << PropName << ": Expected a symbol of Kind " << Required
+ << " but given a symbol of Kind " << Symbol->Kind;
+ return Plugin::error(ErrorCode::SYMBOL_KIND, ErrBuffer.c_str());
+ }
+ return Plugin::success();
+ };
+
switch (PropName) {
case OL_SYMBOL_INFO_KIND:
return Info.write<ol_symbol_kind_t>(Symbol->Kind);
+ case OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS:
+ if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE))
+ return Err;
+ return Info.write<void *>(std::get<GlobalTy>(Symbol->PluginImpl).getPtr());
+ case OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE:
+ if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE))
+ return Err;
+ return Info.write<size_t>(std::get<GlobalTy>(Symbol->PluginImpl).getSize());
default:
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
"olGetSymbolInfo enum '%i' is invalid", PropName);
diff --git a/offload/tools/offload-tblgen/PrintGen.cpp b/offload/tools/offload-tblgen/PrintGen.cpp
index d1189688a90a3..89d7c820426cf 100644
--- a/offload/tools/offload-tblgen/PrintGen.cpp
+++ b/offload/tools/offload-tblgen/PrintGen.cpp
@@ -74,8 +74,12 @@ inline void printTagged(llvm::raw_ostream &os, const void *ptr, {0} value, size_
if (Type == "char[]") {
OS << formatv(TAB_2 "printPtr(os, (const char*) ptr);\n");
} else {
- OS << formatv(TAB_2 "const {0} * const tptr = (const {0} * const)ptr;\n",
- Type);
+ if (Type == "void *")
+ OS << formatv(TAB_2 "void * const * const tptr = (void * "
+ "const * const)ptr;\n");
+ else
+ OS << formatv(
+ TAB_2 "const {0} * const tptr = (const {0} * const)ptr;\n", Type);
// TODO: Handle other cases here
OS << TAB_2 "os << (const void *)tptr << \" (\";\n";
if (Type.ends_with("*")) {
diff --git a/offload/unittests/OffloadAPI/memory/olMemcpy.cpp b/offload/unittests/OffloadAPI/memory/olMemcpy.cpp
index c1762b451b81d..204ed8fa22fc8 100644
--- a/offload/unittests/OffloadAPI/memory/olMemcpy.cpp
+++ b/offload/unittests/OffloadAPI/memory/olMemcpy.cpp
@@ -13,6 +13,30 @@
using olMemcpyTest = OffloadQueueTest;
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemcpyTest);
+struct olMemcpyGlobalTest : OffloadGlobalTest {
+ void SetUp() override {
+ RETURN_ON_FATAL_FAILURE(OffloadGlobalTest::SetUp());
+ ASSERT_SUCCESS(olGetKernel(Program, "read", &ReadKernel));
+ ASSERT_SUCCESS(olGetKernel(Program, "write", &WriteKernel));
+ ASSERT_SUCCESS(olCreateQueue(Device, &Queue));
+ ASSERT_SUCCESS(olGetSymbolInfo(
+ Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS, sizeof(Addr), &Addr));
+
+ LaunchArgs.Dimensions = 1;
+ LaunchArgs.GroupSize = {64, 1, 1};
+ LaunchArgs.NumGroups = {1, 1, 1};
+
+ LaunchArgs.DynSharedMemory = 0;
+ }
+
+ ol_kernel_launch_size_args_t LaunchArgs{};
+ void *Addr;
+ ol_symbol_handle_t ReadKernel;
+ ol_symbol_handle_t WriteKernel;
+ ol_queue_handle_t Queue;
+};
+OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemcpyGlobalTest);
+
TEST_P(olMemcpyTest, SuccessHtoD) {
constexpr size_t Size = 1024;
void *Alloc;
@@ -105,3 +129,82 @@ TEST_P(olMemcpyTest, SuccessSizeZero) {
ASSERT_SUCCESS(
olMemcpy(nullptr, Output.data(), Host, Input.data(), Host, 0, nullptr));
}
+
+TEST_P(olMemcpyGlobalTest, SuccessRoundTrip) {
+ void *SourceMem;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+ 64 * sizeof(uint32_t), &SourceMem));
+ uint32_t *SourceData = (uint32_t *)SourceMem;
+ for (auto I = 0; I < 64; I++)
+ SourceData[I] = I;
+
+ void *DestMem;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+ 64 * sizeof(uint32_t), &DestMem));
+
+ ASSERT_SUCCESS(olMemcpy(Queue, Addr, Device, SourceMem, Host,
+ 64 * sizeof(uint32_t), nullptr));
+ ASSERT_SUCCESS(olWaitQueue(Queue));
+ ASSERT_SUCCESS(olMemcpy(Queue, DestMem, Host, Addr, Device,
+ 64 * sizeof(uint32_t), nullptr));
+ ASSERT_SUCCESS(olWaitQueue(Queue));
+
+ uint32_t *DestData = (uint32_t *)DestMem;
+ for (uint32_t I = 0; I < 64; I++)
+ ASSERT_EQ(DestData[I], I);
+
+ ASSERT_SUCCESS(olMemFree(DestMem));
+ ASSERT_SUCCESS(olMemFree(SourceMem));
+}
+
+TEST_P(olMemcpyGlobalTest, SuccessWrite) {
+ void *SourceMem;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+ LaunchArgs.GroupSize.x * sizeof(uint32_t),
+ &SourceMem));
+ uint32_t *SourceData = (uint32_t *)SourceMem;
+ for (auto I = 0; I < 64; I++)
+ SourceData[I] = I;
+
+ void *DestMem;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+ LaunchArgs.GroupSize.x * sizeof(uint32_t),
+ &DestMem));
+ struct {
+ void *Mem;
+ } Args{DestMem};
+
+ ASSERT_SUCCESS(olMemcpy(Queue, Addr, Device, SourceMem, Host,
+ 64 * sizeof(uint32_t), nullptr));
+ ASSERT_SUCCESS(olWaitQueue(Queue));
+ ASSERT_SUCCESS(olLaunchKernel(Queue, Device, ReadKernel, &Args, sizeof(Args),
+ &LaunchArgs, nullptr));
+ ASSERT_SUCCESS(olWaitQueue(Queue));
+
+ uint32_t *DestData = (uint32_t *)DestMem;
+ for (uint32_t I = 0; I < 64; I++)
+ ASSERT_EQ(DestData[I], I);
+
+ ASSERT_SUCCESS(olMemFree(DestMem));
+ ASSERT_SUCCESS(olMemFree(SourceMem));
+}
+
+TEST_P(olMemcpyGlobalTest, SuccessRead) {
+ void *DestMem;
+ ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
+ LaunchArgs.GroupSize.x * sizeof(uint32_t),
+ &DestMem));
+
+ ASSERT_SUCCESS(olLaunchKernel(Queue, Device, WriteKernel, nullptr, 0,
+ &LaunchArgs, nullptr));
+ ASSERT_SUCCESS(olWaitQueue(Queue));
+ ASSERT_SUCCESS(olMemcpy(Queue, DestMem, Host, Addr, Device,
+ 64 * sizeof(uint32_t), nullptr));
+ ASSERT_SUCCESS(olWaitQueue(Queue));
+
+ uint32_t *DestData = (uint32_t *)DestMem;
+ for (uint32_t I = 0; I < 64; I++)
+ ASSERT_EQ(DestData[I], I * 2);
+
+ ASSERT_SUCCESS(olMemFree(DestMem));
+}
diff --git a/offload/unittests/OffloadAPI/symbol/olGetSymbolInfo.cpp b/offload/unittests/OffloadAPI/symbol/olGetSymbolInfo.cpp
index 100a374430372..ed8f4716974cd 100644
--- a/offload/unittests/OffloadAPI/symbol/olGetSymbolInfo.cpp
+++ b/offload/unittests/OffloadAPI/symbol/olGetSymbolInfo.cpp
@@ -30,6 +30,34 @@ TEST_P(olGetSymbolInfoGlobalTest, SuccessKind) {
ASSERT_EQ(RetrievedKind, OL_SYMBOL_KIND_GLOBAL_VARIABLE);
}
+TEST_P(olGetSymbolInfoKernelTest, InvalidAddress) {
+ void *RetrievedAddr;
+ ASSERT_ERROR(OL_ERRC_SYMBOL_KIND,
+ olGetSymbolInfo(Kernel, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS,
+ sizeof(RetrievedAddr), &RetrievedAddr));
+}
+
+TEST_P(olGetSymbolInfoGlobalTest, SuccessAddress) {
+ void *RetrievedAddr = nullptr;
+ ASSERT_SUCCESS(olGetSymbolInfo(Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS,
+ sizeof(RetrievedAddr), &RetrievedAddr));
+ ASSERT_NE(RetrievedAddr, nullptr);
+}
+
+TEST_P(olGetSymbolInfoKernelTest, InvalidSize) {
+ size_t RetrievedSize;
+ ASSERT_ERROR(OL_ERRC_SYMBOL_KIND,
+ olGetSymbolInfo(Kernel, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE,
+ sizeof(RetrievedSize), &RetrievedSize));
+}
+
+TEST_P(olGetSymbolInfoGlobalTest, SuccessSize) {
+ size_t RetrievedSize = 0;
+ ASSERT_SUCCESS(olGetSymbolInfo(Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE,
+ sizeof(RetrievedSize), &RetrievedSize));
+ ASSERT_EQ(RetrievedSize, 64 * sizeof(uint32_t));
+}
+
TEST_P(olGetSymbolInfoKernelTest, InvalidNullHandle) {
ol_symbol_kind_t RetrievedKind;
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
diff --git a/offload/unittests/OffloadAPI/symbol/olGetSymbolInfoSize.cpp b/offload/unittests/OffloadAPI/symbol/olGetSymbolInfoSize.cpp
index aa7a061a9ef7a..ec011865cc6ad 100644
--- a/offload/unittests/OffloadAPI/symbol/olGetSymbolInfoSize.cpp
+++ b/offload/unittests/OffloadAPI/symbol/olGetSymbolInfoSize.cpp
@@ -28,6 +28,20 @@ TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessKind) {
ASSERT_EQ(Size, sizeof(ol_symbol_kind_t));
}
+TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessAddress) {
+ size_t Size = 0;
+ ASSERT_SUCCESS(olGetSymbolInfoSize(
+ Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS, &Size));
+ ASSERT_EQ(Size, sizeof(void *));
+}
+
+TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessSize) {
+ size_t Size = 0;
+ ASSERT_SUCCESS(
+ olGetSymbolInfoSize(Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE, &Size));
+ ASSERT_EQ(Size, sizeof(size_t));
+}
+
TEST_P(olGetSymbolInfoSizeKernelTest, InvalidNullHandle) {
size_t Size = 0;
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
|
jhuber6
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if we could just make it getSymbol instead of getKernel. We need to disambiguate for the underlying runtime calls (AMDGPU needs .kd and CUDA uses a different one) but that should be easily doable with a quick lookup in the ELF to see if the symbol is a function or variable.
8c44953 to
f2001d4
Compare
b2406a5 to
77a4183
Compare
Add two new symbol info types for getting the bounds of a global variable. As well as a number of tests for reading/writing to it.
77a4183 to
69089e4
Compare
Add two new symbol info types for getting the bounds of a global
variable. As well as a number of tests for reading/writing to it.