Skip to content

Commit b2406a5

Browse files
committed
[Offload] Add global variable address/size queries
Add two new symbol info types for getting the bounds of a global variable. As well as a number of tests for reading/writing to it.
1 parent 8c44953 commit b2406a5

File tree

6 files changed

+173
-3
lines changed

6 files changed

+173
-3
lines changed

offload/liboffload/API/Symbol.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@ def : Enum {
3838
let desc = "Supported symbol info.";
3939
let is_typed = 1;
4040
let etors = [
41-
TaggedEtor<"KIND", "ol_symbol_kind_t", "The kind of this symbol.">
41+
TaggedEtor<"KIND", "ol_symbol_kind_t", "The kind of this symbol.">,
42+
TaggedEtor<"GLOBAL_VARIABLE_ADDRESS", "void *", "The address in memory for this global variable.">,
43+
TaggedEtor<"GLOBAL_VARIABLE_SIZE", "size_t", "The size in bytes for this global variable.">,
4244
];
4345
}
4446

offload/liboffload/src/OffloadImpl.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,9 +751,28 @@ Error olGetSymbolInfoImplDetail(ol_symbol_handle_t Symbol,
751751
void *PropValue, size_t *PropSizeRet) {
752752
InfoWriter Info(PropSize, PropValue, PropSizeRet);
753753

754+
auto CheckKind = [&](ol_symbol_kind_t Required) {
755+
if (Symbol->Kind != Required) {
756+
std::string ErrBuffer;
757+
llvm::raw_string_ostream(ErrBuffer)
758+
<< PropName << ": Expected a symbol of Kind " << Required
759+
<< " but given a symbol of Kind " << Symbol->Kind;
760+
return Plugin::error(ErrorCode::SYMBOL_KIND, ErrBuffer.c_str());
761+
}
762+
return Plugin::success();
763+
};
764+
754765
switch (PropName) {
755766
case OL_SYMBOL_INFO_KIND:
756767
return Info.write<ol_symbol_kind_t>(Symbol->Kind);
768+
case OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS:
769+
if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE))
770+
return Err;
771+
return Info.write<void *>(std::get<GlobalTy>(Symbol->PluginImpl).getPtr());
772+
case OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE:
773+
if (auto Err = CheckKind(OL_SYMBOL_KIND_GLOBAL_VARIABLE))
774+
return Err;
775+
return Info.write<size_t>(std::get<GlobalTy>(Symbol->PluginImpl).getSize());
757776
default:
758777
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
759778
"olGetSymbolInfo enum '%i' is invalid", PropName);

