diff --git a/offload/liboffload/API/Program.td b/offload/liboffload/API/Program.td index 1f48f650cab70..7e11b3d8e331e 100644 --- a/offload/liboffload/API/Program.td +++ b/offload/liboffload/API/Program.td @@ -24,6 +24,18 @@ def olCreateProgram : Function { let returns = []; } +def olIsValidBinary : Function { + let desc = "Validate if the binary image pointed to by `ProgData` is compatible with the device."; + let details = ["The provided `ProgData` will not be loaded onto the device"]; + let params = [ + Param<"ol_device_handle_t", "Device", "handle of the device", PARAM_IN>, + Param<"const void*", "ProgData", "pointer to the program binary data", PARAM_IN>, + Param<"size_t", "ProgDataSize", "size of the program binary in bytes", PARAM_IN>, + Param<"bool*", "Valid", "output is true if the image is compatible", PARAM_OUT> + ]; + let returns = []; +} + def olDestroyProgram : Function { let desc = "Destroy the program and free all underlying resources."; let details = []; diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index b5b9b0e83b975..c5d083db7522e 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -887,7 +887,6 @@ Error olMemFill_impl(ol_queue_handle_t Queue, void *Ptr, size_t PatternSize, Error olCreateProgram_impl(ol_device_handle_t Device, const void *ProgData, size_t ProgDataSize, ol_program_handle_t *Program) { - // Make a copy of the program binary in case it is released by the caller. StringRef Buffer(reinterpret_cast(ProgData), ProgDataSize); Expected Res = Device->Device->loadBinary(Device->Device->Plugin, Buffer); @@ -899,6 +898,14 @@ Error olCreateProgram_impl(ol_device_handle_t Device, const void *ProgData, return Error::success(); } +Error olIsValidBinary_impl(ol_device_handle_t Device, const void *ProgData, + size_t ProgDataSize, bool *IsValid) { + StringRef Buffer(reinterpret_cast(ProgData), ProgDataSize); + *IsValid = Device->Device->Plugin.isDeviceCompatible( + Device->Device->getDeviceId(), Buffer); + return Error::success(); +} + Error olDestroyProgram_impl(ol_program_handle_t Program) { auto &Device = Program->Image->getDevice(); if (auto Err = Device.unloadBinary(Program->Image)) diff --git a/offload/libomptarget/PluginManager.cpp b/offload/libomptarget/PluginManager.cpp index b57a2f815cba6..c8d6b42114d0f 100644 --- a/offload/libomptarget/PluginManager.cpp +++ b/offload/libomptarget/PluginManager.cpp @@ -219,7 +219,10 @@ void PluginManager::registerLib(__tgt_bin_desc *Desc) { // Scan the RTLs that have associated images until we find one that supports // the current image. for (auto &R : plugins()) { - if (!R.is_plugin_compatible(Img)) + StringRef Buffer(reinterpret_cast(Img->ImageStart), + utils::getPtrDiff(Img->ImageEnd, Img->ImageStart)); + + if (!R.isPluginCompatible(Buffer)) continue; if (!initializePlugin(R)) @@ -242,7 +245,7 @@ void PluginManager::registerLib(__tgt_bin_desc *Desc) { continue; } - if (!R.is_device_compatible(DeviceId, Img)) + if (!R.isDeviceCompatible(DeviceId, Buffer)) continue; DP("Image " DPxMOD " is compatible with RTL %s device %d!\n", diff --git a/offload/plugins-nextgen/common/include/PluginInterface.h b/offload/plugins-nextgen/common/include/PluginInterface.h index ce66d277d6187..9d5651a3d7b4e 100644 --- a/offload/plugins-nextgen/common/include/PluginInterface.h +++ b/offload/plugins-nextgen/common/include/PluginInterface.h @@ -1378,10 +1378,10 @@ struct GenericPluginTy { /// Returns non-zero if the \p Image is compatible with the plugin. This /// function does not require the plugin to be initialized before use. - int32_t is_plugin_compatible(__tgt_device_image *Image); + int32_t isPluginCompatible(StringRef Image); /// Returns non-zero if the \p Image is compatible with the device. - int32_t is_device_compatible(int32_t DeviceId, __tgt_device_image *Image); + int32_t isDeviceCompatible(int32_t DeviceId, StringRef Image); /// Returns non-zero if the plugin device has been initialized. int32_t is_device_initialized(int32_t DeviceId) const; diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp index 9f830874d5dad..b1955b53b80e5 100644 --- a/offload/plugins-nextgen/common/src/PluginInterface.cpp +++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp @@ -1713,28 +1713,25 @@ Expected GenericPluginTy::checkBitcodeImage(StringRef Image) const { int32_t GenericPluginTy::is_initialized() const { return Initialized; } -int32_t GenericPluginTy::is_plugin_compatible(__tgt_device_image *Image) { - StringRef Buffer(reinterpret_cast(Image->ImageStart), - utils::getPtrDiff(Image->ImageEnd, Image->ImageStart)); - +int32_t GenericPluginTy::isPluginCompatible(StringRef Image) { auto HandleError = [&](Error Err) -> bool { [[maybe_unused]] std::string ErrStr = toString(std::move(Err)); DP("Failure to check validity of image %p: %s", Image, ErrStr.c_str()); return false; }; - switch (identify_magic(Buffer)) { + switch (identify_magic(Image)) { case file_magic::elf: case file_magic::elf_relocatable: case file_magic::elf_executable: case file_magic::elf_shared_object: case file_magic::elf_core: { - auto MatchOrErr = checkELFImage(Buffer); + auto MatchOrErr = checkELFImage(Image); if (Error Err = MatchOrErr.takeError()) return HandleError(std::move(Err)); return *MatchOrErr; } case file_magic::bitcode: { - auto MatchOrErr = checkBitcodeImage(Buffer); + auto MatchOrErr = checkBitcodeImage(Image); if (Error Err = MatchOrErr.takeError()) return HandleError(std::move(Err)); return *MatchOrErr; @@ -1744,36 +1741,32 @@ int32_t GenericPluginTy::is_plugin_compatible(__tgt_device_image *Image) { } } -int32_t GenericPluginTy::is_device_compatible(int32_t DeviceId, - __tgt_device_image *Image) { - StringRef Buffer(reinterpret_cast(Image->ImageStart), - utils::getPtrDiff(Image->ImageEnd, Image->ImageStart)); - +int32_t GenericPluginTy::isDeviceCompatible(int32_t DeviceId, StringRef Image) { auto HandleError = [&](Error Err) -> bool { [[maybe_unused]] std::string ErrStr = toString(std::move(Err)); DP("Failure to check validity of image %p: %s", Image, ErrStr.c_str()); return false; }; - switch (identify_magic(Buffer)) { + switch (identify_magic(Image)) { case file_magic::elf: case file_magic::elf_relocatable: case file_magic::elf_executable: case file_magic::elf_shared_object: case file_magic::elf_core: { - auto MatchOrErr = checkELFImage(Buffer); + auto MatchOrErr = checkELFImage(Image); if (Error Err = MatchOrErr.takeError()) return HandleError(std::move(Err)); if (!*MatchOrErr) return false; // Perform plugin-dependent checks for the specific architecture if needed. - auto CompatibleOrErr = isELFCompatible(DeviceId, Buffer); + auto CompatibleOrErr = isELFCompatible(DeviceId, Image); if (Error Err = CompatibleOrErr.takeError()) return HandleError(std::move(Err)); return *CompatibleOrErr; } case file_magic::bitcode: { - auto MatchOrErr = checkBitcodeImage(Buffer); + auto MatchOrErr = checkBitcodeImage(Image); if (Error Err = MatchOrErr.takeError()) return HandleError(std::move(Err)); return *MatchOrErr; diff --git a/offload/unittests/OffloadAPI/CMakeLists.txt b/offload/unittests/OffloadAPI/CMakeLists.txt index b2d514423a6ee..ba35c1ee87aac 100644 --- a/offload/unittests/OffloadAPI/CMakeLists.txt +++ b/offload/unittests/OffloadAPI/CMakeLists.txt @@ -35,6 +35,7 @@ add_offload_unittest("platform" add_offload_unittest("program" program/olCreateProgram.cpp + program/olIsValidBinary.cpp program/olDestroyProgram.cpp) add_offload_unittest("queue" diff --git a/offload/unittests/OffloadAPI/program/olIsValidBinary.cpp b/offload/unittests/OffloadAPI/program/olIsValidBinary.cpp new file mode 100644 index 0000000000000..02e805dd1135b --- /dev/null +++ b/offload/unittests/OffloadAPI/program/olIsValidBinary.cpp @@ -0,0 +1,49 @@ +//===------- Offload API tests - olIsValidBinary --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "../common/Fixtures.hpp" +#include +#include + +using olIsValidBinaryTest = OffloadDeviceTest; +OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olIsValidBinaryTest); + +TEST_P(olIsValidBinaryTest, Success) { + + std::unique_ptr DeviceBin; + ASSERT_TRUE(TestEnvironment::loadDeviceBinary("foo", Device, DeviceBin)); + ASSERT_GE(DeviceBin->getBufferSize(), 0lu); + + bool IsValid = false; + ASSERT_SUCCESS(olIsValidBinary(Device, DeviceBin->getBufferStart(), + DeviceBin->getBufferSize(), &IsValid)); + ASSERT_TRUE(IsValid); + + ASSERT_SUCCESS( + olIsValidBinary(Device, DeviceBin->getBufferStart(), 0, &IsValid)); + ASSERT_FALSE(IsValid); +} + +TEST_P(olIsValidBinaryTest, Invalid) { + + std::unique_ptr DeviceBin; + ASSERT_TRUE(TestEnvironment::loadDeviceBinary("foo", Device, DeviceBin)); + ASSERT_GE(DeviceBin->getBufferSize(), 0lu); + + bool IsValid = false; + ASSERT_SUCCESS( + olIsValidBinary(Device, DeviceBin->getBufferStart(), 0, &IsValid)); + ASSERT_FALSE(IsValid); +} + +TEST_P(olIsValidBinaryTest, NullPointer) { + bool IsValid = false; + ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, + olIsValidBinary(Device, nullptr, 42, &IsValid)); + ASSERT_FALSE(IsValid); +}