Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 1 addition & 16 deletions offload/liboffload/API/Kernel.td
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,10 @@
//
//===----------------------------------------------------------------------===//
//
// This file contains Offload API definitions related to loading and launching
// kernels
// This file contains Offload API definitions related to launching kernels
//
//===----------------------------------------------------------------------===//

def : Function {
let name = "olGetKernel";
let desc = "Get a kernel from the function identified by `KernelName` in the given program.";
let details = [
"Symbol handles are owned by the program and do not need to be manually destroyed."
];
let params = [
Param<"ol_program_handle_t", "Program", "handle of the program", PARAM_IN>,
Param<"const char*", "KernelName", "name of the kernel entry point in the program", PARAM_IN>,
Param<"ol_symbol_handle_t*", "Kernel", "output pointer for the fetched kernel", PARAM_OUT>
];
let returns = [];
}

def : Struct {
let name = "ol_kernel_launch_size_args_t";
let desc = "Size-related arguments for a kernel launch.";
Expand Down
16 changes: 16 additions & 0 deletions offload/liboffload/API/Symbol.td
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,21 @@ def : Enum {
let desc = "The kind of a symbol";
let etors =[
Etor<"KERNEL", "a kernel object">,
Etor<"GLOBAL_VARIABLE", "a global variable">,
];
}

def : Function {
let name = "olGetSymbol";
let desc = "Get a symbol (kernel or global variable) identified by `Name` in the given program.";
let details = [
"Symbol handles are owned by the program and do not need to be manually destroyed."
];
let params = [
Param<"ol_program_handle_t", "Program", "handle of the program", PARAM_IN>,
Param<"const char*", "Name", "name of the symbol to look up", PARAM_IN>,
Param<"ol_symbol_kind_t", "Kind", "symbol kind to look up", PARAM_IN>,
Param<"ol_symbol_handle_t*", "Symbol", "output pointer for the symbol", PARAM_OUT>,
];
let returns = [];
}
60 changes: 41 additions & 19 deletions offload/liboffload/src/OffloadImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ struct ol_program_impl_t {
struct ol_symbol_impl_t {
ol_symbol_impl_t(GenericKernelTy *Kernel)
: PluginImpl(Kernel), Kind(OL_SYMBOL_KIND_KERNEL) {}
std::variant<GenericKernelTy *> PluginImpl;
ol_symbol_impl_t(GlobalTy &&Global)
: PluginImpl(Global), Kind(OL_SYMBOL_KIND_GLOBAL_VARIABLE) {}
std::variant<GenericKernelTy *, GlobalTy> PluginImpl;
ol_symbol_kind_t Kind;
};

Expand Down Expand Up @@ -660,24 +662,6 @@ Error olDestroyProgram_impl(ol_program_handle_t Program) {
return olDestroy(Program);
}

Error olGetKernel_impl(ol_program_handle_t Program, const char *KernelName,
ol_symbol_handle_t *Kernel) {

auto &Device = Program->Image->getDevice();
auto KernelImpl = Device.constructKernel(KernelName);
if (!KernelImpl)
return KernelImpl.takeError();

if (auto Err = KernelImpl->init(Device, *Program->Image))
return Err;

*Kernel = Program->Symbols
.emplace_back(std::make_unique<ol_symbol_impl_t>(&*KernelImpl))
.get();

return Error::success();
}

Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device,
ol_symbol_handle_t Kernel, const void *ArgumentsData,
size_t ArgumentsSize,
Expand Down Expand Up @@ -726,5 +710,43 @@ Error olLaunchKernel_impl(ol_queue_handle_t Queue, ol_device_handle_t Device,
return Error::success();
}

Error olGetSymbol_impl(ol_program_handle_t Program, const char *Name,
ol_symbol_kind_t Kind, ol_symbol_handle_t *Symbol) {
auto &Device = Program->Image->getDevice();

switch (Kind) {
case OL_SYMBOL_KIND_KERNEL: {
auto KernelImpl = Device.constructKernel(Name);
if (!KernelImpl)
return KernelImpl.takeError();

if (auto Err = KernelImpl->init(Device, *Program->Image))
return Err;

*Symbol =
Program->Symbols
.emplace_back(std::make_unique<ol_symbol_impl_t>(&*KernelImpl))
.get();
return Error::success();
}
case OL_SYMBOL_KIND_GLOBAL_VARIABLE: {
GlobalTy GlobalObj{Name};
if (auto Res = Device.Plugin.getGlobalHandler().getGlobalMetadataFromDevice(
Device, *Program->Image, GlobalObj))
return Res;

*Symbol = Program->Symbols
.emplace_back(
std::make_unique<ol_symbol_impl_t>(std::move(GlobalObj)))
.get();

return Error::success();
}
default:
return createOffloadError(ErrorCode::INVALID_ENUMERATION,
"getSymbol kind enum '%i' is invalid", Kind);
}
}

} // namespace offload
} // namespace llvm
4 changes: 3 additions & 1 deletion offload/unittests/OffloadAPI/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ add_offload_unittest("init"
target_compile_definitions("init.unittests" PRIVATE DISABLE_WRAPPER)

add_offload_unittest("kernel"
kernel/olGetKernel.cpp
kernel/olLaunchKernel.cpp)

add_offload_unittest("memory"
Expand All @@ -41,3 +40,6 @@ add_offload_unittest("queue"
queue/olDestroyQueue.cpp
queue/olGetQueueInfo.cpp
queue/olGetQueueInfoSize.cpp)

add_offload_unittest("symbol"
symbol/olGetSymbol.cpp)
2 changes: 1 addition & 1 deletion offload/unittests/OffloadAPI/common/Fixtures.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ struct OffloadProgramTest : OffloadDeviceTest {
struct OffloadKernelTest : OffloadProgramTest {
void SetUp() override {
RETURN_ON_FATAL_FAILURE(OffloadProgramTest::SetUp());
ASSERT_SUCCESS(olGetKernel(Program, "foo", &Kernel));
ASSERT_SUCCESS(olGetSymbol(Program, "foo", OL_SYMBOL_KIND_KERNEL, &Kernel));
}

void TearDown() override {
Expand Down
1 change: 1 addition & 0 deletions offload/unittests/OffloadAPI/device_code/global.c
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <gpuintrin.h>
#include <stdint.h>

[[gnu::visibility("default")]]
uint32_t global[64];

__gpu_kernel void write() {
Expand Down
38 changes: 0 additions & 38 deletions offload/unittests/OffloadAPI/kernel/olGetKernel.cpp

This file was deleted.

15 changes: 13 additions & 2 deletions offload/unittests/OffloadAPI/kernel/olLaunchKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ struct LaunchKernelTestBase : OffloadQueueTest {
struct LaunchSingleKernelTestBase : LaunchKernelTestBase {
void SetUpKernel(const char *kernel) {
RETURN_ON_FATAL_FAILURE(SetUpProgram(kernel));
ASSERT_SUCCESS(olGetKernel(Program, kernel, &Kernel));
ASSERT_SUCCESS(
olGetSymbol(Program, kernel, OL_SYMBOL_KIND_KERNEL, &Kernel));
}

ol_symbol_handle_t Kernel = nullptr;
Expand All @@ -67,7 +68,8 @@ struct LaunchMultipleKernelTestBase : LaunchKernelTestBase {
Kernels.resize(kernels.size());
size_t I = 0;
for (auto K : kernels)
ASSERT_SUCCESS(olGetKernel(Program, K, &Kernels[I++]));
ASSERT_SUCCESS(
olGetSymbol(Program, K, OL_SYMBOL_KIND_KERNEL, &Kernels[I++]));
}

std::vector<ol_symbol_handle_t> Kernels;
Expand Down Expand Up @@ -223,6 +225,15 @@ TEST_P(olLaunchKernelGlobalTest, Success) {
ASSERT_SUCCESS(olMemFree(Mem));
}

TEST_P(olLaunchKernelGlobalTest, InvalidNotAKernel) {
ol_symbol_handle_t Global = nullptr;
ASSERT_SUCCESS(
olGetSymbol(Program, "global", OL_SYMBOL_KIND_GLOBAL_VARIABLE, &Global));
ASSERT_ERROR(
OL_ERRC_SYMBOL_KIND,
olLaunchKernel(Queue, Device, Global, nullptr, 0, &LaunchArgs, nullptr));
}

TEST_P(olLaunchKernelGlobalCtorTest, Success) {
void *Mem;
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
Expand Down
93 changes: 93 additions & 0 deletions offload/unittests/OffloadAPI/symbol/olGetSymbol.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
//===------- Offload API tests - olGetSymbol ---------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "../common/Fixtures.hpp"
#include <OffloadAPI.h>
#include <gtest/gtest.h>

using olGetSymbolKernelTest = OffloadProgramTest;
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olGetSymbolKernelTest);

struct olGetSymbolGlobalTest : OffloadQueueTest {
void SetUp() override {
RETURN_ON_FATAL_FAILURE(OffloadQueueTest::SetUp());
ASSERT_TRUE(TestEnvironment::loadDeviceBinary("global", Device, DeviceBin));
ASSERT_GE(DeviceBin->getBufferSize(), 0lu);
ASSERT_SUCCESS(olCreateProgram(Device, DeviceBin->getBufferStart(),
DeviceBin->getBufferSize(), &Program));
}

void TearDown() override {
if (Program) {
olDestroyProgram(Program);
}
RETURN_ON_FATAL_FAILURE(OffloadQueueTest::TearDown());
}

std::unique_ptr<llvm::MemoryBuffer> DeviceBin;
ol_program_handle_t Program = nullptr;
ol_kernel_launch_size_args_t LaunchArgs{};
};
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olGetSymbolGlobalTest);

TEST_P(olGetSymbolKernelTest, Success) {
ol_symbol_handle_t Kernel = nullptr;
ASSERT_SUCCESS(olGetSymbol(Program, "foo", OL_SYMBOL_KIND_KERNEL, &Kernel));
ASSERT_NE(Kernel, nullptr);
}

TEST_P(olGetSymbolKernelTest, InvalidNullProgram) {
ol_symbol_handle_t Kernel = nullptr;
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
olGetSymbol(nullptr, "foo", OL_SYMBOL_KIND_KERNEL, &Kernel));
}

TEST_P(olGetSymbolKernelTest, InvalidNullKernelPointer) {
ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER,
olGetSymbol(Program, "foo", OL_SYMBOL_KIND_KERNEL, nullptr));
}

TEST_P(olGetSymbolKernelTest, InvalidKernelName) {
ol_symbol_handle_t Kernel = nullptr;
ASSERT_ERROR(OL_ERRC_NOT_FOUND, olGetSymbol(Program, "invalid_kernel_name",
OL_SYMBOL_KIND_KERNEL, &Kernel));
}

TEST_P(olGetSymbolKernelTest, InvalidKind) {
ol_symbol_handle_t Kernel = nullptr;
ASSERT_ERROR(
OL_ERRC_INVALID_ENUMERATION,
olGetSymbol(Program, "foo", OL_SYMBOL_KIND_FORCE_UINT32, &Kernel));
}

TEST_P(olGetSymbolGlobalTest, Success) {
ol_symbol_handle_t Global = nullptr;
ASSERT_SUCCESS(
olGetSymbol(Program, "global", OL_SYMBOL_KIND_GLOBAL_VARIABLE, &Global));
ASSERT_NE(Global, nullptr);
}

TEST_P(olGetSymbolGlobalTest, InvalidNullProgram) {
ol_symbol_handle_t Global = nullptr;
ASSERT_ERROR(
OL_ERRC_INVALID_NULL_HANDLE,
olGetSymbol(nullptr, "global", OL_SYMBOL_KIND_GLOBAL_VARIABLE, &Global));
}

TEST_P(olGetSymbolGlobalTest, InvalidNullGlobalPointer) {
ASSERT_ERROR(
OL_ERRC_INVALID_NULL_POINTER,
olGetSymbol(Program, "global", OL_SYMBOL_KIND_GLOBAL_VARIABLE, nullptr));
}

TEST_P(olGetSymbolGlobalTest, InvalidGlobalName) {
ol_symbol_handle_t Global = nullptr;
ASSERT_ERROR(OL_ERRC_NOT_FOUND,
olGetSymbol(Program, "invalid_global",
OL_SYMBOL_KIND_GLOBAL_VARIABLE, &Global));
}
Loading