offload/tools/offload-tblgen/PrintGen.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,12 @@ inline void printTagged(llvm::raw_ostream &os, const void *ptr, {0} value, size_
7474
if (Type == "char[]") {
7575
OS << formatv(TAB_2 "printPtr(os, (const char*) ptr);\n");
7676
} else {
77-
OS << formatv(TAB_2 "const {0} * const tptr = (const {0} * const)ptr;\n",
78-
Type);
77+
if (Type == "void *")
78+
OS << formatv(TAB_2 "void * const * const tptr = (void * "
79+
"const * const)ptr;\n");
80+
else
81+
OS << formatv(
82+
TAB_2 "const {0} * const tptr = (const {0} * const)ptr;\n", Type);
7983
// TODO: Handle other cases here
8084
OS << TAB_2 "os << (const void *)tptr << \" (\";\n";
8185
if (Type.ends_with("*")) {

offload/unittests/OffloadAPI/memory/olMemcpy.cpp

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,30 @@
1313
using olMemcpyTest = OffloadQueueTest;
1414
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemcpyTest);
1515

16+
struct olMemcpyGlobalTest : OffloadGlobalTest {
17+
void SetUp() override {
18+
RETURN_ON_FATAL_FAILURE(OffloadGlobalTest::SetUp());
19+
ASSERT_SUCCESS(olGetKernel(Program, "read", &ReadKernel));
20+
ASSERT_SUCCESS(olGetKernel(Program, "write", &WriteKernel));
21+
ASSERT_SUCCESS(olCreateQueue(Device, &Queue));
22+
ASSERT_SUCCESS(olGetSymbolInfo(
23+
Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS, sizeof(Addr), &Addr));
24+
25+
LaunchArgs.Dimensions = 1;
26+
LaunchArgs.GroupSize = {64, 1, 1};
27+
LaunchArgs.NumGroups = {1, 1, 1};
28+
29+
LaunchArgs.DynSharedMemory = 0;
30+
}
31+
32+
ol_kernel_launch_size_args_t LaunchArgs{};
33+
void *Addr;
34+
ol_symbol_handle_t ReadKernel;
35+
ol_symbol_handle_t WriteKernel;
36+
ol_queue_handle_t Queue;
37+
};
38+
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olMemcpyGlobalTest);
39+
1640
TEST_P(olMemcpyTest, SuccessHtoD) {
1741
constexpr size_t Size = 1024;
1842
void *Alloc;
@@ -105,3 +129,82 @@ TEST_P(olMemcpyTest, SuccessSizeZero) {
105129
ASSERT_SUCCESS(
106130
olMemcpy(nullptr, Output.data(), Host, Input.data(), Host, 0, nullptr));
107131
}
132+
133+
TEST_P(olMemcpyGlobalTest, SuccessRoundTrip) {
134+
void *SourceMem;
135+
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
136+
64 * sizeof(uint32_t), &SourceMem));
137+
uint32_t *SourceData = (uint32_t *)SourceMem;
138+
for (auto I = 0; I < 64; I++)
139+
SourceData[I] = I;
140+
141+
void *DestMem;
142+
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
143+
64 * sizeof(uint32_t), &DestMem));
144+
145+
ASSERT_SUCCESS(olMemcpy(Queue, Addr, Device, SourceMem, Host,
146+
64 * sizeof(uint32_t), nullptr));
147+
ASSERT_SUCCESS(olWaitQueue(Queue));
148+
ASSERT_SUCCESS(olMemcpy(Queue, DestMem, Host, Addr, Device,
149+
64 * sizeof(uint32_t), nullptr));
150+
ASSERT_SUCCESS(olWaitQueue(Queue));
151+
152+
uint32_t *DestData = (uint32_t *)DestMem;
153+
for (uint32_t I = 0; I < 64; I++)
154+
ASSERT_EQ(DestData[I], I);
155+
156+
ASSERT_SUCCESS(olMemFree(DestMem));
157+
ASSERT_SUCCESS(olMemFree(SourceMem));
158+
}
159+
160+
TEST_P(olMemcpyGlobalTest, SuccessWrite) {
161+
void *SourceMem;
162+
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
163+
LaunchArgs.GroupSize.x * sizeof(uint32_t),
164+
&SourceMem));
165+
uint32_t *SourceData = (uint32_t *)SourceMem;
166+
for (auto I = 0; I < 64; I++)
167+
SourceData[I] = I;
168+
169+
void *DestMem;
170+
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
171+
LaunchArgs.GroupSize.x * sizeof(uint32_t),
172+
&DestMem));
173+
struct {
174+
void *Mem;
175+
} Args{DestMem};
176+
177+
ASSERT_SUCCESS(olMemcpy(Queue, Addr, Device, SourceMem, Host,
178+
64 * sizeof(uint32_t), nullptr));
179+
ASSERT_SUCCESS(olWaitQueue(Queue));
180+
ASSERT_SUCCESS(olLaunchKernel(Queue, Device, ReadKernel, &Args, sizeof(Args),
181+
&LaunchArgs, nullptr));
182+
ASSERT_SUCCESS(olWaitQueue(Queue));
183+
184+
uint32_t *DestData = (uint32_t *)DestMem;
185+
for (uint32_t I = 0; I < 64; I++)
186+
ASSERT_EQ(DestData[I], I);
187+
188+
ASSERT_SUCCESS(olMemFree(DestMem));
189+
ASSERT_SUCCESS(olMemFree(SourceMem));
190+
}
191+
192+
TEST_P(olMemcpyGlobalTest, SuccessRead) {
193+
void *DestMem;
194+
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
195+
LaunchArgs.GroupSize.x * sizeof(uint32_t),
196+
&DestMem));
197+
198+
ASSERT_SUCCESS(olLaunchKernel(Queue, Device, WriteKernel, nullptr, 0,
199+
&LaunchArgs, nullptr));
200+
ASSERT_SUCCESS(olWaitQueue(Queue));
201+
ASSERT_SUCCESS(olMemcpy(Queue, DestMem, Host, Addr, Device,
202+
64 * sizeof(uint32_t), nullptr));
203+
ASSERT_SUCCESS(olWaitQueue(Queue));
204+
205+
uint32_t *DestData = (uint32_t *)DestMem;
206+
for (uint32_t I = 0; I < 64; I++)
207+
ASSERT_EQ(DestData[I], I * 2);
208+
209+
ASSERT_SUCCESS(olMemFree(DestMem));
210+
}

