Skip to content

Commit f5ed55c

Browse files
jhuber6ronlieb
authored andcommitted
[Offload] Implement 'olIsValidBinary' in offload and clean up (llvm#159658)
Summary: This exposes the 'isDeviceCompatible' routine for checking if a binary *can* be loaded. This is useful if people don't want to consume errors everywhere when figuring out which image to put to what device. I don't know if this is a good name, I was thining like `olIsCompatible` or whatever. Let me know what you think. Long term I'd like to be able to do something similar to what OpenMP does where we can conditionally only initialize devices if we need them. That's going to be support needed if we want this to be more generic.
1 parent 731ea1d commit f5ed55c

File tree

7 files changed

+134
-81
lines changed

7 files changed

+134
-81
lines changed

offload/liboffload/API/Program.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,18 @@ def olCreateProgram : Function {
2424
let returns = [];
2525
}
2626

27+
def olIsValidBinary : Function {
28+
let desc = "Validate if the binary image pointed to by `ProgData` is compatible with the device.";
29+
let details = ["The provided `ProgData` will not be loaded onto the device"];
30+
let params = [
31+
Param<"ol_device_handle_t", "Device", "handle of the device", PARAM_IN>,
32+
Param<"const void*", "ProgData", "pointer to the program binary data", PARAM_IN>,
33+
Param<"size_t", "ProgDataSize", "size of the program binary in bytes", PARAM_IN>,
34+
Param<"bool*", "Valid", "output is true if the image is compatible", PARAM_OUT>
35+
];
36+
let returns = [];
37+
}
38+
2739
def olDestroyProgram : Function {
2840
let desc = "Destroy the program and free all underlying resources.";
2941
let details = [];

offload/liboffload/src/OffloadImpl.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -990,7 +990,6 @@ Error olMemFill_impl(ol_queue_handle_t Queue, void *Ptr, size_t PatternSize,
990990

991991
Error olCreateProgram_impl(ol_device_handle_t Device, const void *ProgData,
992992
size_t ProgDataSize, ol_program_handle_t *Program) {
993-
// Make a copy of the program binary in case it is released by the caller.
994993
StringRef Buffer(reinterpret_cast<const char *>(ProgData), ProgDataSize);
995994
Expected<plugin::DeviceImageTy *> Res =
996995
Device->Device->loadBinary(Device->Device->Plugin, Buffer);
@@ -1002,6 +1001,14 @@ Error olCreateProgram_impl(ol_device_handle_t Device, const void *ProgData,
10021001
return Error::success();
10031002
}
10041003

1004+
Error olIsValidBinary_impl(ol_device_handle_t Device, const void *ProgData,
1005+
size_t ProgDataSize, bool *IsValid) {
1006+
StringRef Buffer(reinterpret_cast<const char *>(ProgData), ProgDataSize);
1007+
*IsValid = Device->Device->Plugin.isDeviceCompatible(
1008+
Device->Device->getDeviceId(), Buffer);
1009+
return Error::success();
1010+
}
1011+
10051012
Error olDestroyProgram_impl(ol_program_handle_t Program) {
10061013
auto &Device = Program->Image->getDevice();
10071014
if (auto Err = Device.unloadBinary(Program->Image))

offload/libomptarget/PluginManager.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,10 @@ void PluginManager::registerLib(__tgt_bin_desc *Desc) {
243243
// Scan the RTLs that have associated images until we find one that supports
244244
// the current image.
245245
for (auto &R : plugins()) {
246-
if (!R.is_plugin_compatible(Img))
246+
StringRef Buffer(reinterpret_cast<const char *>(Img->ImageStart),
247+
utils::getPtrDiff(Img->ImageEnd, Img->ImageStart));
248+
249+
if (!R.isPluginCompatible(Buffer))
247250
continue;
248251

249252
if (!initializePlugin(R))
@@ -266,7 +269,7 @@ void PluginManager::registerLib(__tgt_bin_desc *Desc) {
266269
continue;
267270
}
268271

269-
if (!R.is_device_compatible(DeviceId, Img))
272+
if (!R.isDeviceCompatible(DeviceId, Buffer))
270273
continue;
271274

272275
DP("Image " DPxMOD " is compatible with RTL %s device %d!\n",

offload/plugins-nextgen/common/include/PluginInterface.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1714,10 +1714,10 @@ struct GenericPluginTy {
17141714

17151715
/// Returns non-zero if the \p Image is compatible with the plugin. This
17161716
/// function does not require the plugin to be initialized before use.
1717-
int32_t is_plugin_compatible(__tgt_device_image *Image);
1717+
int32_t isPluginCompatible(StringRef Image);
17181718

17191719
/// Returns non-zero if the \p Image is compatible with the device.
1720-
int32_t is_device_compatible(int32_t DeviceId, __tgt_device_image *Image);
1720+
int32_t isDeviceCompatible(int32_t DeviceId, StringRef Image);
17211721

17221722
/// Returns non-zero if the plugin device has been initialized.
17231723
int32_t is_device_initialized(int32_t DeviceId) const;

offload/plugins-nextgen/common/src/PluginInterface.cpp

Lines changed: 57 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1981,85 +1981,66 @@ int32_t GenericPluginTy::supports_empty_images() {
19811981
return supportsEmptyImages();
19821982
}
19831983

1984-
int32_t GenericPluginTy::is_plugin_compatible(__tgt_device_image *Image) {
1985-
auto T = logger::log<int32_t>(__func__, Image);
1986-
auto R = [&]() {
1987-
StringRef Buffer(reinterpret_cast<const char *>(Image->ImageStart),
1988-
utils::getPtrDiff(Image->ImageEnd, Image->ImageStart));
1989-
1990-
auto HandleError = [&](Error Err) -> bool {
1991-
[[maybe_unused]] std::string ErrStr = toString(std::move(Err));
1992-
DP("Failure to check validity of image %p: %s", Image, ErrStr.c_str());
1993-
return false;
1994-
};
1995-
switch (identify_magic(Buffer)) {
1996-
case file_magic::elf:
1997-
case file_magic::elf_relocatable:
1998-
case file_magic::elf_executable:
1999-
case file_magic::elf_shared_object:
2000-
case file_magic::elf_core: {
2001-
auto MatchOrErr = checkELFImage(Buffer);
2002-
if (Error Err = MatchOrErr.takeError())
2003-
return HandleError(std::move(Err));
2004-
return *MatchOrErr;
2005-
}
2006-
case file_magic::bitcode: {
2007-
auto MatchOrErr = checkBitcodeImage(Buffer);
2008-
if (Error Err = MatchOrErr.takeError())
2009-
return HandleError(std::move(Err));
2010-
return *MatchOrErr;
2011-
}
2012-
default:
2013-
return false;
2014-
}
2015-
}();
2016-
T.res(R);
2017-
return R;
1984+
int32_t GenericPluginTy::isPluginCompatible(StringRef Image) {
1985+
auto HandleError = [&](Error Err) -> bool {
1986+
[[maybe_unused]] std::string ErrStr = toString(std::move(Err));
1987+
DP("Failure to check validity of image %p: %s", Image, ErrStr.c_str());
1988+
return false;
1989+
};
1990+
switch (identify_magic(Image)) {
1991+
case file_magic::elf:
1992+
case file_magic::elf_relocatable:
1993+
case file_magic::elf_executable:
1994+
case file_magic::elf_shared_object:
1995+
case file_magic::elf_core: {
1996+
auto MatchOrErr = checkELFImage(Image);
1997+
if (Error Err = MatchOrErr.takeError())
1998+
return HandleError(std::move(Err));
1999+
return *MatchOrErr;
2000+
}
2001+
case file_magic::bitcode: {
2002+
auto MatchOrErr = checkBitcodeImage(Image);
2003+
if (Error Err = MatchOrErr.takeError())
2004+
return HandleError(std::move(Err));
2005+
return *MatchOrErr;
2006+
}
2007+
default:
2008+
return false;
2009+
}
20182010
}
20192011

2020-
int32_t GenericPluginTy::is_device_compatible(int32_t DeviceId,
2021-
__tgt_device_image *Image) {
2022-
auto T = logger::log<int32_t>(__func__, DeviceId, Image);
2023-
auto R = [&]() {
2024-
StringRef Buffer(reinterpret_cast<const char *>(Image->ImageStart),
2025-
utils::getPtrDiff(Image->ImageEnd, Image->ImageStart));
2026-
2027-
auto HandleError = [&](Error Err) -> bool {
2028-
[[maybe_unused]] std::string ErrStr = toString(std::move(Err));
2029-
DP("Failure to check validity of image %p: %s", Image, ErrStr.c_str());
2030-
return false;
2031-
};
2032-
switch (identify_magic(Buffer)) {
2033-
case file_magic::elf:
2034-
case file_magic::elf_relocatable:
2035-
case file_magic::elf_executable:
2036-
case file_magic::elf_shared_object:
2037-
case file_magic::elf_core: {
2038-
auto MatchOrErr = checkELFImage(Buffer);
2039-
if (Error Err = MatchOrErr.takeError())
2040-
return HandleError(std::move(Err));
2041-
if (!*MatchOrErr)
2042-
return false;
2043-
2044-
// Perform plugin-dependent checks for the specific architecture if
2045-
// needed.
2046-
auto CompatibleOrErr = isELFCompatible(DeviceId, Buffer);
2047-
if (Error Err = CompatibleOrErr.takeError())
2048-
return HandleError(std::move(Err));
2049-
return *CompatibleOrErr;
2050-
}
2051-
case file_magic::bitcode: {
2052-
auto MatchOrErr = checkBitcodeImage(Buffer);
2053-
if (Error Err = MatchOrErr.takeError())
2054-
return HandleError(std::move(Err));
2055-
return *MatchOrErr;
2056-
}
2057-
default:
2012+
int32_t GenericPluginTy::isDeviceCompatible(int32_t DeviceId, StringRef Image) {
2013+
auto HandleError = [&](Error Err) -> bool {
2014+
[[maybe_unused]] std::string ErrStr = toString(std::move(Err));
2015+
DP("Failure to check validity of image %p: %s", Image, ErrStr.c_str());
2016+
return false;
2017+
};
2018+
switch (identify_magic(Image)) {
2019+
case file_magic::elf:
2020+
case file_magic::elf_relocatable:
2021+
case file_magic::elf_executable:
2022+
case file_magic::elf_shared_object:
2023+
case file_magic::elf_core: {
2024+
auto MatchOrErr = checkELFImage(Image);
2025+
if (Error Err = MatchOrErr.takeError())
2026+
return HandleError(std::move(Err));
2027+
if (!*MatchOrErr)
20582028
return false;
2059-
}
2060-
}();
2061-
T.res(R);
2062-
return R;
2029+
// Perform plugin-dependent checks for the specific architecture if needed.
2030+
auto CompatibleOrErr = isELFCompatible(DeviceId, Image);
2031+
if (Error Err = CompatibleOrErr.takeError())
2032+
return HandleError(std::move(Err));
2033+
return *CompatibleOrErr;
2034+
}
2035+
case file_magic::bitcode: {
2036+
auto MatchOrErr = checkBitcodeImage(Image);
2037+
if (Error Err = MatchOrErr.takeError())
2038+
return HandleError(std::move(Err));
2039+
return *MatchOrErr;
2040+
}
2041+
default:
2042+
return false;
2043+
}
20632044
}
20642045

20652046
int32_t GenericPluginTy::is_device_initialized(int32_t DeviceId) const {

offload/unittests/OffloadAPI/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ add_offload_unittest("platform"
3737

3838
add_offload_unittest("program"
3939
program/olCreateProgram.cpp
40+
program/olIsValidBinary.cpp
4041
program/olDestroyProgram.cpp)
4142

4243
add_offload_unittest("queue"
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//===------- Offload API tests - olIsValidBinary --------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "../common/Fixtures.hpp"
10+
#include <OffloadAPI.h>
11+
#include <gtest/gtest.h>
12+
13+
using olIsValidBinaryTest = OffloadDeviceTest;
14+
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olIsValidBinaryTest);
15+
16+
TEST_P(olIsValidBinaryTest, Success) {
17+
18+
std::unique_ptr<llvm::MemoryBuffer> DeviceBin;
19+
ASSERT_TRUE(TestEnvironment::loadDeviceBinary("foo", Device, DeviceBin));
20+
ASSERT_GE(DeviceBin->getBufferSize(), 0lu);
21+
22+
bool IsValid = false;
23+
ASSERT_SUCCESS(olIsValidBinary(Device, DeviceBin->getBufferStart(),
24+
DeviceBin->getBufferSize(), &IsValid));
25+
ASSERT_TRUE(IsValid);
26+
27+
ASSERT_SUCCESS(
28+
olIsValidBinary(Device, DeviceBin->getBufferStart(), 0, &IsValid));
29+
ASSERT_FALSE(IsValid);
30+
}
31+
32+
TEST_P(olIsValidBinaryTest, Invalid) {
33+
34+
std::unique_ptr<llvm::MemoryBuffer> DeviceBin;
35+
ASSERT_TRUE(TestEnvironment::loadDeviceBinary("foo", Device, DeviceBin));
36+
ASSERT_GE(DeviceBin->getBufferSize(), 0lu);
37+
38+
bool IsValid = false;
39+
ASSERT_SUCCESS(
40+
olIsValidBinary(Device, DeviceBin->getBufferStart(), 0, &IsValid));
41+
ASSERT_FALSE(IsValid);
42+
}
43+
44+
TEST_P(olIsValidBinaryTest, NullPointer) {
45+
bool IsValid = false;
46+
ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER,
47+
olIsValidBinary(Device, nullptr, 42, &IsValid));
48+
ASSERT_FALSE(IsValid);
49+
}

0 commit comments

Comments
 (0)