Skip to content

Commit 57fb543

Browse files
committed
link device lib image when required
Signed-off-by: jinge90 <[email protected]>
1 parent 42f4db5 commit 57fb543

File tree

2 files changed

+121
-12
lines changed

2 files changed

+121
-12
lines changed

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 116 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,8 +1171,8 @@ ProgramManager::getProgramBuildLog(const ur_program_handle_t &Program,
11711171
// TODO device libraries may use scpecialization constants, manifest files, etc.
11721172
// To support that they need to be delivered in a different container - so that
11731173
// sycl_device_binary_struct can be created for each of them.
1174-
static bool loadDeviceLib(const ContextImplPtr Context, const char *Name,
1175-
ur_program_handle_t &Prog) {
1174+
static bool loadDeviceLibLegacy(const ContextImplPtr Context, const char *Name,
1175+
ur_program_handle_t &Prog) {
11761176
std::string LibSyclDir = OSUtil::getCurrentDSODir();
11771177
std::ifstream File(LibSyclDir + OSUtil::DirSep + Name,
11781178
std::ifstream::in | std::ifstream::binary);
@@ -1192,6 +1192,14 @@ static bool loadDeviceLib(const ContextImplPtr Context, const char *Name,
11921192
return Prog != nullptr;
11931193
}
11941194

1195+
static bool loadDeviceLib(const ContextImplPtr Context,
1196+
ur_program_handle_t &Prog,
1197+
const unsigned char *DeviceLibImageBuffer,
1198+
size_t DeviceLibImageSize) {
1199+
Prog = createSpirvProgram(Context, DeviceLibImageBuffer, DeviceLibImageSize);
1200+
return Prog != nullptr;
1201+
}
1202+
11951203
// For each extension, a pair of library names. The first uses native support,
11961204
// the second emulates functionality in software.
11971205
static const std::map<DeviceLibExt, std::pair<const char *, const char *>>
@@ -1272,9 +1280,13 @@ static ur_result_t doCompile(const AdapterPtr &Adapter,
12721280
static ur_program_handle_t
12731281
loadDeviceLibFallback(const ContextImplPtr Context, DeviceLibExt Extension,
12741282
std::vector<ur_device_handle_t> &Devices,
1275-
bool UseNativeLib) {
1283+
bool UseNativeLib, bool LegacyMode = true,
1284+
const unsigned char *DeviceLibImageBuffer = nullptr,
1285+
size_t DeviceLibImageSize = 0) {
12761286

1277-
auto LibFileName = getDeviceLibFilename(Extension, UseNativeLib);
1287+
const char *LibFileName = nullptr;
1288+
if (LegacyMode)
1289+
LibFileName = getDeviceLibFilename(Extension, UseNativeLib);
12781290
auto LockedCache = Context->acquireCachedLibPrograms();
12791291
auto &CachedLibPrograms = LockedCache.get();
12801292
// Collect list of devices to compile the library for. Library was already
@@ -1311,10 +1323,21 @@ loadDeviceLibFallback(const ContextImplPtr Context, DeviceLibExt Extension,
13111323
bool IsProgramCreated = !URProgram;
13121324

13131325
// Create UR program for device lib if we don't have it yet.
1314-
if (!URProgram && !loadDeviceLib(Context, LibFileName, URProgram)) {
1315-
EraseProgramForDevices();
1316-
throw exception(make_error_code(errc::build),
1317-
std::string("Failed to load ") + LibFileName);
1326+
if (LegacyMode) {
1327+
if (!URProgram && !loadDeviceLibLegacy(Context, LibFileName, URProgram)) {
1328+
EraseProgramForDevices();
1329+
throw exception(make_error_code(errc::build),
1330+
std::string("Failed to load ") + LibFileName);
1331+
}
1332+
} else {
1333+
if (!URProgram && !loadDeviceLib(Context, URProgram, DeviceLibImageBuffer,
1334+
DeviceLibImageSize)) {
1335+
EraseProgramForDevices();
1336+
const char *ExtStr = getDeviceLibExtensionStr(Extension);
1337+
throw exception(
1338+
make_error_code(errc::build),
1339+
std::string("Failed to load fallback device library for ") + ExtStr);
1340+
}
13181341
}
13191342

13201343
// Insert URProgram into the cache for all devices that we compiled it for.
@@ -1573,9 +1596,9 @@ static bool isDeviceLibRequired(DeviceLibExt Ext, uint32_t DeviceLibReqMask) {
15731596
}
15741597

15751598
static std::vector<ur_program_handle_t>
1576-
getDeviceLibPrograms(const ContextImplPtr Context,
1577-
std::vector<ur_device_handle_t> &Devices,
1578-
uint32_t DeviceLibReqMask) {
1599+
getDeviceLibProgramsLegacy(const ContextImplPtr Context,
1600+
std::vector<ur_device_handle_t> &Devices,
1601+
uint32_t DeviceLibReqMask) {
15791602
std::vector<ur_program_handle_t> Programs;
15801603

15811604
std::pair<DeviceLibExt, bool> RequiredDeviceLibExt[] = {
@@ -1658,6 +1681,83 @@ getDeviceLibPrograms(const ContextImplPtr Context,
16581681
return Programs;
16591682
}
16601683

1684+
std::vector<ur_program_handle_t> ProgramManager::getDeviceLibReqPrograms(
1685+
const ContextImplPtr Context, std::vector<ur_device_handle_t> &Devices,
1686+
uint32_t DeviceLibReqMask) {
1687+
1688+
std::vector<ur_program_handle_t> Programs;
1689+
1690+
// Check whether a specified extension is supported by ALL devices.
1691+
auto checkExtForDevices = [&Context, &Devices](const char *ExtStr) -> bool {
1692+
bool ExtAvailable = true;
1693+
for (auto SingleDevice : Devices) {
1694+
std::string DevExtList =
1695+
Context->getPlatformImpl()
1696+
->getDeviceImpl(SingleDevice)
1697+
->get_device_info_string(
1698+
UrInfoCode<info::device::extensions>::value);
1699+
if (DevExtList.npos == DevExtList.find(ExtStr)) {
1700+
ExtAvailable = false;
1701+
break;
1702+
}
1703+
}
1704+
return ExtAvailable;
1705+
};
1706+
1707+
const bool fp64Support = checkExtForDevices("cl_khr_fp64");
1708+
1709+
size_t Idx = 0;
1710+
std::vector<DeviceLibExt> ReqDeviceLibExts;
1711+
while (DeviceLibReqMask != 0) {
1712+
if (DeviceLibReqMask & 1) {
1713+
DeviceLibExt ExtReq = static_cast<DeviceLibExt>(
1714+
static_cast<uint32_t>(DeviceLibExt::cl_intel_devicelib_assert) + Idx);
1715+
ReqDeviceLibExts.push_back(ExtReq);
1716+
}
1717+
++Idx;
1718+
DeviceLibReqMask = DeviceLibReqMask >> 1;
1719+
}
1720+
1721+
std::vector<unsigned> ReqExtMetaKeys;
1722+
for (auto Ext : ReqDeviceLibExts) {
1723+
if ((Ext == DeviceLibExt::cl_intel_devicelib_math_fp64 ||
1724+
Ext == DeviceLibExt::cl_intel_devicelib_complex_fp64 ||
1725+
Ext == DeviceLibExt::cl_intel_devicelib_imf_fp64) &&
1726+
!fp64Support) {
1727+
continue;
1728+
}
1729+
auto ExtName = getDeviceLibExtensionStr(Ext);
1730+
bool InhibitNativeImpl = false;
1731+
if (const char *Env = getenv("SYCL_DEVICELIB_INHIBIT_NATIVE")) {
1732+
InhibitNativeImpl = strstr(Env, ExtName) != nullptr;
1733+
}
1734+
bool ExtReqAvailable = checkExtForDevices(ExtName);
1735+
unsigned ExtMetaKey = static_cast<unsigned>(Ext);
1736+
if (ExtReqAvailable && !InhibitNativeImpl) {
1737+
if (Ext == DeviceLibExt::cl_intel_devicelib_bfloat16) {
1738+
ExtMetaKey = ExtMetaKey | 0x80000000;
1739+
} else
1740+
continue;
1741+
}
1742+
ReqExtMetaKeys.push_back(ExtMetaKey);
1743+
}
1744+
1745+
if (ReqExtMetaKeys.size() > 0) {
1746+
std::lock_guard<std::mutex> DeviceLibImagesGuard(m_DeviceLibImagesMutex);
1747+
for (auto Key : ReqExtMetaKeys) {
1748+
if (m_DeviceLibImages.find(Key) != m_DeviceLibImages.end()) {
1749+
bool IsNative = ((Key & 0x80000000) > 0);
1750+
DeviceLibExt Ext = static_cast<DeviceLibExt>(Key & 0x7FFFFFFF);
1751+
Programs.push_back(loadDeviceLibFallback(
1752+
Context, Ext, Devices, IsNative, false,
1753+
m_DeviceLibImages[Key]->getRawData().BinaryStart,
1754+
m_DeviceLibImages[Key]->getSize()));
1755+
}
1756+
}
1757+
}
1758+
return Programs;
1759+
}
1760+
16611761
// Check if device image is compressed.
16621762
static inline bool isDeviceImageCompressed(sycl_device_binary Bin) {
16631763

@@ -1691,7 +1791,11 @@ ProgramManager::ProgramPtr ProgramManager::build(
16911791

16921792
std::vector<ur_program_handle_t> LinkPrograms;
16931793
if (LinkDeviceLibs) {
1694-
LinkPrograms = getDeviceLibPrograms(Context, Devices, DeviceLibReqMask);
1794+
LinkPrograms = getDeviceLibReqPrograms(Context, Devices, DeviceLibReqMask);
1795+
if (LinkPrograms.size() == 0) {
1796+
LinkPrograms =
1797+
getDeviceLibProgramsLegacy(Context, Devices, DeviceLibReqMask);
1798+
}
16951799
}
16961800

16971801
static const char *ForceLinkEnv = std::getenv("SYCL_FORCE_LINK");

sycl/source/detail/program_manager/program_manager.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,11 @@ class ProgramManager {
217217

218218
uint32_t getDeviceLibReqMask(const RTDeviceBinaryImage &Img);
219219

220+
std::vector<ur_program_handle_t>
221+
getDeviceLibReqPrograms(const ContextImplPtr Context,
222+
std::vector<ur_device_handle_t> &Devices,
223+
uint32_t DeviceLibReqMask);
224+
220225
/// Returns the mask for eliminated kernel arguments for the requested kernel
221226
/// within the native program.
222227
/// \param NativePrg the UR program associated with the kernel.

0 commit comments

Comments
 (0)