offload/unittests/OffloadAPI/symbol/olGetSymbolInfo.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,34 @@ TEST_P(olGetSymbolInfoGlobalTest, SuccessKind) {
3030
ASSERT_EQ(RetrievedKind, OL_SYMBOL_KIND_GLOBAL_VARIABLE);
3131
}
3232

33+
TEST_P(olGetSymbolInfoKernelTest, InvalidAddress) {
34+
void *RetrievedAddr;
35+
ASSERT_ERROR(OL_ERRC_SYMBOL_KIND,
36+
olGetSymbolInfo(Kernel, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS,
37+
sizeof(RetrievedAddr), &RetrievedAddr));
38+
}
39+
40+
TEST_P(olGetSymbolInfoGlobalTest, SuccessAddress) {
41+
void *RetrievedAddr = nullptr;
42+
ASSERT_SUCCESS(olGetSymbolInfo(Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS,
43+
sizeof(RetrievedAddr), &RetrievedAddr));
44+
ASSERT_NE(RetrievedAddr, nullptr);
45+
}
46+
47+
TEST_P(olGetSymbolInfoKernelTest, InvalidSize) {
48+
size_t RetrievedSize;
49+
ASSERT_ERROR(OL_ERRC_SYMBOL_KIND,
50+
olGetSymbolInfo(Kernel, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE,
51+
sizeof(RetrievedSize), &RetrievedSize));
52+
}
53+
54+
TEST_P(olGetSymbolInfoGlobalTest, SuccessSize) {
55+
size_t RetrievedSize = 0;
56+
ASSERT_SUCCESS(olGetSymbolInfo(Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE,
57+
sizeof(RetrievedSize), &RetrievedSize));
58+
ASSERT_EQ(RetrievedSize, 64 * sizeof(uint32_t));
59+
}
60+
3361
TEST_P(olGetSymbolInfoKernelTest, InvalidNullHandle) {
3462
ol_symbol_kind_t RetrievedKind;
3563
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,

offload/unittests/OffloadAPI/symbol/olGetSymbolInfoSize.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,20 @@ TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessKind) {
2828
ASSERT_EQ(Size, sizeof(ol_symbol_kind_t));
2929
}
3030

31+
TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessAddress) {
32+
size_t Size = 0;
33+
ASSERT_SUCCESS(olGetSymbolInfoSize(
34+
Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_ADDRESS, &Size));
35+
ASSERT_EQ(Size, sizeof(void *));
36+
}
37+
38+
TEST_P(olGetSymbolInfoSizeGlobalTest, SuccessSize) {
39+
size_t Size = 0;
40+
ASSERT_SUCCESS(
41+
olGetSymbolInfoSize(Global, OL_SYMBOL_INFO_GLOBAL_VARIABLE_SIZE, &Size));
42+
ASSERT_EQ(Size, sizeof(size_t));
43+
}
44+
3145
TEST_P(olGetSymbolInfoSizeKernelTest, InvalidNullHandle) {
3246
size_t Size = 0;
3347
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,

0 commit comments

Comments
 (0)