Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 61 additions & 71 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <detail/compiler.hpp>
#include <detail/config.hpp>
#include <detail/context_impl.hpp>
Expand Down Expand Up @@ -1837,84 +1836,27 @@ ProgramManager::kernelImplicitLocalArgPos(const std::string &KernelName) const {
return {};
}

static bool shouldSkipEmptyImage(sycl_device_binary RawImg, bool IsRTC) {
// For bfloat16 device library image, we should keep it. However, in some
// scenario, __sycl_register_lib can be called multiple times and the same
// bfloat16 device library image may be handled multiple times which is not
// needed. 2 static bool variables are created to record whether native or
// fallback bfloat16 device library image has been handled, if yes, we just
// need to skip it.
// We cannot prevent redundant loads of device library images if they are part
// of a runtime-compiled device binary, as these will be freed when the
// corresponding kernel bundle is destroyed. Hence, normal kernels cannot rely
// on the presence of RTC device library images.
static bool shouldSkipEmptyImage(sycl_device_binary RawImg) {
// For bfloat16 device library image, we should keep it although it doesn't
// include any kernel.
sycl_device_binary_property_set ImgPS;
static bool IsNativeBF16DeviceLibHandled = false;
static bool IsFallbackBF16DeviceLibHandled = false;
for (ImgPS = RawImg->PropertySetsBegin; ImgPS != RawImg->PropertySetsEnd;
++ImgPS) {
if (ImgPS->Name &&
!strcmp(__SYCL_PROPERTY_SET_DEVICELIB_METADATA, ImgPS->Name)) {
sycl_device_binary_property ImgP;
for (ImgP = ImgPS->PropertiesBegin; ImgP != ImgPS->PropertiesEnd;
++ImgP) {
if (ImgP->Name && !strcmp("bfloat16", ImgP->Name) &&
(ImgP->Type == SYCL_PROPERTY_TYPE_UINT32))
break;
}
if (ImgP == ImgPS->PropertiesEnd)
return true;

// A valid bfloat16 device library image is found here.
// If it originated from RTC, we cannot skip it, but do not mark it as
// being present.
if (IsRTC)
return false;

// Otherwise, we need to check whether it has been handled already.
uint32_t BF16NativeVal = DeviceBinaryProperty(ImgP).asUint32();
if (((BF16NativeVal == 0) && IsFallbackBF16DeviceLibHandled) ||
((BF16NativeVal == 1) && IsNativeBF16DeviceLibHandled))
return true;

if (BF16NativeVal == 0)
IsFallbackBF16DeviceLibHandled = true;
else
IsNativeBF16DeviceLibHandled = true;

!strcmp(__SYCL_PROPERTY_SET_DEVICELIB_METADATA, ImgPS->Name))
return false;
}
}
return true;
}

static bool isCompiledAtRuntime(sycl_device_binaries DeviceBinary) {
// Check whether the first device binary contains a legacy format offload
// entry with a `$` in its name.
if (DeviceBinary->NumDeviceBinaries > 0) {
sycl_device_binary Binary = DeviceBinary->DeviceBinaries;
if (Binary->EntriesBegin != Binary->EntriesEnd) {
sycl_offload_entry Entry = Binary->EntriesBegin;
if (!Entry->IsNewOffloadEntryType() &&
std::string_view{Entry->name}.find('$') != std::string_view::npos) {
return true;
}
}
}
return false;
return true;
}

void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
const bool DumpImages = std::getenv("SYCL_DUMP_IMAGES") && !m_UseSpvFile;
const bool IsRTC = isCompiledAtRuntime(DeviceBinary);
for (int I = 0; I < DeviceBinary->NumDeviceBinaries; I++) {
sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries[I]);
const sycl_offload_entry EntriesB = RawImg->EntriesBegin;
const sycl_offload_entry EntriesE = RawImg->EntriesEnd;
// If the image does not contain kernels, skip it unless it is one of the
// bfloat16 device libraries, and it wasn't loaded before or resulted from
// runtime compilation.
if ((EntriesB == EntriesE) && shouldSkipEmptyImage(RawImg, IsRTC))
if ((EntriesB == EntriesE) && shouldSkipEmptyImage(RawImg))
continue;

std::unique_ptr<RTDeviceBinaryImage> Img;
Expand Down Expand Up @@ -1946,6 +1888,32 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
// Fill maps for kernel bundles
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);

