@@ -614,22 +614,25 @@ static bool compatibleWithDevice(RTDeviceBinaryImage *BinImage,
614614}
615615
616616// Quick check to see whether BinImage is a compiler-generated device image.
617- static bool isSpecialDeviceImage (RTDeviceBinaryImage *BinImage) {
617+ bool ProgramManager:: isSpecialDeviceImage (RTDeviceBinaryImage *BinImage) {
618618 // SYCL devicelib image.
619- if (BinImage->getDeviceLibMetadata ().isAvailable ())
619+ if ((m_Bfloat16DeviceLibImages[0 ].get () == BinImage) ||
620+ m_Bfloat16DeviceLibImages[1 ].get () == BinImage)
620621 return true ;
621622
622623 return false ;
623624}
624625
625- static bool isSpecialDeviceImageShouldBeUsed (RTDeviceBinaryImage *BinImage,
626- const device &Dev) {
626+ bool ProgramManager:: isSpecialDeviceImageShouldBeUsed (
627+ RTDeviceBinaryImage *BinImage, const device &Dev) {
627628 // Decide whether a devicelib image should be used.
628- if (BinImage->getDeviceLibMetadata ().isAvailable ()) {
629- const RTDeviceBinaryImage::PropertyRange &DeviceLibMetaProp =
630- BinImage->getDeviceLibMetadata ();
631- uint32_t DeviceLibMeta =
632- DeviceBinaryProperty (*(DeviceLibMetaProp.begin ())).asUint32 ();
629+ int Bfloat16DeviceLibVersion = -1 ;
630+ if (m_Bfloat16DeviceLibImages[0 ].get () == BinImage)
631+ Bfloat16DeviceLibVersion = 0 ;
632+ else if (m_Bfloat16DeviceLibImages[1 ].get () == BinImage)
633+ Bfloat16DeviceLibVersion = 1 ;
634+
635+ if (Bfloat16DeviceLibVersion != -1 ) {
633636 // Currently, only bfloat conversion devicelib are supported, so the prop
634637 // DeviceLibMeta are only used to represent fallback or native version.
635638 // For bfloat16 conversion devicelib, we have fallback and native version.
@@ -643,7 +646,8 @@ static bool isSpecialDeviceImageShouldBeUsed(RTDeviceBinaryImage *BinImage,
643646 detail::getSyclObjImpl (Dev);
644647 std::string NativeBF16ExtName = " cl_intel_bfloat16_conversions" ;
645648 bool NativeBF16Supported = (DeviceImpl->has_extension (NativeBF16ExtName));
646- return NativeBF16Supported == (DeviceLibMeta == DEVICELIB_NATIVE);
649+ return NativeBF16Supported ==
650+ (Bfloat16DeviceLibVersion == DEVICELIB_NATIVE);
647651 }
648652
649653 return false ;
@@ -1837,87 +1841,69 @@ ProgramManager::kernelImplicitLocalArgPos(const std::string &KernelName) const {
18371841 return {};
18381842}
18391843
1840- static bool shouldSkipEmptyImage (sycl_device_binary RawImg, bool IsRTC) {
1841- // For bfloat16 device library image, we should keep it. However, in some
1842- // scenario, __sycl_register_lib can be called multiple times and the same
1843- // bfloat16 device library image may be handled multiple times which is not
1844- // needed. 2 static bool variables are created to record whether native or
1845- // fallback bfloat16 device library image has been handled, if yes, we just
1846- // need to skip it.
1847- // We cannot prevent redundant loads of device library images if they are part
1848- // of a runtime-compiled device binary, as these will be freed when the
1849- // corresponding kernel bundle is destroyed. Hence, normal kernels cannot rely
1850- // on the presence of RTC device library images.
1844+ static bool isBfloat16DeviceLibImage (sycl_device_binary RawImg,
1845+ uint32_t *LibVersion = nullptr ) {
18511846 sycl_device_binary_property_set ImgPS;
1852- static bool IsNativeBF16DeviceLibHandled = false ;
1853- static bool IsFallbackBF16DeviceLibHandled = false ;
18541847 for (ImgPS = RawImg->PropertySetsBegin ; ImgPS != RawImg->PropertySetsEnd ;
18551848 ++ImgPS) {
18561849 if (ImgPS->Name &&
18571850 !strcmp (__SYCL_PROPERTY_SET_DEVICELIB_METADATA, ImgPS->Name )) {
1851+ if (!LibVersion)
1852+ return true ;
1853+
1854+ // Valid version for bfloat16 device library is 0(fallback), 1(native).
1855+ *LibVersion = 2 ;
18581856 sycl_device_binary_property ImgP;
18591857 for (ImgP = ImgPS->PropertiesBegin ; ImgP != ImgPS->PropertiesEnd ;
18601858 ++ImgP) {
18611859 if (ImgP->Name && !strcmp (" bfloat16" , ImgP->Name ) &&
18621860 (ImgP->Type == SYCL_PROPERTY_TYPE_UINT32))
18631861 break ;
18641862 }
1865- if (ImgP == ImgPS->PropertiesEnd )
1866- return true ;
1867-
1868- // A valid bfloat16 device library image is found here.
1869- // If it originated from RTC, we cannot skip it, but do not mark it as
1870- // being present.
1871- if (IsRTC)
1872- return false ;
1873-
1874- // Otherwise, we need to check whether it has been handled already.
1875- uint32_t BF16NativeVal = DeviceBinaryProperty (ImgP).asUint32 ();
1876- if (((BF16NativeVal == 0 ) && IsFallbackBF16DeviceLibHandled) ||
1877- ((BF16NativeVal == 1 ) && IsNativeBF16DeviceLibHandled))
1878- return true ;
1879-
1880- if (BF16NativeVal == 0 )
1881- IsFallbackBF16DeviceLibHandled = true ;
1882- else
1883- IsNativeBF16DeviceLibHandled = true ;
1884-
1885- return false ;
1863+ if (ImgP != ImgPS->PropertiesEnd )
1864+ *LibVersion = DeviceBinaryProperty (ImgP).asUint32 ();
1865+ return true ;
18861866 }
18871867 }
1888- return true ;
1868+
1869+ return false ;
18891870}
18901871
1891- static bool isCompiledAtRuntime (sycl_device_binaries DeviceBinary) {
1892- // Check whether the first device binary contains a legacy format offload
1893- // entry with a `$` in its name.
1894- if (DeviceBinary->NumDeviceBinaries > 0 ) {
1895- sycl_device_binary Binary = DeviceBinary->DeviceBinaries ;
1896- if (Binary->EntriesBegin != Binary->EntriesEnd ) {
1897- sycl_offload_entry Entry = Binary->EntriesBegin ;
1898- if (!Entry->IsNewOffloadEntryType () &&
1899- std::string_view{Entry->name }.find (' $' ) != std::string_view::npos) {
1900- return true ;
1901- }
1902- }
1872+ static sycl_device_binary_property_set
1873+ getExportedSymbolPS (sycl_device_binary RawImg) {
1874+ sycl_device_binary_property_set ImgPS;
1875+ for (ImgPS = RawImg->PropertySetsBegin ; ImgPS != RawImg->PropertySetsEnd ;
1876+ ++ImgPS) {
1877+ if (ImgPS->Name &&
1878+ !strcmp (__SYCL_PROPERTY_SET_SYCL_EXPORTED_SYMBOLS, ImgPS->Name ))
1879+ return ImgPS;
19031880 }
1904- return false ;
1881+
1882+ return nullptr ;
1883+ }
1884+
1885+ static bool shouldSkipEmptyImage (sycl_device_binary RawImg) {
1886+ // For bfloat16 device library image, we should keep it although it doesn't
1887+ // include any kernel.
1888+ if (isBfloat16DeviceLibImage (RawImg))
1889+ return false ;
1890+
1891+ // We may extend the logic here other than bfloat16 device library image.
1892+ return true ;
19051893}
19061894
19071895void ProgramManager::addImages (sycl_device_binaries DeviceBinary) {
19081896 const bool DumpImages = std::getenv (" SYCL_DUMP_IMAGES" ) && !m_UseSpvFile;
1909- const bool IsRTC = isCompiledAtRuntime (DeviceBinary);
19101897 for (int I = 0 ; I < DeviceBinary->NumDeviceBinaries ; I++) {
19111898 sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries [I]);
19121899 const sycl_offload_entry EntriesB = RawImg->EntriesBegin ;
19131900 const sycl_offload_entry EntriesE = RawImg->EntriesEnd ;
1914- // If the image does not contain kernels, skip it unless it is one of the
1915- // bfloat16 device libraries, and it wasn't loaded before or resulted from
1916- // runtime compilation.
1917- if ((EntriesB == EntriesE) && shouldSkipEmptyImage (RawImg, IsRTC))
1901+ if ((EntriesB == EntriesE) && shouldSkipEmptyImage (RawImg))
19181902 continue ;
19191903
19201904 std::unique_ptr<RTDeviceBinaryImage> Img;
1905+ bool IsBfloat16DeviceLib = false ;
1906+ uint32_t Bfloat16DeviceLibVersion = 0 ;
19211907 if (isDeviceImageCompressed (RawImg))
19221908#ifndef SYCL_RT_ZSTD_NOT_AVAIABLE
19231909 Img = std::make_unique<CompressedRTDeviceBinaryImage>(RawImg);
@@ -1927,25 +1913,63 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
19271913 " SYCL RT was built without ZSTD support."
19281914 " Aborting. " );
19291915#endif
1930- else
1931- Img = std::make_unique<RTDeviceBinaryImage>(RawImg);
1916+ else {
1917+ IsBfloat16DeviceLib =
1918+ isBfloat16DeviceLibImage (RawImg, &Bfloat16DeviceLibVersion);
1919+ if (!IsBfloat16DeviceLib)
1920+ Img = std::make_unique<RTDeviceBinaryImage>(RawImg);
1921+ }
19321922
19331923 static uint32_t SequenceID = 0 ;
19341924
1935- // Fill the kernel argument mask map
1936- const RTDeviceBinaryImage::PropertyRange &KPOIRange =
1937- Img->getKernelParamOptInfo ();
1938- if (KPOIRange.isAvailable ()) {
1939- KernelNameToArgMaskMap &ArgMaskMap =
1940- m_EliminatedKernelArgMasks[Img.get ()];
1941- for (const auto &Info : KPOIRange)
1942- ArgMaskMap[Info->Name ] =
1943- createKernelArgMask (DeviceBinaryProperty (Info).asByteArray ());
1925+ // Fill the kernel argument mask map, no need to do this for bfloat16
1926+ // device library image since it doesn't include any kernel.
1927+ if (!IsBfloat16DeviceLib) {
1928+ const RTDeviceBinaryImage::PropertyRange &KPOIRange =
1929+ Img->getKernelParamOptInfo ();
1930+ if (KPOIRange.isAvailable ()) {
1931+ KernelNameToArgMaskMap &ArgMaskMap =
1932+ m_EliminatedKernelArgMasks[Img.get ()];
1933+ for (const auto &Info : KPOIRange)
1934+ ArgMaskMap[Info->Name ] =
1935+ createKernelArgMask (DeviceBinaryProperty (Info).asByteArray ());
1936+ }
19441937 }
19451938
19461939 // Fill maps for kernel bundles
19471940 std::lock_guard<std::mutex> KernelIDsGuard (m_KernelIDsMutex);
19481941
1942+ // For bfloat16 device library image, it doesn't include any kernel, device
1943+ // global, virtual function, so just skip adding it to any related maps.
1944+ // The bfloat16 device library are provided by compiler and may be used by
1945+ // different sycl device images, program manager will own single copy for
1946+ // native and fallback version bfloat16 device library, these device
1947+ // library images will not be erased unless program manager is destroyed.
1948+ {
1949+ if (IsBfloat16DeviceLib) {
1950+ assert ((Bfloat16DeviceLibVersion < 2 ) &&
1951+ " Invalid Bfloat16 Device Library Index." );
1952+ if (m_Bfloat16DeviceLibImages[Bfloat16DeviceLibVersion].get ())
1953+ continue ;
1954+ size_t ImgSize =
1955+ static_cast <size_t >(RawImg->BinaryEnd - RawImg->BinaryStart );
1956+ std::unique_ptr<char []> Data (new char [ImgSize]);
1957+ std::memcpy (Data.get (), RawImg->BinaryStart , ImgSize);
1958+ auto DynBfloat16DeviceLibImg =
1959+ std::make_unique<DynRTDeviceBinaryImage>(std::move (Data), ImgSize);
1960+ auto ESPropSet = getExportedSymbolPS (RawImg);
1961+ sycl_device_binary_property ESProp;
1962+ for (ESProp = ESPropSet->PropertiesBegin ;
1963+ ESProp != ESPropSet->PropertiesEnd ; ++ESProp) {
1964+ m_ExportedSymbolImages.insert (
1965+ {ESProp->Name , DynBfloat16DeviceLibImg.get ()});
1966+ }
1967+ m_Bfloat16DeviceLibImages[Bfloat16DeviceLibVersion] =
1968+ std::move (DynBfloat16DeviceLibImg);
1969+ continue ;
1970+ }
1971+ }
1972+
19491973 // Register all exported symbols
19501974 for (const sycl_device_binary_property &ESProp :
19511975 Img->getExportedSymbols ()) {
@@ -2110,19 +2134,14 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
21102134}
21112135
21122136void ProgramManager::removeImages (sycl_device_binaries DeviceBinary) {
2113- bool IsRTC = isCompiledAtRuntime (DeviceBinary);
21142137 for (int I = 0 ; I < DeviceBinary->NumDeviceBinaries ; I++) {
21152138 sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries [I]);
21162139 auto DevImgIt = m_DeviceImages.find (RawImg);
21172140 if (DevImgIt == m_DeviceImages.end ())
21182141 continue ;
21192142 const sycl_offload_entry EntriesB = RawImg->EntriesBegin ;
21202143 const sycl_offload_entry EntriesE = RawImg->EntriesEnd ;
2121- // Skip clean up if there are no offload entries, unless `DeviceBinary`
2122- // resulted from runtime compilation: Then, this is one of the `bfloat16`
2123- // device libraries, so we want to make sure that the image and its exported
2124- // symbols are removed from the program manager's maps.
2125- if (EntriesB == EntriesE && !IsRTC)
2144+ if (EntriesB == EntriesE)
21262145 continue ;
21272146
21282147 RTDeviceBinaryImage *Img = DevImgIt->second .get ();
@@ -2650,7 +2669,11 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
26502669 std::shared_ptr<std::vector<sycl::kernel_id>> DepKernelIDs;
26512670 {
26522671 std::lock_guard<std::mutex> KernelIDsGuard (m_KernelIDsMutex);
2653- DepKernelIDs = m_BinImg2KernelIDs[Dep];
2672+ // For device library images, they are not in m_BinImg2KernelIDs since
2673+ // no kernel is included.
2674+ auto DepIt = m_BinImg2KernelIDs.find (Dep);
2675+ if (DepIt != m_BinImg2KernelIDs.end ())
2676+ DepKernelIDs = DepIt->second ;
26542677 }
26552678
26562679 assert (ImgInfoPair.second .State == getBinImageState (Dep) &&
@@ -2863,9 +2886,10 @@ static void mergeImageData(const std::vector<device_image_plain> &Imgs,
28632886 for (const device_image_plain &Img : Imgs) {
28642887 std::shared_ptr<device_image_impl> DeviceImageImpl = getSyclObjImpl (Img);
28652888 // Duplicates are not expected here, otherwise urProgramLink should fail
2866- KernelIDs.insert (KernelIDs.end (),
2867- DeviceImageImpl->get_kernel_ids_ptr ()->begin (),
2868- DeviceImageImpl->get_kernel_ids_ptr ()->end ());
2889+ if (DeviceImageImpl->get_kernel_ids_ptr ())
2890+ KernelIDs.insert (KernelIDs.end (),
2891+ DeviceImageImpl->get_kernel_ids_ptr ()->begin (),
2892+ DeviceImageImpl->get_kernel_ids_ptr ()->end ());
28692893 // To be able to answer queries about specialziation constants, the new
28702894 // device image should have the specialization constants from all the linked
28712895 // images.
0 commit comments