-
Notifications
You must be signed in to change notification settings - Fork 796
[SYCL] Keep multiple copies for bf16 device library image #17461
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
4ee0932
57b5523
57d14fd
26f8895
34f396f
b09663d
fcbbe1c
c794cb3
db0281d
aa60bc5
1351401
640f679
24fa45a
472d1be
a33f24d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,7 +5,6 @@ | |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
| // | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| #include <detail/compiler.hpp> | ||
| #include <detail/config.hpp> | ||
| #include <detail/context_impl.hpp> | ||
|
|
@@ -1837,84 +1836,27 @@ ProgramManager::kernelImplicitLocalArgPos(const std::string &KernelName) const { | |
| return {}; | ||
| } | ||
|
|
||
| static bool shouldSkipEmptyImage(sycl_device_binary RawImg, bool IsRTC) { | ||
| // For bfloat16 device library image, we should keep it. However, in some | ||
| // scenario, __sycl_register_lib can be called multiple times and the same | ||
| // bfloat16 device library image may be handled multiple times which is not | ||
| // needed. 2 static bool variables are created to record whether native or | ||
| // fallback bfloat16 device library image has been handled, if yes, we just | ||
| // need to skip it. | ||
| // We cannot prevent redundant loads of device library images if they are part | ||
| // of a runtime-compiled device binary, as these will be freed when the | ||
| // corresponding kernel bundle is destroyed. Hence, normal kernels cannot rely | ||
| // on the presence of RTC device library images. | ||
| static bool shouldSkipEmptyImage(sycl_device_binary RawImg) { | ||
| // For bfloat16 device library image, we should keep it although it doesn't | ||
| // include any kernel. | ||
| sycl_device_binary_property_set ImgPS; | ||
| static bool IsNativeBF16DeviceLibHandled = false; | ||
| static bool IsFallbackBF16DeviceLibHandled = false; | ||
| for (ImgPS = RawImg->PropertySetsBegin; ImgPS != RawImg->PropertySetsEnd; | ||
| ++ImgPS) { | ||
| if (ImgPS->Name && | ||
| !strcmp(__SYCL_PROPERTY_SET_DEVICELIB_METADATA, ImgPS->Name)) { | ||
| sycl_device_binary_property ImgP; | ||
| for (ImgP = ImgPS->PropertiesBegin; ImgP != ImgPS->PropertiesEnd; | ||
| ++ImgP) { | ||
| if (ImgP->Name && !strcmp("bfloat16", ImgP->Name) && | ||
| (ImgP->Type == SYCL_PROPERTY_TYPE_UINT32)) | ||
| break; | ||
| } | ||
| if (ImgP == ImgPS->PropertiesEnd) | ||
| return true; | ||
|
|
||
| // A valid bfloat16 device library image is found here. | ||
| // If it originated from RTC, we cannot skip it, but do not mark it as | ||
| // being present. | ||
| if (IsRTC) | ||
| return false; | ||
|
|
||
| // Otherwise, we need to check whether it has been handled already. | ||
| uint32_t BF16NativeVal = DeviceBinaryProperty(ImgP).asUint32(); | ||
| if (((BF16NativeVal == 0) && IsFallbackBF16DeviceLibHandled) || | ||
| ((BF16NativeVal == 1) && IsNativeBF16DeviceLibHandled)) | ||
| return true; | ||
|
|
||
| if (BF16NativeVal == 0) | ||
| IsFallbackBF16DeviceLibHandled = true; | ||
| else | ||
| IsNativeBF16DeviceLibHandled = true; | ||
|
|
||
| !strcmp(__SYCL_PROPERTY_SET_DEVICELIB_METADATA, ImgPS->Name)) | ||
| return false; | ||
| } | ||
| } | ||
| return true; | ||
| } | ||
|
|
||
| static bool isCompiledAtRuntime(sycl_device_binaries DeviceBinary) { | ||
| // Check whether the first device binary contains a legacy format offload | ||
| // entry with a `$` in its name. | ||
| if (DeviceBinary->NumDeviceBinaries > 0) { | ||
| sycl_device_binary Binary = DeviceBinary->DeviceBinaries; | ||
| if (Binary->EntriesBegin != Binary->EntriesEnd) { | ||
| sycl_offload_entry Entry = Binary->EntriesBegin; | ||
| if (!Entry->IsNewOffloadEntryType() && | ||
| std::string_view{Entry->name}.find('$') != std::string_view::npos) { | ||
| return true; | ||
| } | ||
| } | ||
| } | ||
| return false; | ||
| return true; | ||
| } | ||
|
|
||
| void ProgramManager::addImages(sycl_device_binaries DeviceBinary) { | ||
| const bool DumpImages = std::getenv("SYCL_DUMP_IMAGES") && !m_UseSpvFile; | ||
| const bool IsRTC = isCompiledAtRuntime(DeviceBinary); | ||
| for (int I = 0; I < DeviceBinary->NumDeviceBinaries; I++) { | ||
| sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries[I]); | ||
| const sycl_offload_entry EntriesB = RawImg->EntriesBegin; | ||
| const sycl_offload_entry EntriesE = RawImg->EntriesEnd; | ||
| // If the image does not contain kernels, skip it unless it is one of the | ||
| // bfloat16 device libraries, and it wasn't loaded before or resulted from | ||
| // runtime compilation. | ||
| if ((EntriesB == EntriesE) && shouldSkipEmptyImage(RawImg, IsRTC)) | ||
| if ((EntriesB == EntriesE) && shouldSkipEmptyImage(RawImg)) | ||
| continue; | ||
|
|
||
| std::unique_ptr<RTDeviceBinaryImage> Img; | ||
|
|
@@ -1946,6 +1888,32 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) { | |
| // Fill maps for kernel bundles | ||
| std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex); | ||
|
|
||
| // For bfloat16 device library image, it doesn't include any kernel, device | ||
| // global, virtual function, so just skip adding it to any related maps. | ||
| // We only need to: 1). add exported symbols to m_ExportedSymbolImages. 2). | ||
| // add the device image to m_DeviceImages used for future clean up when | ||
| // removeImage is called. RefCount is used to keep how many user device | ||
| // images are depending on native/fallback bfloat16 device library image, | ||
| // the corresponding image will be added to m_ExportedSymbolImages and | ||
| // m_DeviceImages only when RefCount is 0. These RefCount are used when | ||
| // KernelIDsGuard is acquired by current thread. | ||
| { | ||
| auto Bfloat16DeviceLibProp = Img->getDeviceLibMetadata(); | ||
| if (Bfloat16DeviceLibProp.isAvailable()) { | ||
| uint32_t IsNative = | ||
| DeviceBinaryProperty(*(Bfloat16DeviceLibProp.begin())).asUint32(); | ||
| if (!m_Bfloat16DeviceLibRefCount[IsNative]) { | ||
| for (const sycl_device_binary_property &ESProp : | ||
| Img->getExportedSymbols()) { | ||
| m_ExportedSymbolImages.insert({ESProp->Name, Img.get()}); | ||
| } | ||
| m_DeviceImages.insert({RawImg, std::move(Img)}); | ||
| } | ||
| m_Bfloat16DeviceLibRefCount[IsNative] += 1; | ||
| continue; | ||
| } | ||
| } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Making the special handling explicit (and only doing the nececssary things) is a good idea 👍 |
||
|
|
||
| // Register all exported symbols | ||
| for (const sycl_device_binary_property &ESProp : | ||
| Img->getExportedSymbols()) { | ||
|
|
@@ -2110,19 +2078,14 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) { | |
| } | ||
|
|
||
| void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) { | ||
| bool IsRTC = isCompiledAtRuntime(DeviceBinary); | ||
| for (int I = 0; I < DeviceBinary->NumDeviceBinaries; I++) { | ||
| sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries[I]); | ||
| auto DevImgIt = m_DeviceImages.find(RawImg); | ||
| if (DevImgIt == m_DeviceImages.end()) | ||
| continue; | ||
| const sycl_offload_entry EntriesB = RawImg->EntriesBegin; | ||
| const sycl_offload_entry EntriesE = RawImg->EntriesEnd; | ||
| // Skip clean up if there are no offload entries, unless `DeviceBinary` | ||
| // resulted from runtime compilation: Then, this is one of the `bfloat16` | ||
| // device libraries, so we want to make sure that the image and its exported | ||
| // symbols are removed from the program manager's maps. | ||
| if (EntriesB == EntriesE && !IsRTC) | ||
| if ((EntriesB == EntriesE) && shouldSkipEmptyImage(RawImg)) | ||
| continue; | ||
|
|
||
| RTDeviceBinaryImage *Img = DevImgIt->second.get(); | ||
|
|
@@ -2133,6 +2096,28 @@ void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) { | |
| // Acquire lock to modify maps for kernel bundles | ||
| std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex); | ||
|
|
||
| { | ||
| // Clean up Bfloat16 device library image, unregister exported symbols | ||
| // and remove the device image only when RefCount is 0. | ||
| auto Bfloat16DeviceLibProp = Img->getDeviceLibMetadata(); | ||
| if (Bfloat16DeviceLibProp.isAvailable()) { | ||
| uint32_t IsNative = | ||
| DeviceBinaryProperty(*(Bfloat16DeviceLibProp.begin())).asUint32(); | ||
| if (m_Bfloat16DeviceLibRefCount[IsNative] != 0) | ||
| m_Bfloat16DeviceLibRefCount[IsNative] -= 1; | ||
| if (!m_Bfloat16DeviceLibRefCount[IsNative]) { | ||
| for (const sycl_device_binary_property &ESProp : | ||
| Img->getExportedSymbols()) { | ||
| m_ExportedSymbolImages.erase(ESProp->Name); | ||
| } | ||
|
|
||
| m_DeviceImages.erase(DevImgIt); | ||
|
||
| } | ||
|
|
||
| continue; | ||
| } | ||
| } | ||
|
|
||
| // Unmap the unique kernel IDs for the offload entries | ||
| for (sycl_offload_entry EntriesIt = EntriesB; EntriesIt != EntriesE; | ||
| EntriesIt = EntriesIt->Increment()) { | ||
|
|
@@ -2650,7 +2635,10 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState( | |
| std::shared_ptr<std::vector<sycl::kernel_id>> DepKernelIDs; | ||
| { | ||
| std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex); | ||
| DepKernelIDs = m_BinImg2KernelIDs[Dep]; | ||
| // For device library images, they are not in m_BinImg2KernelIDs since | ||
| // no kernel is included. | ||
| if (m_BinImg2KernelIDs.find(Dep) != m_BinImg2KernelIDs.end()) | ||
| DepKernelIDs = m_BinImg2KernelIDs[Dep]; | ||
jinge90 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| assert(ImgInfoPair.second.State == getBinImageState(Dep) && | ||
|
|
@@ -2862,6 +2850,8 @@ static void mergeImageData(const std::vector<device_image_plain> &Imgs, | |
| device_image_impl::SpecConstMapT &NewSpecConstMap) { | ||
| for (const device_image_plain &Img : Imgs) { | ||
| std::shared_ptr<device_image_impl> DeviceImageImpl = getSyclObjImpl(Img); | ||
| if (!DeviceImageImpl->get_kernel_ids_ptr()) | ||
| continue; | ||
|
||
| // Duplicates are not expected here, otherwise urProgramLink should fail | ||
| KernelIDs.insert(KernelIDs.end(), | ||
| DeviceImageImpl->get_kernel_ids_ptr()->begin(), | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.