// For bfloat16 device library image, it doesn't include any kernel, device
// global, virtual function, so just skip adding it to any related maps.
// We only need to: 1). add exported symbols to m_ExportedSymbolImages. 2).
// add the device image to m_DeviceImages used for future clean up when
// removeImage is called. RefCount is used to keep how many user device
// images are depending on native/fallback bfloat16 device library image,
// the corresponding image will be added to m_ExportedSymbolImages and
// m_DeviceImages only when RefCount is 0. These RefCount are used when
// KernelIDsGuard is acquired by current thread.
{
auto Bfloat16DeviceLibProp = Img->getDeviceLibMetadata();
if (Bfloat16DeviceLibProp.isAvailable()) {
uint32_t IsNative =
DeviceBinaryProperty(*(Bfloat16DeviceLibProp.begin())).asUint32();
if (!m_Bfloat16DeviceLibRefCount[IsNative]) {
for (const sycl_device_binary_property &ESProp :
Img->getExportedSymbols()) {
m_ExportedSymbolImages.insert({ESProp->Name, Img.get()});
}
m_DeviceImages.insert({RawImg, std::move(Img)});
}
m_Bfloat16DeviceLibRefCount[IsNative] += 1;
continue;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making the special handling explicit (and only doing the nececssary things) is a good idea 👍


// Register all exported symbols
for (const sycl_device_binary_property &ESProp :
Img->getExportedSymbols()) {
Expand Down Expand Up @@ -2110,19 +2078,14 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {
}

void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) {
bool IsRTC = isCompiledAtRuntime(DeviceBinary);
for (int I = 0; I < DeviceBinary->NumDeviceBinaries; I++) {
sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries[I]);
auto DevImgIt = m_DeviceImages.find(RawImg);
if (DevImgIt == m_DeviceImages.end())
continue;
const sycl_offload_entry EntriesB = RawImg->EntriesBegin;
const sycl_offload_entry EntriesE = RawImg->EntriesEnd;
// Skip clean up if there are no offload entries, unless `DeviceBinary`
// resulted from runtime compilation: Then, this is one of the `bfloat16`
// device libraries, so we want to make sure that the image and its exported
// symbols are removed from the program manager's maps.
if (EntriesB == EntriesE && !IsRTC)
if ((EntriesB == EntriesE) && shouldSkipEmptyImage(RawImg))
continue;

RTDeviceBinaryImage *Img = DevImgIt->second.get();
Expand All @@ -2133,6 +2096,28 @@ void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) {
// Acquire lock to modify maps for kernel bundles
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);

{
// Clean up Bfloat16 device library image, unregister exported symbols
// and remove the device image only when RefCount is 0.
auto Bfloat16DeviceLibProp = Img->getDeviceLibMetadata();
if (Bfloat16DeviceLibProp.isAvailable()) {
uint32_t IsNative =
DeviceBinaryProperty(*(Bfloat16DeviceLibProp.begin())).asUint32();
if (m_Bfloat16DeviceLibRefCount[IsNative] != 0)
m_Bfloat16DeviceLibRefCount[IsNative] -= 1;
if (!m_Bfloat16DeviceLibRefCount[IsNative]) {
for (const sycl_device_binary_property &ESProp :
Img->getExportedSymbols()) {
m_ExportedSymbolImages.erase(ESProp->Name);
}

m_DeviceImages.erase(DevImgIt);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not erasing the device image here when the refcount > 0 won't keep the underlying device binaries alive. Consider the following situation:

  • bundle A loaded, contributes bfloat dev lib into runtime -> refcount = 1
  • bundle B loaded, uses bfloat dev lib -> refcount = 2
  • bundle A is freed -> refcount = 1
  • bundle C loaded, uses bfloat dev lib -> refcount = 2, and crash when PM tries to link the kernels because m_ExportedSymbols points to image coming from bundle A, which has been destroyed

}

continue;
}
}

// Unmap the unique kernel IDs for the offload entries
for (sycl_offload_entry EntriesIt = EntriesB; EntriesIt != EntriesE;
EntriesIt = EntriesIt->Increment()) {
Expand Down Expand Up @@ -2650,7 +2635,10 @@ ProgramManager::getSYCLDeviceImagesWithCompatibleState(
std::shared_ptr<std::vector<sycl::kernel_id>> DepKernelIDs;
{
std::lock_guard<std::mutex> KernelIDsGuard(m_KernelIDsMutex);
DepKernelIDs = m_BinImg2KernelIDs[Dep];
// For device library images, they are not in m_BinImg2KernelIDs since
// no kernel is included.
if (m_BinImg2KernelIDs.find(Dep) != m_BinImg2KernelIDs.end())
DepKernelIDs = m_BinImg2KernelIDs[Dep];
}

assert(ImgInfoPair.second.State == getBinImageState(Dep) &&
Expand Down Expand Up @@ -2862,6 +2850,8 @@ static void mergeImageData(const std::vector<device_image_plain> &Imgs,
device_image_impl::SpecConstMapT &NewSpecConstMap) {
for (const device_image_plain &Img : Imgs) {
std::shared_ptr<device_image_impl> DeviceImageImpl = getSyclObjImpl(Img);
if (!DeviceImageImpl->get_kernel_ids_ptr())
continue;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are other uses of device_image_impl::get_kernel_ids_ptr() which don't seem to expect it being uninitialised. Maybe it's safer to set up an empty vector in m_BinImg2KernelIDs as part of the device lib special handling?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @jopperm
I filtered bf16 devicelib image in mergeData, since all devicelib never use spec const or include any kernels, we don't need to do the extra "merge data" work for them.

// Duplicates are not expected here, otherwise urProgramLink should fail
KernelIDs.insert(KernelIDs.end(),
DeviceImageImpl->get_kernel_ids_ptr()->begin(),
Expand Down
1 change: 1 addition & 0 deletions sycl/source/detail/program_manager/program_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ class ProgramManager {
std::map<std::vector<unsigned char>, ur_kernel_handle_t>;
std::unordered_map<std::string, MaterializedEntries> m_MaterializedKernels;

size_t m_Bfloat16DeviceLibRefCount[2] = {0, 0};
friend class ::ProgramManagerTest;
};
} // namespace detail
Expand Down