Skip to content

Commit 1351401

Browse files
committed
Create DynRTDeviceBinaryImage for bf16 devicelib
Signed-off-by: jinge90 <[email protected]>
1 parent aa60bc5 commit 1351401

File tree

3 files changed

+106
-48
lines changed

3 files changed

+106
-48
lines changed

sycl/source/detail/program_manager/program_manager.cpp

Lines changed: 96 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -614,22 +614,24 @@ static bool compatibleWithDevice(RTDeviceBinaryImage *BinImage,
614614
}
615615

616616
// Quick check to see whether BinImage is a compiler-generated device image.
617-
static bool isSpecialDeviceImage(RTDeviceBinaryImage *BinImage) {
617+
bool ProgramManager::isSpecialDeviceImage(RTDeviceBinaryImage *BinImage) {
618618
// SYCL devicelib image.
619-
if (BinImage->getDeviceLibMetadata().isAvailable())
619+
if ((m_Bfloat16DeviceLibImages[0].get() == BinImage) ||
620+
m_Bfloat16DeviceLibImages[1].get() == BinImage)
620621
return true;
621622

622623
return false;
623624
}
624625

625-
static bool isSpecialDeviceImageShouldBeUsed(RTDeviceBinaryImage *BinImage,
626-
const device &Dev) {
626+
bool ProgramManager::isSpecialDeviceImageShouldBeUsed(
627+
RTDeviceBinaryImage *BinImage, const device &Dev) {
627628
// Decide whether a devicelib image should be used.
628-
if (BinImage->getDeviceLibMetadata().isAvailable()) {
629-
const RTDeviceBinaryImage::PropertyRange &DeviceLibMetaProp =
630-
BinImage->getDeviceLibMetadata();
631-
uint32_t DeviceLibMeta =
632-
DeviceBinaryProperty(*(DeviceLibMetaProp.begin())).asUint32();
629+
int Bfloat16DeviceLibVersion = -1;
630+
if (m_Bfloat16DeviceLibImages[0].get() == BinImage)
631+
Bfloat16DeviceLibVersion = 0;
632+
if (m_Bfloat16DeviceLibImages[1].get() == BinImage)
633+
Bfloat16DeviceLibVersion = 1;
634+
if (Bfloat16DeviceLibVersion != -1) {
633635
// Currently, only bfloat conversion devicelib are supported, so the prop
634636
// DeviceLibMeta are only used to represent fallback or native version.
635637
// For bfloat16 conversion devicelib, we have fallback and native version.
@@ -643,7 +645,8 @@ static bool isSpecialDeviceImageShouldBeUsed(RTDeviceBinaryImage *BinImage,
643645
detail::getSyclObjImpl(Dev);
644646
std::string NativeBF16ExtName = "cl_intel_bfloat16_conversions";
645647
bool NativeBF16Supported = (DeviceImpl->has_extension(NativeBF16ExtName));
646-
return NativeBF16Supported == (DeviceLibMeta == DEVICELIB_NATIVE);
648+
return NativeBF16Supported ==
649+
(Bfloat16DeviceLibVersion == DEVICELIB_NATIVE);
647650
}
648651

649652
return false;
@@ -1837,17 +1840,53 @@ ProgramManager::kernelImplicitLocalArgPos(const std::string &KernelName) const {
18371840
return {};
18381841
}
18391842

1840-
static bool shouldSkipEmptyImage(sycl_device_binary RawImg) {
1841-
// For bfloat16 device library image, we should keep it although it doesn't
1842-
// include any kernel.
1843+
static bool isBfloat16DeviceLibImage(sycl_device_binary RawImg,
1844+
uint32_t *LibVersion = nullptr) {
18431845
sycl_device_binary_property_set ImgPS;
18441846
for (ImgPS = RawImg->PropertySetsBegin; ImgPS != RawImg->PropertySetsEnd;
18451847
++ImgPS) {
18461848
if (ImgPS->Name &&
1847-
!strcmp(__SYCL_PROPERTY_SET_DEVICELIB_METADATA, ImgPS->Name))
1848-
return false;
1849+
!strcmp(__SYCL_PROPERTY_SET_DEVICELIB_METADATA, ImgPS->Name)) {
1850+
if (!LibVersion)
1851+
return true;
1852+
1853+
*LibVersion = 0;
1854+
sycl_device_binary_property ImgP;
1855+
for (ImgP = ImgPS->PropertiesBegin; ImgP != ImgPS->PropertiesEnd;
1856+
++ImgP) {
1857+
if (ImgP->Name && !strcmp("bfloat16", ImgP->Name) &&
1858+
(ImgP->Type == SYCL_PROPERTY_TYPE_UINT32))
1859+
break;
1860+
}
1861+
if (ImgP != ImgPS->PropertiesEnd)
1862+
*LibVersion = DeviceBinaryProperty(ImgP).asUint32();
1863+
return true;
1864+
}
1865+
}
1866+
1867+
return false;
1868+
}
1869+
1870+
static sycl_device_binary_property_set
1871+
getExportedSymbolPS(sycl_device_binary RawImg) {
1872+
sycl_device_binary_property_set ImgPS;
1873+
for (ImgPS = RawImg->PropertySetsBegin; ImgPS != RawImg->PropertySetsEnd;
1874+
++ImgPS) {
1875+
if (ImgPS->Name &&
1876+
!strcmp(__SYCL_PROPERTY_SET_SYCL_EXPORTED_SYMBOLS, ImgPS->Name))
1877+
return ImgPS;
18491878
}
18501879

1880+
return nullptr;
1881+
}
1882+
1883+
static bool shouldSkipEmptyImage(sycl_device_binary RawImg) {
1884+
// For bfloat16 device library image, we should keep it although it doesn't
1885+
// include any kernel.
1886+
if (isBfloat16DeviceLibImage(RawImg))
1887+
return false;
1888+
1889+
// We may extend the logic here other than bfloat16 device library image.
18511890
return true;
18521891
}
18531892

@@ -1861,6 +1900,8 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
18611900
continue;
18621901

18631902
std::unique_ptr<RTDeviceBinaryImage> Img;
1903+
bool IsBfloat16DeviceLib = false;
1904+
uint32_t Bfloat16DeviceLibVersion = 0;
18641905
if (isDeviceImageCompressed(RawImg))
18651906
#ifndef SYCL_RT_ZSTD_NOT_AVAIABLE
18661907
Img = std::make_unique<CompressedRTDeviceBinaryImage>(RawImg);
@@ -1870,40 +1911,57 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
18701911
"SYCL RT was built without ZSTD support."
18711912
"Aborting. ");
18721913
#endif
1873-
else
1874-
Img = std::make_unique<RTDeviceBinaryImage>(RawImg);
1914+
else {
1915+
IsBfloat16DeviceLib =
1916+
isBfloat16DeviceLibImage(RawImg, &Bfloat16DeviceLibVersion);
1917+
if (!IsBfloat16DeviceLib)
1918+
Img = std::make_unique<RTDeviceBinaryImage>(RawImg);
1919+
}
18751920

18761921
static uint32_t SequenceID = 0;
18771922

1878-
// Fill the kernel argument mask map
1879-
const RTDeviceBinaryImage::PropertyRange &KPOIRange =
1880-
Img->getKernelParamOptInfo();
1881-
if (KPOIRange.isAvailable()) {
1882-
KernelNameToArgMaskMap &ArgMaskMap =
1883-
m_EliminatedKernelArgMasks[Img.get()];
1884-
for (const auto &Info : KPOIRange)
1885-
ArgMaskMap[Info->Name] =
1886-
createKernelArgMask(DeviceBinaryProperty(Info).asByteArray());
1923+
// Fill the kernel argument mask map, no need to do this for bfloat16
1924+
// device library image since it doesn't include any kernel.
1925+
if (!IsBfloat16DeviceLib) {
1926+
const RTDeviceBinaryImage::PropertyRange &KPOIRange =
1927+
Img->getKernelParamOptInfo();
1928+
if (KPOIRange.isAvailable()) {
1929+
KernelNameToArgMaskMap &ArgMaskMap =
1930+
m_EliminatedKernelArgMasks[Img.get()];
1931+
for (const auto &Info : KPOIRange)
1932+
ArgMaskMap[Info->Name] =
1933+
createKernelArgMask(DeviceBinaryProperty(Info).asByteArray());
1934+
}
18871935
}
18881936

18891937
// Fill maps for kernel bundles
18901938
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
18911939

18921940
// For bfloat16 device library image, it doesn't include any kernel, device
1893-
// global, virtual function, so just skip adding it to any related maps. We
1894-
// only need to 1) add exported symbols to m_ExportedSymbolImages, and 2)
1895-
// add the device image to m_Bfloat16DeviceLibImages.
1941+
// global, virtual function, so just skip adding it to any related maps.
1942+
// The bfloat16 device library are provided by compiler and may be used by
1943+
// different sycl device images, program manager will own single copy for
1944+
// native and fallback version bfloat16 device library, these device
1945+
// library images will not be erased unless program manager is destroyed.
18961946
{
1897-
auto Bfloat16DeviceLibProp = Img->getDeviceLibMetadata();
1898-
if (Bfloat16DeviceLibProp.isAvailable()) {
1899-
uint32_t LibVersion = DeviceBinaryProperty(*(Bfloat16DeviceLibProp.begin())).asUint32();
1900-
if (m_Bfloat16DeviceLibImages.count(LibVersion) > 0)
1947+
if (IsBfloat16DeviceLib) {
1948+
if (m_Bfloat16DeviceLibImages.count(Bfloat16DeviceLibVersion) > 0)
19011949
continue;
1902-
for (const sycl_device_binary_property &ESProp :
1903-
Img->getExportedSymbols()) {
1904-
m_ExportedSymbolImages.insert({ESProp->Name, Img.get()});
1950+
size_t ImgSize =
1951+
static_cast<size_t>(RawImg->BinaryEnd - RawImg->BinaryStart);
1952+
std::unique_ptr<char[]> Data(new char[ImgSize]);
1953+
std::memcpy(Data.get(), RawImg->BinaryStart, ImgSize);
1954+
auto DynBfloat16DeviceLibImg =
1955+
std::make_unique<DynRTDeviceBinaryImage>(std::move(Data), ImgSize);
1956+
auto ESPropSet = getExportedSymbolPS(RawImg);
1957+
sycl_device_binary_property ESProp;
1958+
for (ESProp = ESPropSet->PropertiesBegin;
1959+
ESProp != ESPropSet->PropertiesEnd; ++ESProp) {
1960+
m_ExportedSymbolImages.insert(
1961+
{ESProp->Name, DynBfloat16DeviceLibImg.get()});
19051962
}
1906-
m_Bfloat16DeviceLibImages.insert({LibVersion, std::move(Img)});
1963+
m_Bfloat16DeviceLibImages.insert(
1964+
{Bfloat16DeviceLibVersion, std::move(DynBfloat16DeviceLibImg)});
19071965
continue;
19081966
}
19091967
}
@@ -2824,12 +2882,8 @@ static void mergeImageData(const std::vector<device_image_plain> &Imgs,
28242882
for (const device_image_plain &Img : Imgs) {
28252883
const std::shared_ptr<device_image_impl> &DeviceImageImpl =
28262884
getSyclObjImpl(Img);
2827-
auto BinImgRef = DeviceImageImpl->get_bin_image_ref();
2828-
// For bfloat16 deice library image, no kernels, spec const are included,
2829-
// so we just skip merging data.
2830-
if (BinImgRef && BinImgRef->getDeviceLibMetadata().isAvailable())
2831-
continue;
28322885
// Duplicates are not expected here, otherwise urProgramLink should fail
2886+
if (DeviceImageImpl->get_kernel_ids_ptr())
28332887
KernelIDs.insert(KernelIDs.end(),
28342888
DeviceImageImpl->get_kernel_ids_ptr()->begin(),
28352889
DeviceImageImpl->get_kernel_ids_ptr()->end());

sycl/source/detail/program_manager/program_manager.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -376,11 +376,15 @@ class ProgramManager {
376376
collectDependentDeviceImagesForVirtualFunctions(
377377
const RTDeviceBinaryImage &Img, const device &Dev);
378378

379+
bool isSpecialDeviceImage(RTDeviceBinaryImage *BinImage);
380+
bool isSpecialDeviceImageShouldBeUsed(RTDeviceBinaryImage *BinImage,
381+
const device &Dev);
382+
379383
protected:
380384
/// The three maps below are used during kernel resolution. Any kernel is
381385
/// identified by its name.
382386
using RTDeviceBinaryImageUPtr = std::unique_ptr<RTDeviceBinaryImage>;
383-
387+
using DynRTDeviceBinaryImageUPtr = std::unique_ptr<DynRTDeviceBinaryImage>;
384388
/// Maps names of kernels to their unique kernel IDs.
385389
/// TODO: Use std::unordered_set with transparent hash and equality functions
386390
/// when C++20 is enabled for the runtime library.
@@ -502,7 +506,7 @@ class ProgramManager {
502506
// and 1 for native version. These bfloat16 device library images are
503507
// provided by compiler long time ago, we expect no further update, so
504508
// keeping 1 copy should be OK.
505-
std::unordered_map<uint32_t, RTDeviceBinaryImageUPtr>
509+
std::unordered_map<uint32_t, DynRTDeviceBinaryImageUPtr>
506510
m_Bfloat16DeviceLibImages;
507511

508512
friend class ::ProgramManagerTest;

sycl/test-e2e/KernelCompiler/sycl.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -533,11 +533,11 @@ int main() {
533533
if (!ok) {
534534
return -1;
535535
}
536-
536+
// Run test_device_libraries twice to verify bfloat16 device library.
537537
return test_build_and_run(q) || test_device_code_split(q) ||
538-
test_device_libraries(q) || test_esimd(q) ||
539-
test_unsupported_options(q) || test_error(q) ||
540-
test_no_visible_ids(q) || test_warning(q);
538+
test_device_libraries(q) || test_device_libraries(q) ||
539+
test_device_libraries(q) || test_unsupported_options(q) ||
540+
test_error(q) || test_no_visible_ids(q) || test_warning(q);
541541
#else
542542
static_assert(false, "Kernel Compiler feature test macro undefined");
543543
#endif

0 commit comments

Comments
 (0)