Skip to content

Commit 331644e

Browse files
committed
extract required spv binary from exe
Signed-off-by: jinge90 <[email protected]>
1 parent 2976cde commit 331644e

File tree

1 file changed

+102
-10
lines changed

1 file changed

+102
-10
lines changed

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 102 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#include <cstdlib>
4040
#include <cstring>
4141
#include <fstream>
42+
#include <map>
4243
#include <memory>
4344
#include <mutex>
4445
#include <sstream>
@@ -1112,8 +1113,8 @@ ProgramManager::getProgramBuildLog(const ur_program_handle_t &Program,
11121113
// TODO device libraries may use scpecialization constants, manifest files, etc.
11131114
// To support that they need to be delivered in a different container - so that
11141115
// sycl_device_binary_struct can be created for each of them.
1115-
static bool loadDeviceLib(const ContextImplPtr Context, const char *Name,
1116-
ur_program_handle_t &Prog) {
1116+
static bool loadDeviceLibLegacy(const ContextImplPtr Context, const char *Name,
1117+
ur_program_handle_t &Prog) {
11171118
std::string LibSyclDir = OSUtil::getCurrentDSODir();
11181119
std::ifstream File(LibSyclDir + OSUtil::DirSep + Name,
11191120
std::ifstream::in | std::ifstream::binary);
@@ -1133,6 +1134,13 @@ static bool loadDeviceLib(const ContextImplPtr Context, const char *Name,
11331134
return Prog != nullptr;
11341135
}
11351136

1137+
static bool loadDeviceLib(const ContextImplPtr Context,
1138+
ur_program_handle_t &Prog,
1139+
const unsigned char *SPVBuffer, size_t SPVSize) {
1140+
Prog = createSpirvProgram(Context, SPVBuffer, SPVSize);
1141+
return Prog != nullptr;
1142+
}
1143+
11361144
// For each extension, a pair of library names. The first uses native support,
11371145
// the second emulates functionality in software.
11381146
static const std::map<DeviceLibExt, std::pair<const char *, const char *>>
@@ -1213,9 +1221,13 @@ static ur_result_t doCompile(const AdapterPtr &Adapter,
12131221
static ur_program_handle_t
12141222
loadDeviceLibFallback(const ContextImplPtr Context, DeviceLibExt Extension,
12151223
std::vector<ur_device_handle_t> &Devices,
1216-
bool UseNativeLib) {
1224+
bool UseNativeLib, bool LegacyMode = true,
1225+
const unsigned char *SPVBuffer = nullptr,
1226+
size_t SPVSize = 0) {
12171227

1218-
auto LibFileName = getDeviceLibFilename(Extension, UseNativeLib);
1228+
const char *LibFileName = nullptr;
1229+
if (LegacyMode)
1230+
LibFileName = getDeviceLibFilename(Extension, UseNativeLib);
12191231
auto LockedCache = Context->acquireCachedLibPrograms();
12201232
auto &CachedLibPrograms = LockedCache.get();
12211233
// Collect list of devices to compile the library for. Library was already
@@ -1252,10 +1264,20 @@ loadDeviceLibFallback(const ContextImplPtr Context, DeviceLibExt Extension,
12521264
bool IsProgramCreated = !URProgram;
12531265

12541266
// Create UR program for device lib if we don't have it yet.
1255-
if (!URProgram && !loadDeviceLib(Context, LibFileName, URProgram)) {
1256-
EraseProgramForDevices();
1257-
throw exception(make_error_code(errc::build),
1258-
std::string("Failed to load ") + LibFileName);
1267+
if (LegacyMode) {
1268+
if (!URProgram && !loadDeviceLibLegacy(Context, LibFileName, URProgram)) {
1269+
EraseProgramForDevices();
1270+
throw exception(make_error_code(errc::build),
1271+
std::string("Failed to load ") + LibFileName);
1272+
}
1273+
} else {
1274+
if (!URProgram && !loadDeviceLib(Context, URProgram, SPVBuffer, SPVSize)) {
1275+
EraseProgramForDevices();
1276+
const char *ExtStr = getDeviceLibExtensionStr(Extension);
1277+
throw exception(
1278+
make_error_code(errc::build),
1279+
std::string("Failed to load fallback device library for ") + ExtStr);
1280+
}
12591281
}
12601282

12611283
// Insert URProgram into the cache for all devices that we compiled it for.
@@ -1513,6 +1535,8 @@ static bool isDeviceLibRequired(DeviceLibExt Ext, uint32_t DeviceLibReqMask) {
15131535
return ((DeviceLibReqMask & Mask) == Mask);
15141536
}
15151537

1538+
// TODO: Clear legacy getDeviceLibPrograms when developers upgrade to
1539+
// latest version compiler.
15161540
static std::vector<ur_program_handle_t>
15171541
getDeviceLibProgramsLegacy(const ContextImplPtr Context,
15181542
std::vector<ur_device_handle_t> &Devices,
@@ -1604,6 +1628,38 @@ getDeviceLibPrograms(const ContextImplPtr Context,
16041628
std::vector<ur_device_handle_t> &Devices,
16051629
const std::vector<const RTDeviceBinaryImage *> &Images) {
16061630
std::vector<ur_program_handle_t> Programs;
1631+
std::map<DeviceLibExt, bool> DeviceLibExtLoaded = {
1632+
{DeviceLibExt::cl_intel_devicelib_assert,
1633+
/* is fallback loaded? */ false},
1634+
{DeviceLibExt::cl_intel_devicelib_math, false},
1635+
{DeviceLibExt::cl_intel_devicelib_math_fp64, false},
1636+
{DeviceLibExt::cl_intel_devicelib_complex, false},
1637+
{DeviceLibExt::cl_intel_devicelib_complex_fp64, false},
1638+
{DeviceLibExt::cl_intel_devicelib_cstring, false},
1639+
{DeviceLibExt::cl_intel_devicelib_imf, false},
1640+
{DeviceLibExt::cl_intel_devicelib_imf_fp64, false},
1641+
{DeviceLibExt::cl_intel_devicelib_imf_bf16, false},
1642+
{DeviceLibExt::cl_intel_devicelib_bfloat16, false}};
1643+
1644+
// Check whether a specified extension is supported by ALL devices.
1645+
auto checkExtForDevices = [&Context, &Devices](const char *ExtStr) -> bool {
1646+
bool ExtAvailable = true;
1647+
for (auto SingleDevice : Devices) {
1648+
std::string DevExtList =
1649+
Context->getPlatformImpl()
1650+
->getDeviceImpl(SingleDevice)
1651+
->get_device_info_string(
1652+
UrInfoCode<info::device::extensions>::value);
1653+
if (DevExtList.npos == DevExtList.find(ExtStr)) {
1654+
ExtAvailable = false;
1655+
break;
1656+
}
1657+
}
1658+
return ExtAvailable;
1659+
};
1660+
1661+
const bool fp64Support = checkExtForDevices("cl_khr_fp64");
1662+
16071663
for (auto Img : Images) {
16081664
if (!Img)
16091665
continue;
@@ -1616,11 +1672,47 @@ getDeviceLibPrograms(const ContextImplPtr Context,
16161672
auto DeviceLibByteArray =
16171673
DeviceBinaryProperty(DeviceLibBinProp).asByteArray();
16181674
DeviceLibByteArray.dropBytes(8);
1619-
uint32_t DeviceLibExtReq =
1675+
DeviceLibExt DeviceLibExtReq = static_cast<DeviceLibExt>(
16201676
(static_cast<uint32_t>(DeviceLibByteArray[3]) << 24) |
16211677
(static_cast<uint32_t>(DeviceLibByteArray[2]) << 16) |
16221678
(static_cast<uint32_t>(DeviceLibByteArray[1]) << 8) |
1623-
DeviceLibByteArray[0];
1679+
DeviceLibByteArray[0]);
1680+
if (DeviceLibExtLoaded.count(DeviceLibExtReq) != 1) {
1681+
if constexpr (DbgProgMgr > 0) {
1682+
std::cerr << "Unknown DeviceLib extension("
1683+
<< static_cast<uint32_t>(DeviceLibExtReq) << ")!"
1684+
<< std::endl;
1685+
}
1686+
continue;
1687+
}
1688+
1689+
if (DeviceLibExtLoaded[DeviceLibExtReq])
1690+
continue;
1691+
1692+
if ((DeviceLibExtReq == DeviceLibExt::cl_intel_devicelib_math_fp64 ||
1693+
DeviceLibExtReq == DeviceLibExt::cl_intel_devicelib_complex_fp64 ||
1694+
DeviceLibExtReq == DeviceLibExt::cl_intel_devicelib_imf_fp64) &&
1695+
!fp64Support)
1696+
continue;
1697+
1698+
auto DeviceLibExtReqName = getDeviceLibExtensionStr(DeviceLibExtReq);
1699+
bool InhibitNativeImpl = false;
1700+
if (const char *Env = getenv("SYCL_DEVICELIB_INHIBIT_NATIVE")) {
1701+
InhibitNativeImpl = strstr(Env, DeviceLibExtReqName) != nullptr;
1702+
}
1703+
1704+
bool ExtReqAvailable = checkExtForDevices(DeviceLibExtReqName);
1705+
1706+
// Load fallback device library only when 1) or 2) is met:
1707+
// 1. underlying device doesn't support the extension
1708+
// 2. user explicitly ask to inhibit usage of native support
1709+
if (!ExtReqAvailable || InhibitNativeImpl) {
1710+
DeviceLibByteArray.dropBytes(4);
1711+
Programs.push_back(loadDeviceLibFallback(
1712+
Context, DeviceLibExtReq, Devices,
1713+
/*UseNativeLib=*/false, false, DeviceLibByteArray.begin(),
1714+
DeviceLibByteArray.size()));
1715+
}
16241716
}
16251717
}
16261718
return Programs;

0 commit comments

Comments
 (0)