@@ -1171,8 +1171,8 @@ ProgramManager::getProgramBuildLog(const ur_program_handle_t &Program,
11711171// TODO device libraries may use scpecialization constants, manifest files, etc.
11721172// To support that they need to be delivered in a different container - so that
11731173// sycl_device_binary_struct can be created for each of them.
1174- static bool loadDeviceLib (const ContextImplPtr Context, const char *Name,
1175- ur_program_handle_t &Prog) {
1174+ static bool loadDeviceLibLegacy (const ContextImplPtr Context, const char *Name,
1175+ ur_program_handle_t &Prog) {
11761176 std::string LibSyclDir = OSUtil::getCurrentDSODir ();
11771177 std::ifstream File (LibSyclDir + OSUtil::DirSep + Name,
11781178 std::ifstream::in | std::ifstream::binary);
@@ -1192,6 +1192,14 @@ static bool loadDeviceLib(const ContextImplPtr Context, const char *Name,
11921192 return Prog != nullptr ;
11931193}
11941194
1195+ static bool loadDeviceLib (const ContextImplPtr Context,
1196+ ur_program_handle_t &Prog,
1197+ const unsigned char *DeviceLibImageBuffer,
1198+ size_t DeviceLibImageSize) {
1199+ Prog = createSpirvProgram (Context, DeviceLibImageBuffer, DeviceLibImageSize);
1200+ return Prog != nullptr ;
1201+ }
1202+
11951203// For each extension, a pair of library names. The first uses native support,
11961204// the second emulates functionality in software.
11971205static const std::map<DeviceLibExt, std::pair<const char *, const char *>>
@@ -1272,9 +1280,13 @@ static ur_result_t doCompile(const AdapterPtr &Adapter,
12721280static ur_program_handle_t
12731281loadDeviceLibFallback (const ContextImplPtr Context, DeviceLibExt Extension,
12741282 std::vector<ur_device_handle_t > &Devices,
1275- bool UseNativeLib) {
1283+ bool UseNativeLib, bool LegacyMode = true ,
1284+ const unsigned char *DeviceLibImageBuffer = nullptr ,
1285+ size_t DeviceLibImageSize = 0 ) {
12761286
1277- auto LibFileName = getDeviceLibFilename (Extension, UseNativeLib);
1287+ const char *LibFileName = nullptr ;
1288+ if (LegacyMode)
1289+ LibFileName = getDeviceLibFilename (Extension, UseNativeLib);
12781290 auto LockedCache = Context->acquireCachedLibPrograms ();
12791291 auto &CachedLibPrograms = LockedCache.get ();
12801292 // Collect list of devices to compile the library for. Library was already
@@ -1311,10 +1323,21 @@ loadDeviceLibFallback(const ContextImplPtr Context, DeviceLibExt Extension,
13111323 bool IsProgramCreated = !URProgram;
13121324
13131325 // Create UR program for device lib if we don't have it yet.
1314- if (!URProgram && !loadDeviceLib (Context, LibFileName, URProgram)) {
1315- EraseProgramForDevices ();
1316- throw exception (make_error_code (errc::build),
1317- std::string (" Failed to load " ) + LibFileName);
1326+ if (LegacyMode) {
1327+ if (!URProgram && !loadDeviceLibLegacy (Context, LibFileName, URProgram)) {
1328+ EraseProgramForDevices ();
1329+ throw exception (make_error_code (errc::build),
1330+ std::string (" Failed to load " ) + LibFileName);
1331+ }
1332+ } else {
1333+ if (!URProgram && !loadDeviceLib (Context, URProgram, DeviceLibImageBuffer,
1334+ DeviceLibImageSize)) {
1335+ EraseProgramForDevices ();
1336+ const char *ExtStr = getDeviceLibExtensionStr (Extension);
1337+ throw exception (
1338+ make_error_code (errc::build),
1339+ std::string (" Failed to load fallback device library for " ) + ExtStr);
1340+ }
13181341 }
13191342
13201343 // Insert URProgram into the cache for all devices that we compiled it for.
@@ -1573,9 +1596,9 @@ static bool isDeviceLibRequired(DeviceLibExt Ext, uint32_t DeviceLibReqMask) {
15731596}
15741597
15751598static std::vector<ur_program_handle_t >
1576- getDeviceLibPrograms (const ContextImplPtr Context,
1577- std::vector<ur_device_handle_t > &Devices,
1578- uint32_t DeviceLibReqMask) {
1599+ getDeviceLibProgramsLegacy (const ContextImplPtr Context,
1600+ std::vector<ur_device_handle_t > &Devices,
1601+ uint32_t DeviceLibReqMask) {
15791602 std::vector<ur_program_handle_t > Programs;
15801603
15811604 std::pair<DeviceLibExt, bool > RequiredDeviceLibExt[] = {
@@ -1658,6 +1681,83 @@ getDeviceLibPrograms(const ContextImplPtr Context,
16581681 return Programs;
16591682}
16601683
1684+ std::vector<ur_program_handle_t > ProgramManager::getDeviceLibReqPrograms (
1685+ const ContextImplPtr Context, std::vector<ur_device_handle_t > &Devices,
1686+ uint32_t DeviceLibReqMask) {
1687+
1688+ std::vector<ur_program_handle_t > Programs;
1689+
1690+ // Check whether a specified extension is supported by ALL devices.
1691+ auto checkExtForDevices = [&Context, &Devices](const char *ExtStr) -> bool {
1692+ bool ExtAvailable = true ;
1693+ for (auto SingleDevice : Devices) {
1694+ std::string DevExtList =
1695+ Context->getPlatformImpl ()
1696+ ->getDeviceImpl (SingleDevice)
1697+ ->get_device_info_string (
1698+ UrInfoCode<info::device::extensions>::value);
1699+ if (DevExtList.npos == DevExtList.find (ExtStr)) {
1700+ ExtAvailable = false ;
1701+ break ;
1702+ }
1703+ }
1704+ return ExtAvailable;
1705+ };
1706+
1707+ const bool fp64Support = checkExtForDevices (" cl_khr_fp64" );
1708+
1709+ size_t Idx = 0 ;
1710+ std::vector<DeviceLibExt> ReqDeviceLibExts;
1711+ while (DeviceLibReqMask != 0 ) {
1712+ if (DeviceLibReqMask & 1 ) {
1713+ DeviceLibExt ExtReq = static_cast <DeviceLibExt>(
1714+ static_cast <uint32_t >(DeviceLibExt::cl_intel_devicelib_assert) + Idx);
1715+ ReqDeviceLibExts.push_back (ExtReq);
1716+ }
1717+ ++Idx;
1718+ DeviceLibReqMask = DeviceLibReqMask >> 1 ;
1719+ }
1720+
1721+ std::vector<unsigned > ReqExtMetaKeys;
1722+ for (auto Ext : ReqDeviceLibExts) {
1723+ if ((Ext == DeviceLibExt::cl_intel_devicelib_math_fp64 ||
1724+ Ext == DeviceLibExt::cl_intel_devicelib_complex_fp64 ||
1725+ Ext == DeviceLibExt::cl_intel_devicelib_imf_fp64) &&
1726+ !fp64Support) {
1727+ continue ;
1728+ }
1729+ auto ExtName = getDeviceLibExtensionStr (Ext);
1730+ bool InhibitNativeImpl = false ;
1731+ if (const char *Env = getenv (" SYCL_DEVICELIB_INHIBIT_NATIVE" )) {
1732+ InhibitNativeImpl = strstr (Env, ExtName) != nullptr ;
1733+ }
1734+ bool ExtReqAvailable = checkExtForDevices (ExtName);
1735+ unsigned ExtMetaKey = static_cast <unsigned >(Ext);
1736+ if (ExtReqAvailable && !InhibitNativeImpl) {
1737+ if (Ext == DeviceLibExt::cl_intel_devicelib_bfloat16) {
1738+ ExtMetaKey = ExtMetaKey | 0x80000000 ;
1739+ } else
1740+ continue ;
1741+ }
1742+ ReqExtMetaKeys.push_back (ExtMetaKey);
1743+ }
1744+
1745+ if (ReqExtMetaKeys.size () > 0 ) {
1746+ std::lock_guard<std::mutex> DeviceLibImagesGuard (m_DeviceLibImagesMutex);
1747+ for (auto Key : ReqExtMetaKeys) {
1748+ if (m_DeviceLibImages.find (Key) != m_DeviceLibImages.end ()) {
1749+ bool IsNative = ((Key & 0x80000000 ) > 0 );
1750+ DeviceLibExt Ext = static_cast <DeviceLibExt>(Key & 0x7FFFFFFF );
1751+ Programs.push_back (loadDeviceLibFallback (
1752+ Context, Ext, Devices, IsNative, false ,
1753+ m_DeviceLibImages[Key]->getRawData ().BinaryStart ,
1754+ m_DeviceLibImages[Key]->getSize ()));
1755+ }
1756+ }
1757+ }
1758+ return Programs;
1759+ }
1760+
16611761// Check if device image is compressed.
16621762static inline bool isDeviceImageCompressed (sycl_device_binary Bin) {
16631763
@@ -1691,7 +1791,11 @@ ProgramManager::ProgramPtr ProgramManager::build(
16911791
16921792 std::vector<ur_program_handle_t > LinkPrograms;
16931793 if (LinkDeviceLibs) {
1694- LinkPrograms = getDeviceLibPrograms (Context, Devices, DeviceLibReqMask);
1794+ LinkPrograms = getDeviceLibReqPrograms (Context, Devices, DeviceLibReqMask);
1795+ if (LinkPrograms.size () == 0 ) {
1796+ LinkPrograms =
1797+ getDeviceLibProgramsLegacy (Context, Devices, DeviceLibReqMask);
1798+ }
16951799 }
16961800
16971801 static const char *ForceLinkEnv = std::getenv (" SYCL_FORCE_LINK" );
0 commit comments