From 6b164a07f3f753ccf3b88dc01b82f58e804dc91d Mon Sep 17 00:00:00 2001 From: Ross Brunton Date: Thu, 10 Jul 2025 12:53:30 +0100 Subject: [PATCH] [Offload] Replace `GetKernel` with `GetSymbol` with global support `olGetKernel` has been replaced by `olGetSymbol` which accepts a `Kind` parameter. As well as loading information about kernels, it can now also load information about global variables. --- offload/liboffload/API/Kernel.td | 17 +--- offload/liboffload/API/Symbol.td | 16 ++++ offload/liboffload/src/OffloadImpl.cpp | 60 ++++++++---- offload/unittests/OffloadAPI/CMakeLists.txt | 4 +- .../unittests/OffloadAPI/common/Fixtures.hpp | 2 +- .../unittests/OffloadAPI/device_code/global.c | 1 + .../OffloadAPI/kernel/olGetKernel.cpp | 38 -------- .../OffloadAPI/kernel/olLaunchKernel.cpp | 15 ++- .../OffloadAPI/symbol/olGetSymbol.cpp | 93 +++++++++++++++++++ 9 files changed, 169 insertions(+), 77 deletions(-) delete mode 100644 offload/unittests/OffloadAPI/kernel/olGetKernel.cpp create mode 100644 offload/unittests/OffloadAPI/symbol/olGetSymbol.cpp diff --git a/offload/liboffload/API/Kernel.td b/offload/liboffload/API/Kernel.td index 7cb3016afd597..1e9537452820d 100644 --- a/offload/liboffload/API/Kernel.td +++ b/offload/liboffload/API/Kernel.td @@ -6,25 +6,10 @@ // //===----------------------------------------------------------------------===// // -// This file contains Offload API definitions related to loading and launching -// kernels +// This file contains Offload API definitions related to launching kernels // //===----------------------------------------------------------------------===// -def : Function { - let name = "olGetKernel"; - let desc = "Get a kernel from the function identified by `KernelName` in the given program."; - let details = [ - "Symbol handles are owned by the program and do not need to be manually destroyed." - ]; - let params = [ - Param<"ol_program_handle_t", "Program", "handle of the program", PARAM_IN>, - Param<"const char*", "KernelName", "name of the kernel entry point in the program", PARAM_IN>, - Param<"ol_symbol_handle_t*", "Kernel", "output pointer for the fetched kernel", PARAM_OUT> - ]; - let returns = []; -} - def : Struct { let name = "ol_kernel_launch_size_args_t"; let desc = "Size-related arguments for a kernel launch."; diff --git a/offload/liboffload/API/Symbol.td b/offload/liboffload/API/Symbol.td index cf4d45b09f035..cd6aab36ddb1e 100644 --- a/offload/liboffload/API/Symbol.td +++ b/offload/liboffload/API/Symbol.td @@ -15,5 +15,21 @@ def : Enum { let desc = "The kind of a symbol"; let etors =[ Etor<"KERNEL", "a kernel object">, + Etor<"GLOBAL_VARIABLE", "a global variable">, ]; } + +def : Function { + let name = "olGetSymbol"; + let desc = "Get a symbol (kernel or global variable) identified by `Name` in the given program."; + let details = [ + "Symbol handles are owned by the program and do not need to be manually destroyed." + ]; + let params = [ + Param<"ol_program_handle_t", "Program", "handle of the program", PARAM_IN>, + Param<"const char*", "Name", "name of the symbol to look up", PARAM_IN>, + Param<"ol_symbol_kind_t", "Kind", "symbol kind to look up", PARAM_IN>, + Param<"ol_symbol_handle_t*", "Symbol", "output pointer for the symbol", PARAM_OUT>, + ]; + let returns = []; +} diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index fa5d18c044048..af07a6786cfea 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -91,7 +91,9 @@ struct ol_program_impl_t { struct ol_symbol_impl_t { ol_symbol_impl_t(GenericKernelTy *Kernel) : PluginImpl(Kernel), Kind(OL_SYMBOL_KIND_KERNEL) {} - std::variant PluginImpl; + ol_symbol_impl_t(GlobalTy &&Global) + : PluginImpl(Global), Kind(OL_SYMBOL_KIND_GLOBAL_VARIABLE) {} + std::variant PluginImpl; ol_symbol_kind_t Kind; }; @@ -660,24 +662,6 @@ Error olDestroyProgram_impl(ol_program_handle_t Program) { return olDestroy(Program); } -Error olGetKernel_impl(ol_program_handle_t Program, const char *KernelName, - ol_symbol_handle_t *Kernel) { - - auto &Device = Program->Image->getDevice(); - auto KernelImpl = Device.constructKernel(KernelName); - if (!KernelImpl) - return KernelImpl.takeError(); - - if (auto Err = KernelImpl->init(Device, *Program->Image)) - return Err; - - *Kernel = Program->Symbols - .emplace_back(std::make_unique(&*KernelImpl)) - .get(); - - return Error::success(); -} - Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device, ol_symbol_handle_t Kernel, const void *ArgumentsData, size_t ArgumentsSize, @@ -726,5 +710,43 @@ Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device, return Error::success(); } +Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name, + ol_symbol_kind_t Kind, ol_symbol_handle_t *Symbol) { + auto &Device = Program->Image->getDevice(); + + switch (Kind) { + case OL_SYMBOL_KIND_KERNEL: { + auto KernelImpl = Device.constructKernel(Name); + if (!KernelImpl) + return KernelImpl.takeError(); + + if (auto Err = KernelImpl->init(Device, *Program->Image)) + return Err; + + *Symbol = + Program->Symbols + .emplace_back(std::make_unique(&*KernelImpl)) + .get(); + return Error::success(); + } + case OL_SYMBOL_KIND_GLOBAL_VARIABLE: { + GlobalTy GlobalObj{Name}; + if (auto Res = Device.Plugin.getGlobalHandler().getGlobalMetadataFromDevice( + Device, *Program->Image, GlobalObj)) + return Res; + + *Symbol = Program->Symbols + .emplace_back( + std::make_unique(std::move(GlobalObj))) + .get(); + + return Error::success(); + } + default: + return createOffloadError(ErrorCode::INVALID_ENUMERATION, + "getSymbol kind enum '%i' is invalid", Kind); + } +} + } // namespace offload } // namespace llvm diff --git a/offload/unittests/OffloadAPI/CMakeLists.txt b/offload/unittests/OffloadAPI/CMakeLists.txt index 93e5fd2f6cd26..ebf2d6b4aeaaf 100644 --- a/offload/unittests/OffloadAPI/CMakeLists.txt +++ b/offload/unittests/OffloadAPI/CMakeLists.txt @@ -19,7 +19,6 @@ add_offload_unittest("init" target_compile_definitions("init.unittests" PRIVATE DISABLE_WRAPPER) add_offload_unittest("kernel" - kernel/olGetKernel.cpp kernel/olLaunchKernel.cpp) add_offload_unittest("memory" @@ -41,3 +40,6 @@ add_offload_unittest("queue" queue/olDestroyQueue.cpp queue/olGetQueueInfo.cpp queue/olGetQueueInfoSize.cpp) + +add_offload_unittest("symbol" + symbol/olGetSymbol.cpp) diff --git a/offload/unittests/OffloadAPI/common/Fixtures.hpp b/offload/unittests/OffloadAPI/common/Fixtures.hpp index e443d9761f30b..0c2bd1e3dae20 100644 --- a/offload/unittests/OffloadAPI/common/Fixtures.hpp +++ b/offload/unittests/OffloadAPI/common/Fixtures.hpp @@ -113,7 +113,7 @@ struct OffloadProgramTest : OffloadDeviceTest { struct OffloadKernelTest : OffloadProgramTest { void SetUp() override { RETURN_ON_FATAL_FAILURE(OffloadProgramTest::SetUp()); - ASSERT_SUCCESS(olGetKernel(Program, "foo", &Kernel)); + ASSERT_SUCCESS(olGetSymbol(Program, "foo", OL_SYMBOL_KIND_KERNEL, &Kernel)); } void TearDown() override { diff --git a/offload/unittests/OffloadAPI/device_code/global.c b/offload/unittests/OffloadAPI/device_code/global.c index b30e406fb98c7..9f27f9424324f 100644 --- a/offload/unittests/OffloadAPI/device_code/global.c +++ b/offload/unittests/OffloadAPI/device_code/global.c @@ -1,6 +1,7 @@ #include #include +[[gnu::visibility("default")]] uint32_t global[64]; __gpu_kernel void write() { diff --git a/offload/unittests/OffloadAPI/kernel/olGetKernel.cpp b/offload/unittests/OffloadAPI/kernel/olGetKernel.cpp deleted file mode 100644 index 34870f1fbf0a3..0000000000000 --- a/offload/unittests/OffloadAPI/kernel/olGetKernel.cpp +++ /dev/null @@ -1,38 +0,0 @@ -//===------- Offload API tests - olGetKernel ---------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "../common/Fixtures.hpp" -#include -#include - -using olGetKernelTest = OffloadProgramTest; -OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olGetKernelTest); - -TEST_P(olGetKernelTest, Success) { - ol_symbol_handle_t Kernel = nullptr; - ASSERT_SUCCESS(olGetKernel(Program, "foo", &Kernel)); - ASSERT_NE(Kernel, nullptr); -} - -TEST_P(olGetKernelTest, InvalidNullProgram) { - ol_symbol_handle_t Kernel = nullptr; - ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE, - olGetKernel(nullptr, "foo", &Kernel)); -} - -TEST_P(olGetKernelTest, InvalidNullKernelPointer) { - ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, - olGetKernel(Program, "foo", nullptr)); -} - -// Error code returning from plugin interface not yet supported -TEST_P(olGetKernelTest, InvalidKernelName) { - ol_symbol_handle_t Kernel = nullptr; - ASSERT_ERROR(OL_ERRC_NOT_FOUND, - olGetKernel(Program, "invalid_kernel_name", &Kernel)); -} diff --git a/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp b/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp index acda4795edec2..e7e608f2a64d4 100644 --- a/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp +++ b/offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp @@ -40,7 +40,8 @@ struct LaunchKernelTestBase : OffloadQueueTest { struct LaunchSingleKernelTestBase : LaunchKernelTestBase { void SetUpKernel(const char *kernel) { RETURN_ON_FATAL_FAILURE(SetUpProgram(kernel)); - ASSERT_SUCCESS(olGetKernel(Program, kernel, &Kernel)); + ASSERT_SUCCESS( + olGetSymbol(Program, kernel, OL_SYMBOL_KIND_KERNEL, &Kernel)); } ol_symbol_handle_t Kernel = nullptr; @@ -67,7 +68,8 @@ struct LaunchMultipleKernelTestBase : LaunchKernelTestBase { Kernels.resize(kernels.size()); size_t I = 0; for (auto K : kernels) - ASSERT_SUCCESS(olGetKernel(Program, K, &Kernels[I++])); + ASSERT_SUCCESS( + olGetSymbol(Program, K, OL_SYMBOL_KIND_KERNEL, &Kernels[I++])); } std::vector Kernels; @@ -223,6 +225,15 @@ TEST_P(olLaunchKernelGlobalTest, Success) { ASSERT_SUCCESS(olMemFree(Mem)); } +TEST_P(olLaunchKernelGlobalTest, InvalidNotAKernel) { + ol_symbol_handle_t Global = nullptr; + ASSERT_SUCCESS( + olGetSymbol(Program, "global", OL_SYMBOL_KIND_GLOBAL_VARIABLE, &Global)); + ASSERT_ERROR( + OL_ERRC_SYMBOL_KIND, + olLaunchKernel(Queue, Device, Global, nullptr, 0, &LaunchArgs, nullptr)); +} + TEST_P(olLaunchKernelGlobalCtorTest, Success) { void *Mem; ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, diff --git a/offload/unittests/OffloadAPI/symbol/olGetSymbol.cpp b/offload/unittests/OffloadAPI/symbol/olGetSymbol.cpp new file mode 100644 index 0000000000000..5e87ab5b29621 --- /dev/null +++ b/offload/unittests/OffloadAPI/symbol/olGetSymbol.cpp @@ -0,0 +1,93 @@ +//===------- Offload API tests - olGetSymbol ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "../common/Fixtures.hpp" +#include +#include + +using olGetSymbolKernelTest = OffloadProgramTest; +OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olGetSymbolKernelTest); + +struct olGetSymbolGlobalTest : OffloadQueueTest { + void SetUp() override { + RETURN_ON_FATAL_FAILURE(OffloadQueueTest::SetUp()); + ASSERT_TRUE(TestEnvironment::loadDeviceBinary("global", Device, DeviceBin)); + ASSERT_GE(DeviceBin->getBufferSize(), 0lu); + ASSERT_SUCCESS(olCreateProgram(Device, DeviceBin->getBufferStart(), + DeviceBin->getBufferSize(), &Program)); + } + + void TearDown() override { + if (Program) { + olDestroyProgram(Program); + } + RETURN_ON_FATAL_FAILURE(OffloadQueueTest::TearDown()); + } + + std::unique_ptr DeviceBin; + ol_program_handle_t Program = nullptr; + ol_kernel_launch_size_args_t LaunchArgs{}; +}; +OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olGetSymbolGlobalTest); + +TEST_P(olGetSymbolKernelTest, Success) { + ol_symbol_handle_t Kernel = nullptr; + ASSERT_SUCCESS(olGetSymbol(Program, "foo", OL_SYMBOL_KIND_KERNEL, &Kernel)); + ASSERT_NE(Kernel, nullptr); +} + +TEST_P(olGetSymbolKernelTest, InvalidNullProgram) { + ol_symbol_handle_t Kernel = nullptr; + ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE, + olGetSymbol(nullptr, "foo", OL_SYMBOL_KIND_KERNEL, &Kernel)); +} + +TEST_P(olGetSymbolKernelTest, InvalidNullKernelPointer) { + ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, + olGetSymbol(Program, "foo", OL_SYMBOL_KIND_KERNEL, nullptr)); +} + +TEST_P(olGetSymbolKernelTest, InvalidKernelName) { + ol_symbol_handle_t Kernel = nullptr; + ASSERT_ERROR(OL_ERRC_NOT_FOUND, olGetSymbol(Program, "invalid_kernel_name", + OL_SYMBOL_KIND_KERNEL, &Kernel)); +} + +TEST_P(olGetSymbolKernelTest, InvalidKind) { + ol_symbol_handle_t Kernel = nullptr; + ASSERT_ERROR( + OL_ERRC_INVALID_ENUMERATION, + olGetSymbol(Program, "foo", OL_SYMBOL_KIND_FORCE_UINT32, &Kernel)); +} + +TEST_P(olGetSymbolGlobalTest, Success) { + ol_symbol_handle_t Global = nullptr; + ASSERT_SUCCESS( + olGetSymbol(Program, "global", OL_SYMBOL_KIND_GLOBAL_VARIABLE, &Global)); + ASSERT_NE(Global, nullptr); +} + +TEST_P(olGetSymbolGlobalTest, InvalidNullProgram) { + ol_symbol_handle_t Global = nullptr; + ASSERT_ERROR( + OL_ERRC_INVALID_NULL_HANDLE, + olGetSymbol(nullptr, "global", OL_SYMBOL_KIND_GLOBAL_VARIABLE, &Global)); +} + +TEST_P(olGetSymbolGlobalTest, InvalidNullGlobalPointer) { + ASSERT_ERROR( + OL_ERRC_INVALID_NULL_POINTER, + olGetSymbol(Program, "global", OL_SYMBOL_KIND_GLOBAL_VARIABLE, nullptr)); +} + +TEST_P(olGetSymbolGlobalTest, InvalidGlobalName) { + ol_symbol_handle_t Global = nullptr; + ASSERT_ERROR(OL_ERRC_NOT_FOUND, + olGetSymbol(Program, "invalid_global", + OL_SYMBOL_KIND_GLOBAL_VARIABLE, &Global)); +}