3939#include < cstdlib>
4040#include < cstring>
4141#include < fstream>
42+ #include < map>
4243#include < memory>
4344#include < mutex>
4445#include < sstream>
@@ -1112,8 +1113,8 @@ ProgramManager::getProgramBuildLog(const ur_program_handle_t &Program,
11121113// TODO device libraries may use scpecialization constants, manifest files, etc.
11131114// To support that they need to be delivered in a different container - so that
11141115// sycl_device_binary_struct can be created for each of them.
1115- static bool loadDeviceLib (const ContextImplPtr Context, const char *Name,
1116- ur_program_handle_t &Prog) {
1116+ static bool loadDeviceLibLegacy (const ContextImplPtr Context, const char *Name,
1117+ ur_program_handle_t &Prog) {
11171118 std::string LibSyclDir = OSUtil::getCurrentDSODir ();
11181119 std::ifstream File (LibSyclDir + OSUtil::DirSep + Name,
11191120 std::ifstream::in | std::ifstream::binary);
@@ -1133,6 +1134,13 @@ static bool loadDeviceLib(const ContextImplPtr Context, const char *Name,
11331134 return Prog != nullptr ;
11341135}
11351136
1137+ static bool loadDeviceLib (const ContextImplPtr Context,
1138+ ur_program_handle_t &Prog,
1139+ const unsigned char *SPVBuffer, size_t SPVSize) {
1140+ Prog = createSpirvProgram (Context, SPVBuffer, SPVSize);
1141+ return Prog != nullptr ;
1142+ }
1143+
11361144// For each extension, a pair of library names. The first uses native support,
11371145// the second emulates functionality in software.
11381146static const std::map<DeviceLibExt, std::pair<const char *, const char *>>
@@ -1213,9 +1221,13 @@ static ur_result_t doCompile(const AdapterPtr &Adapter,
12131221static ur_program_handle_t
12141222loadDeviceLibFallback (const ContextImplPtr Context, DeviceLibExt Extension,
12151223 std::vector<ur_device_handle_t > &Devices,
1216- bool UseNativeLib) {
1224+ bool UseNativeLib, bool LegacyMode = true ,
1225+ const unsigned char *SPVBuffer = nullptr ,
1226+ size_t SPVSize = 0 ) {
12171227
1218- auto LibFileName = getDeviceLibFilename (Extension, UseNativeLib);
1228+ const char *LibFileName = nullptr ;
1229+ if (LegacyMode)
1230+ LibFileName = getDeviceLibFilename (Extension, UseNativeLib);
12191231 auto LockedCache = Context->acquireCachedLibPrograms ();
12201232 auto &CachedLibPrograms = LockedCache.get ();
12211233 // Collect list of devices to compile the library for. Library was already
@@ -1252,10 +1264,20 @@ loadDeviceLibFallback(const ContextImplPtr Context, DeviceLibExt Extension,
12521264 bool IsProgramCreated = !URProgram;
12531265
12541266 // Create UR program for device lib if we don't have it yet.
1255- if (!URProgram && !loadDeviceLib (Context, LibFileName, URProgram)) {
1256- EraseProgramForDevices ();
1257- throw exception (make_error_code (errc::build),
1258- std::string (" Failed to load " ) + LibFileName);
1267+ if (LegacyMode) {
1268+ if (!URProgram && !loadDeviceLibLegacy (Context, LibFileName, URProgram)) {
1269+ EraseProgramForDevices ();
1270+ throw exception (make_error_code (errc::build),
1271+ std::string (" Failed to load " ) + LibFileName);
1272+ }
1273+ } else {
1274+ if (!URProgram && !loadDeviceLib (Context, URProgram, SPVBuffer, SPVSize)) {
1275+ EraseProgramForDevices ();
1276+ const char *ExtStr = getDeviceLibExtensionStr (Extension);
1277+ throw exception (
1278+ make_error_code (errc::build),
1279+ std::string (" Failed to load fallback device library for " ) + ExtStr);
1280+ }
12591281 }
12601282
12611283 // Insert URProgram into the cache for all devices that we compiled it for.
@@ -1513,6 +1535,8 @@ static bool isDeviceLibRequired(DeviceLibExt Ext, uint32_t DeviceLibReqMask) {
15131535 return ((DeviceLibReqMask & Mask) == Mask);
15141536}
15151537
1538+ // TODO: Clear legacy getDeviceLibPrograms when developers upgrade to
1539+ // latest version compiler.
15161540static std::vector<ur_program_handle_t >
15171541getDeviceLibProgramsLegacy (const ContextImplPtr Context,
15181542 std::vector<ur_device_handle_t > &Devices,
@@ -1604,6 +1628,38 @@ getDeviceLibPrograms(const ContextImplPtr Context,
16041628 std::vector<ur_device_handle_t > &Devices,
16051629 const std::vector<const RTDeviceBinaryImage *> &Images) {
16061630 std::vector<ur_program_handle_t > Programs;
1631+ std::map<DeviceLibExt, bool > DeviceLibExtLoaded = {
1632+ {DeviceLibExt::cl_intel_devicelib_assert,
1633+ /* is fallback loaded? */ false },
1634+ {DeviceLibExt::cl_intel_devicelib_math, false },
1635+ {DeviceLibExt::cl_intel_devicelib_math_fp64, false },
1636+ {DeviceLibExt::cl_intel_devicelib_complex, false },
1637+ {DeviceLibExt::cl_intel_devicelib_complex_fp64, false },
1638+ {DeviceLibExt::cl_intel_devicelib_cstring, false },
1639+ {DeviceLibExt::cl_intel_devicelib_imf, false },
1640+ {DeviceLibExt::cl_intel_devicelib_imf_fp64, false },
1641+ {DeviceLibExt::cl_intel_devicelib_imf_bf16, false },
1642+ {DeviceLibExt::cl_intel_devicelib_bfloat16, false }};
1643+
1644+ // Check whether a specified extension is supported by ALL devices.
1645+ auto checkExtForDevices = [&Context, &Devices](const char *ExtStr) -> bool {
1646+ bool ExtAvailable = true ;
1647+ for (auto SingleDevice : Devices) {
1648+ std::string DevExtList =
1649+ Context->getPlatformImpl ()
1650+ ->getDeviceImpl (SingleDevice)
1651+ ->get_device_info_string (
1652+ UrInfoCode<info::device::extensions>::value);
1653+ if (DevExtList.npos == DevExtList.find (ExtStr)) {
1654+ ExtAvailable = false ;
1655+ break ;
1656+ }
1657+ }
1658+ return ExtAvailable;
1659+ };
1660+
1661+ const bool fp64Support = checkExtForDevices (" cl_khr_fp64" );
1662+
16071663 for (auto Img : Images) {
16081664 if (!Img)
16091665 continue ;
@@ -1616,11 +1672,47 @@ getDeviceLibPrograms(const ContextImplPtr Context,
16161672 auto DeviceLibByteArray =
16171673 DeviceBinaryProperty (DeviceLibBinProp).asByteArray ();
16181674 DeviceLibByteArray.dropBytes (8 );
1619- uint32_t DeviceLibExtReq =
1675+ DeviceLibExt DeviceLibExtReq = static_cast <DeviceLibExt>(
16201676 (static_cast <uint32_t >(DeviceLibByteArray[3 ]) << 24 ) |
16211677 (static_cast <uint32_t >(DeviceLibByteArray[2 ]) << 16 ) |
16221678 (static_cast <uint32_t >(DeviceLibByteArray[1 ]) << 8 ) |
1623- DeviceLibByteArray[0 ];
1679+ DeviceLibByteArray[0 ]);
1680+ if (DeviceLibExtLoaded.count (DeviceLibExtReq) != 1 ) {
1681+ if constexpr (DbgProgMgr > 0 ) {
1682+ std::cerr << " Unknown DeviceLib extension("
1683+ << static_cast <uint32_t >(DeviceLibExtReq) << " )!"
1684+ << std::endl;
1685+ }
1686+ continue ;
1687+ }
1688+
1689+ if (DeviceLibExtLoaded[DeviceLibExtReq])
1690+ continue ;
1691+
1692+ if ((DeviceLibExtReq == DeviceLibExt::cl_intel_devicelib_math_fp64 ||
1693+ DeviceLibExtReq == DeviceLibExt::cl_intel_devicelib_complex_fp64 ||
1694+ DeviceLibExtReq == DeviceLibExt::cl_intel_devicelib_imf_fp64) &&
1695+ !fp64Support)
1696+ continue ;
1697+
1698+ auto DeviceLibExtReqName = getDeviceLibExtensionStr (DeviceLibExtReq);
1699+ bool InhibitNativeImpl = false ;
1700+ if (const char *Env = getenv (" SYCL_DEVICELIB_INHIBIT_NATIVE" )) {
1701+ InhibitNativeImpl = strstr (Env, DeviceLibExtReqName) != nullptr ;
1702+ }
1703+
1704+ bool ExtReqAvailable = checkExtForDevices (DeviceLibExtReqName);
1705+
1706+ // Load fallback device library only when 1) or 2) is met:
1707+ // 1. underlying device doesn't support the extension
1708+ // 2. user explicitly ask to inhibit usage of native support
1709+ if (!ExtReqAvailable || InhibitNativeImpl) {
1710+ DeviceLibByteArray.dropBytes (4 );
1711+ Programs.push_back (loadDeviceLibFallback (
1712+ Context, DeviceLibExtReq, Devices,
1713+ /* UseNativeLib=*/ false , false , DeviceLibByteArray.begin (),
1714+ DeviceLibByteArray.size ()));
1715+ }
16241716 }
16251717 }
16261718 return Programs;
0 commit comments