@@ -614,22 +614,24 @@ static bool compatibleWithDevice(RTDeviceBinaryImage *BinImage,
614614}
615615
616616// Quick check to see whether BinImage is a compiler-generated device image.
617- static bool isSpecialDeviceImage (RTDeviceBinaryImage *BinImage) {
617+ bool ProgramManager:: isSpecialDeviceImage (RTDeviceBinaryImage *BinImage) {
618618 // SYCL devicelib image.
619- if (BinImage->getDeviceLibMetadata ().isAvailable ())
619+ if ((m_Bfloat16DeviceLibImages[0 ].get () == BinImage) ||
620+ m_Bfloat16DeviceLibImages[1 ].get () == BinImage)
620621 return true ;
621622
622623 return false ;
623624}
624625
625- static bool isSpecialDeviceImageShouldBeUsed (RTDeviceBinaryImage *BinImage,
626- const device &Dev) {
626+ bool ProgramManager:: isSpecialDeviceImageShouldBeUsed (
627+ RTDeviceBinaryImage *BinImage, const device &Dev) {
627628 // Decide whether a devicelib image should be used.
628- if (BinImage->getDeviceLibMetadata ().isAvailable ()) {
629- const RTDeviceBinaryImage::PropertyRange &DeviceLibMetaProp =
630- BinImage->getDeviceLibMetadata ();
631- uint32_t DeviceLibMeta =
632- DeviceBinaryProperty (*(DeviceLibMetaProp.begin ())).asUint32 ();
629+ int Bfloat16DeviceLibVersion = -1 ;
630+ if (m_Bfloat16DeviceLibImages[0 ].get () == BinImage)
631+ Bfloat16DeviceLibVersion = 0 ;
632+ if (m_Bfloat16DeviceLibImages[1 ].get () == BinImage)
633+ Bfloat16DeviceLibVersion = 1 ;
634+ if (Bfloat16DeviceLibVersion != -1 ) {
633635 // Currently, only bfloat conversion devicelib are supported, so the prop
634636 // DeviceLibMeta are only used to represent fallback or native version.
635637 // For bfloat16 conversion devicelib, we have fallback and native version.
@@ -643,7 +645,8 @@ static bool isSpecialDeviceImageShouldBeUsed(RTDeviceBinaryImage *BinImage,
643645 detail::getSyclObjImpl (Dev);
644646 std::string NativeBF16ExtName = " cl_intel_bfloat16_conversions" ;
645647 bool NativeBF16Supported = (DeviceImpl->has_extension (NativeBF16ExtName));
646- return NativeBF16Supported == (DeviceLibMeta == DEVICELIB_NATIVE);
648+ return NativeBF16Supported ==
649+ (Bfloat16DeviceLibVersion == DEVICELIB_NATIVE);
647650 }
648651
649652 return false ;
@@ -1837,17 +1840,53 @@ ProgramManager::kernelImplicitLocalArgPos(const std::string &KernelName) const {
18371840 return {};
18381841}
18391842
1840- static bool shouldSkipEmptyImage (sycl_device_binary RawImg) {
1841- // For bfloat16 device library image, we should keep it although it doesn't
1842- // include any kernel.
1843+ static bool isBfloat16DeviceLibImage (sycl_device_binary RawImg,
1844+ uint32_t *LibVersion = nullptr ) {
18431845 sycl_device_binary_property_set ImgPS;
18441846 for (ImgPS = RawImg->PropertySetsBegin ; ImgPS != RawImg->PropertySetsEnd ;
18451847 ++ImgPS) {
18461848 if (ImgPS->Name &&
1847- !strcmp (__SYCL_PROPERTY_SET_DEVICELIB_METADATA, ImgPS->Name ))
1848- return false ;
1849+ !strcmp (__SYCL_PROPERTY_SET_DEVICELIB_METADATA, ImgPS->Name )) {
1850+ if (!LibVersion)
1851+ return true ;
1852+
1853+ *LibVersion = 0 ;
1854+ sycl_device_binary_property ImgP;
1855+ for (ImgP = ImgPS->PropertiesBegin ; ImgP != ImgPS->PropertiesEnd ;
1856+ ++ImgP) {
1857+ if (ImgP->Name && !strcmp (" bfloat16" , ImgP->Name ) &&
1858+ (ImgP->Type == SYCL_PROPERTY_TYPE_UINT32))
1859+ break ;
1860+ }
1861+ if (ImgP != ImgPS->PropertiesEnd )
1862+ *LibVersion = DeviceBinaryProperty (ImgP).asUint32 ();
1863+ return true ;
1864+ }
1865+ }
1866+
1867+ return false ;
1868+ }
1869+
1870+ static sycl_device_binary_property_set
1871+ getExportedSymbolPS (sycl_device_binary RawImg) {
1872+ sycl_device_binary_property_set ImgPS;
1873+ for (ImgPS = RawImg->PropertySetsBegin ; ImgPS != RawImg->PropertySetsEnd ;
1874+ ++ImgPS) {
1875+ if (ImgPS->Name &&
1876+ !strcmp (__SYCL_PROPERTY_SET_SYCL_EXPORTED_SYMBOLS, ImgPS->Name ))
1877+ return ImgPS;
18491878 }
18501879
1880+ return nullptr ;
1881+ }
1882+
1883+ static bool shouldSkipEmptyImage (sycl_device_binary RawImg) {
1884+ // For bfloat16 device library image, we should keep it although it doesn't
1885+ // include any kernel.
1886+ if (isBfloat16DeviceLibImage (RawImg))
1887+ return false ;
1888+
1889+ // We may extend the logic here other than bfloat16 device library image.
18511890 return true ;
18521891}
18531892
@@ -1861,6 +1900,8 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
18611900 continue ;
18621901
18631902 std::unique_ptr<RTDeviceBinaryImage> Img;
1903+ bool IsBfloat16DeviceLib = false ;
1904+ uint32_t Bfloat16DeviceLibVersion = 0 ;
18641905 if (isDeviceImageCompressed (RawImg))
18651906#ifndef SYCL_RT_ZSTD_NOT_AVAIABLE
18661907 Img = std::make_unique<CompressedRTDeviceBinaryImage>(RawImg);
@@ -1870,40 +1911,57 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
18701911 " SYCL RT was built without ZSTD support."
18711912 " Aborting. " );
18721913#endif
1873- else
1874- Img = std::make_unique<RTDeviceBinaryImage>(RawImg);
1914+ else {
1915+ IsBfloat16DeviceLib =
1916+ isBfloat16DeviceLibImage (RawImg, &Bfloat16DeviceLibVersion);
1917+ if (!IsBfloat16DeviceLib)
1918+ Img = std::make_unique<RTDeviceBinaryImage>(RawImg);
1919+ }
18751920
18761921 static uint32_t SequenceID = 0 ;
18771922
1878- // Fill the kernel argument mask map
1879- const RTDeviceBinaryImage::PropertyRange &KPOIRange =
1880- Img->getKernelParamOptInfo ();
1881- if (KPOIRange.isAvailable ()) {
1882- KernelNameToArgMaskMap &ArgMaskMap =
1883- m_EliminatedKernelArgMasks[Img.get ()];
1884- for (const auto &Info : KPOIRange)
1885- ArgMaskMap[Info->Name ] =
1886- createKernelArgMask (DeviceBinaryProperty (Info).asByteArray ());
1923+ // Fill the kernel argument mask map, no need to do this for bfloat16
1924+ // device library image since it doesn't include any kernel.
1925+ if (!IsBfloat16DeviceLib) {
1926+ const RTDeviceBinaryImage::PropertyRange &KPOIRange =
1927+ Img->getKernelParamOptInfo ();
1928+ if (KPOIRange.isAvailable ()) {
1929+ KernelNameToArgMaskMap &ArgMaskMap =
1930+ m_EliminatedKernelArgMasks[Img.get ()];
1931+ for (const auto &Info : KPOIRange)
1932+ ArgMaskMap[Info->Name ] =
1933+ createKernelArgMask (DeviceBinaryProperty (Info).asByteArray ());
1934+ }
18871935 }
18881936
18891937 // Fill maps for kernel bundles
18901938 std::lock_guard<std::mutex> KernelIDsGuard (m_KernelIDsMutex);
18911939
18921940 // For bfloat16 device library image, it doesn't include any kernel, device
1893- // global, virtual function, so just skip adding it to any related maps. We
1894- // only need to 1) add exported symbols to m_ExportedSymbolImages, and 2)
1895- // add the device image to m_Bfloat16DeviceLibImages.
1941+ // global, virtual function, so just skip adding it to any related maps.
1942+ // The bfloat16 device library are provided by compiler and may be used by
1943+ // different sycl device images, program manager will own single copy for
1944+ // native and fallback version bfloat16 device library, these device
1945+ // library images will not be erased unless program manager is destroyed.
18961946 {
1897- auto Bfloat16DeviceLibProp = Img->getDeviceLibMetadata ();
1898- if (Bfloat16DeviceLibProp.isAvailable ()) {
1899- uint32_t LibVersion = DeviceBinaryProperty (*(Bfloat16DeviceLibProp.begin ())).asUint32 ();
1900- if (m_Bfloat16DeviceLibImages.count (LibVersion) > 0 )
1947+ if (IsBfloat16DeviceLib) {
1948+ if (m_Bfloat16DeviceLibImages.count (Bfloat16DeviceLibVersion) > 0 )
19011949 continue ;
1902- for (const sycl_device_binary_property &ESProp :
1903- Img->getExportedSymbols ()) {
1904- m_ExportedSymbolImages.insert ({ESProp->Name , Img.get ()});
1950+ size_t ImgSize =
1951+ static_cast <size_t >(RawImg->BinaryEnd - RawImg->BinaryStart );
1952+ std::unique_ptr<char []> Data (new char [ImgSize]);
1953+ std::memcpy (Data.get (), RawImg->BinaryStart , ImgSize);
1954+ auto DynBfloat16DeviceLibImg =
1955+ std::make_unique<DynRTDeviceBinaryImage>(std::move (Data), ImgSize);
1956+ auto ESPropSet = getExportedSymbolPS (RawImg);
1957+ sycl_device_binary_property ESProp;
1958+ for (ESProp = ESPropSet->PropertiesBegin ;
1959+ ESProp != ESPropSet->PropertiesEnd ; ++ESProp) {
1960+ m_ExportedSymbolImages.insert (
1961+ {ESProp->Name , DynBfloat16DeviceLibImg.get ()});
19051962 }
1906- m_Bfloat16DeviceLibImages.insert ({LibVersion, std::move (Img)});
1963+ m_Bfloat16DeviceLibImages.insert (
1964+ {Bfloat16DeviceLibVersion, std::move (DynBfloat16DeviceLibImg)});
19071965 continue ;
19081966 }
19091967 }
@@ -2824,12 +2882,8 @@ static void mergeImageData(const std::vector<device_image_plain> &Imgs,
28242882 for (const device_image_plain &Img : Imgs) {
28252883 const std::shared_ptr<device_image_impl> &DeviceImageImpl =
28262884 getSyclObjImpl (Img);
2827- auto BinImgRef = DeviceImageImpl->get_bin_image_ref ();
2828- // For bfloat16 deice library image, no kernels, spec const are included,
2829- // so we just skip merging data.
2830- if (BinImgRef && BinImgRef->getDeviceLibMetadata ().isAvailable ())
2831- continue ;
28322885 // Duplicates are not expected here, otherwise urProgramLink should fail
2886+ if (DeviceImageImpl->get_kernel_ids_ptr ())
28332887 KernelIDs.insert (KernelIDs.end (),
28342888 DeviceImageImpl->get_kernel_ids_ptr ()->begin (),
28352889 DeviceImageImpl->get_kernel_ids_ptr ()->end ());
0 commit comments