diff --git a/sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp b/sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp index a47e2800ec668..23bb4affd8f92 100644 --- a/sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp +++ b/sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp @@ -531,6 +531,19 @@ static bool getDeviceLibraries(const ArgList &Args, return FoundUnknownLib; } +static Expected> +loadBitcodeLibrary(StringRef LibPath, LLVMContext &Context) { + SMDiagnostic Diag; + std::unique_ptr Lib = parseIRFile(LibPath, Diag, Context); + if (!Lib) { + std::string DiagMsg; + raw_string_ostream SOS(DiagMsg); + Diag.print(/*ProgName=*/nullptr, SOS); + return createStringError(DiagMsg); + } + return std::move(Lib); +} + Error jit_compiler::linkDeviceLibraries(llvm::Module &Module, const InputArgList &UserArgList, std::string &BuildLog) { @@ -558,16 +571,13 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module, for (const std::string &LibName : LibNames) { std::string LibPath = DPCPPRoot + "/lib/" + LibName; - SMDiagnostic Diag; - std::unique_ptr Lib = parseIRFile(LibPath, Diag, Context); - if (!Lib) { - std::string DiagMsg; - raw_string_ostream SOS(DiagMsg); - Diag.print(/*ProgName=*/nullptr, SOS); - return createStringError(DiagMsg); + auto LibOrErr = loadBitcodeLibrary(LibPath, Context); + if (!LibOrErr) { + return LibOrErr.takeError(); } - if (Linker::linkModules(Module, std::move(Lib), Linker::LinkOnlyNeeded)) { + if (Linker::linkModules(Module, std::move(*LibOrErr), + Linker::LinkOnlyNeeded)) { return createStringError("Unable to link device library %s: %s", LibPath.c_str(), BuildLog.c_str()); } @@ -607,6 +617,31 @@ static IRSplitMode getDeviceCodeSplitMode(const InputArgList &UserArgList) { return SPLIT_AUTO; } +static void encodeProperties(PropertySetRegistry &Properties, + RTCDevImgInfo &DevImgInfo) { + const auto &PropertySets = Properties.getPropSets(); + + DevImgInfo.Properties = FrozenPropertyRegistry{PropertySets.size()}; + for (auto [KV, FrozenPropSet] : + zip_equal(PropertySets, DevImgInfo.Properties)) { + const auto &PropertySetName = KV.first; + const auto &PropertySet = KV.second; + FrozenPropSet = + FrozenPropertySet{PropertySetName.str(), PropertySet.size()}; + for (auto [KV2, FrozenProp] : + zip_equal(PropertySet, FrozenPropSet.Values)) { + const auto &PropertyName = KV2.first; + const auto &PropertyValue = KV2.second; + FrozenProp = PropertyValue.getType() == PropertyValue::Type::UINT32 + ? FrozenPropertyValue{PropertyName.str(), + PropertyValue.asUint32()} + : FrozenPropertyValue{ + PropertyName.str(), PropertyValue.asRawByteArray(), + PropertyValue.getRawByteArraySize()}; + } + }; +} + Expected jit_compiler::performPostLink(std::unique_ptr Module, const InputArgList &UserArgList) { @@ -637,9 +672,9 @@ jit_compiler::performPostLink(std::unique_ptr Module, // Otherwise: Port over the `removeSYCLKernelsConstRefArray` and // `removeDeviceGlobalFromCompilerUsed` methods. - assert(!isModuleUsingAsan(*Module)); - // Otherwise: Need to instrument each image scope device globals if the module - // has been instrumented by sanitizer pass. + assert(!(isModuleUsingAsan(*Module) || isModuleUsingMsan(*Module) || + isModuleUsingTsan(*Module))); + // Otherwise: Run `SanitizerKernelMetadataPass`. // Transform Joint Matrix builtin calls to align them with SPIR-V friendly // LLVM IR specification. @@ -668,6 +703,7 @@ jit_compiler::performPostLink(std::unique_ptr Module, // `-fno-sycl-device-code-split-esimd` as a prerequisite for compiling // `invoke_simd` code. + bool IsBF16DeviceLibUsed = false; while (Splitter->hasMoreSplits()) { ModuleDesc MDesc = Splitter->nextSplit(); @@ -701,35 +737,58 @@ jit_compiler::performPostLink(std::unique_ptr Module, /*DeviceGlobals=*/false}; PropertySetRegistry Properties = computeModuleProperties(MDesc.getModule(), MDesc.entries(), PropReq); + + // When the split mode is none, the required work group size will be added + // to the whole module, which will make the runtime unable to launch the + // other kernels in the module that have different required work group + // sizes or no required work group sizes. So we need to remove the + // required work group size metadata in this case. + if (SplitMode == module_split::SPLIT_NONE) { + Properties.remove(PropSetRegTy::SYCL_DEVICE_REQUIREMENTS, + PropSetRegTy::PROPERTY_REQD_WORK_GROUP_SIZE); + } + // TODO: Manually add `compile_target` property as in // `saveModuleProperties`? - const auto &PropertySets = Properties.getPropSets(); - - DevImgInfo.Properties = FrozenPropertyRegistry{PropertySets.size()}; - for (auto [KV, FrozenPropSet] : - zip_equal(PropertySets, DevImgInfo.Properties)) { - const auto &PropertySetName = KV.first; - const auto &PropertySet = KV.second; - FrozenPropSet = - FrozenPropertySet{PropertySetName.str(), PropertySet.size()}; - for (auto [KV2, FrozenProp] : - zip_equal(PropertySet, FrozenPropSet.Values)) { - const auto &PropertyName = KV2.first; - const auto &PropertyValue = KV2.second; - FrozenProp = - PropertyValue.getType() == PropertyValue::Type::UINT32 - ? FrozenPropertyValue{PropertyName.str(), - PropertyValue.asUint32()} - : FrozenPropertyValue{PropertyName.str(), - PropertyValue.asRawByteArray(), - PropertyValue.getRawByteArraySize()}; - } - }; + encodeProperties(Properties, DevImgInfo); + + IsBF16DeviceLibUsed |= isSYCLDeviceLibBF16Used(MDesc.getModule()); Modules.push_back(MDesc.releaseModulePtr()); } } + if (IsBF16DeviceLibUsed) { + const std::string &DPCPPRoot = getDPCPPRoot(); + if (DPCPPRoot == InvalidDPCPPRoot) { + return createStringError("Could not locate DPCPP root directory"); + } + + auto &Ctx = Modules.front()->getContext(); + auto WrapLibraryInDevImg = [&](const std::string &LibName) -> Error { + std::string LibPath = DPCPPRoot + "/lib/" + LibName; + auto LibOrErr = loadBitcodeLibrary(LibPath, Ctx); + if (!LibOrErr) { + return LibOrErr.takeError(); + } + + std::unique_ptr LibModule = std::move(*LibOrErr); + PropertySetRegistry Properties = + computeDeviceLibProperties(*LibModule, LibName); + encodeProperties(Properties, DevImgInfoVec.emplace_back()); + Modules.push_back(std::move(LibModule)); + + return Error::success(); + }; + + if (auto Err = WrapLibraryInDevImg("libsycl-fallback-bfloat16.bc")) { + return std::move(Err); + } + if (auto Err = WrapLibraryInDevImg("libsycl-native-bfloat16.bc")) { + return std::move(Err); + } + } + assert(DevImgInfoVec.size() == Modules.size()); RTCBundleInfo BundleInfo; BundleInfo.DevImgInfos = DynArray{DevImgInfoVec.size()}; diff --git a/sycl/source/detail/program_manager/program_manager.cpp b/sycl/source/detail/program_manager/program_manager.cpp index 96b7f3bca3f53..c6e16e96650ff 100644 --- a/sycl/source/detail/program_manager/program_manager.cpp +++ b/sycl/source/detail/program_manager/program_manager.cpp @@ -1837,13 +1837,17 @@ ProgramManager::kernelImplicitLocalArgPos(const std::string &KernelName) const { return {}; } -static bool shouldSkipEmptyImage(sycl_device_binary RawImg) { +static bool shouldSkipEmptyImage(sycl_device_binary RawImg, bool IsRTC) { // For bfloat16 device library image, we should keep it. However, in some // scenario, __sycl_register_lib can be called multiple times and the same // bfloat16 device library image may be handled multiple times which is not // needed. 2 static bool variables are created to record whether native or // fallback bfloat16 device library image has been handled, if yes, we just // need to skip it. + // We cannot prevent redundant loads of device library images if they are part + // of a runtime-compiled device binary, as these will be freed when the + // corresponding kernel bundle is destroyed. Hence, normal kernels cannot rely + // on the presence of RTC device library images. sycl_device_binary_property_set ImgPS; static bool IsNativeBF16DeviceLibHandled = false; static bool IsFallbackBF16DeviceLibHandled = false; @@ -1861,8 +1865,13 @@ static bool shouldSkipEmptyImage(sycl_device_binary RawImg) { if (ImgP == ImgPS->PropertiesEnd) return true; - // A valid bfloat16 device library image is found here, need to check - // wheter it has been handled already. + // A valid bfloat16 device library image is found here. + // If it originated from RTC, we cannot skip it, but do not mark it as + // being present. + if (IsRTC) + return false; + + // Otherwise, we need to check whether it has been handled already. uint32_t BF16NativeVal = DeviceBinaryProperty(ImgP).asUint32(); if (((BF16NativeVal == 0) && IsFallbackBF16DeviceLibHandled) || ((BF16NativeVal == 1) && IsNativeBF16DeviceLibHandled)) @@ -1879,14 +1888,33 @@ static bool shouldSkipEmptyImage(sycl_device_binary RawImg) { return true; } +static bool isCompiledAtRuntime(sycl_device_binaries DeviceBinary) { + // Check whether the first device binary contains a legacy format offload + // entry with a `$` in its name. + if (DeviceBinary->NumDeviceBinaries > 0) { + sycl_device_binary Binary = DeviceBinary->DeviceBinaries; + if (Binary->EntriesBegin != Binary->EntriesEnd) { + sycl_offload_entry Entry = Binary->EntriesBegin; + if (!Entry->IsNewOffloadEntryType() && + std::string_view{Entry->name}.find('$') != std::string_view::npos) { + return true; + } + } + } + return false; +} + void ProgramManager::addImages(sycl_device_binaries DeviceBinary) { const bool DumpImages = std::getenv("SYCL_DUMP_IMAGES") && !m_UseSpvFile; + const bool IsRTC = isCompiledAtRuntime(DeviceBinary); for (int I = 0; I < DeviceBinary->NumDeviceBinaries; I++) { sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries[I]); const sycl_offload_entry EntriesB = RawImg->EntriesBegin; const sycl_offload_entry EntriesE = RawImg->EntriesEnd; - // Treat the image as empty one - if ((EntriesB == EntriesE) && shouldSkipEmptyImage(RawImg)) + // If the image does not contain kernels, skip it unless it is one of the + // bfloat16 device libraries, and it wasn't loaded before or resulted from + // runtime compilation. + if ((EntriesB == EntriesE) && shouldSkipEmptyImage(RawImg, IsRTC)) continue; std::unique_ptr Img; @@ -2081,6 +2109,7 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) { } void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) { + bool IsRTC = isCompiledAtRuntime(DeviceBinary); for (int I = 0; I < DeviceBinary->NumDeviceBinaries; I++) { sycl_device_binary RawImg = &(DeviceBinary->DeviceBinaries[I]); auto DevImgIt = m_DeviceImages.find(RawImg); @@ -2088,8 +2117,11 @@ void ProgramManager::removeImages(sycl_device_binaries DeviceBinary) { continue; const sycl_offload_entry EntriesB = RawImg->EntriesBegin; const sycl_offload_entry EntriesE = RawImg->EntriesEnd; - // Treat the image as empty one - if (EntriesB == EntriesE) + // Skip clean up if there are no offload entries, unless `DeviceBinary` + // resulted from runtime compilation: Then, this is one of the `bfloat16` + // device libraries, so we want to make sure that the image and its exported + // symbols are removed from the program manager's maps. + if (EntriesB == EntriesE && !IsRTC) continue; RTDeviceBinaryImage *Img = DevImgIt->second.get(); diff --git a/sycl/test-e2e/KernelCompiler/sycl.cpp b/sycl/test-e2e/KernelCompiler/sycl.cpp index b8ba0ffb65fac..c3d2d43eba5af 100644 --- a/sycl/test-e2e/KernelCompiler/sycl.cpp +++ b/sycl/test-e2e/KernelCompiler/sycl.cpp @@ -134,6 +134,12 @@ void device_libs_kernel(float *ptr) { // cl_intel_devicelib_imf ptr[3] = sycl::ext::intel::math::sqrt(ptr[3] * 2); + + // cl_intel_devicelib_imf_bf16 + ptr[4] = sycl::ext::intel::math::float2bfloat16(ptr[4] * 0.5f); + + // cl_intel_devicelib_bfloat16 + ptr[5] = sycl::ext::oneapi::bfloat16{ptr[5] / 0.25f}; } )==="; @@ -435,7 +441,7 @@ int test_device_libraries() { exe_kb kbExe = syclex::build(kbSrc); sycl::kernel k = kbExe.ext_oneapi_get_kernel("device_libs_kernel"); - constexpr size_t nElem = 4; + constexpr size_t nElem = 6; float *ptr = sycl::malloc_shared(nElem, q); for (int i = 0; i < nElem; ++i) ptr[i] = 1.0f; @@ -446,8 +452,8 @@ int test_device_libraries() { }); q.wait_and_throw(); - // Check that the kernel was executed. Given the {1.0, 1.0, 1.0, 1.0} input, - // the expected result is approximately {0.84, 1.41, 0.0, 1.41}. + // Check that the kernel was executed. Given the {1.0, ..., 1.0} input, + // the expected result is approximately {0.84, 1.41, 0.0, 1.41, 0.5, 4.0}. for (unsigned i = 0; i < nElem; ++i) { std::cout << ptr[i] << ' '; assert(ptr[i] != 1.0f);