diff --git a/sycl-jit/jit-compiler/CMakeLists.txt b/sycl-jit/jit-compiler/CMakeLists.txt index b01964f493f10..31b90e104c6b7 100644 --- a/sycl-jit/jit-compiler/CMakeLists.txt +++ b/sycl-jit/jit-compiler/CMakeLists.txt @@ -9,12 +9,15 @@ else() set(SYCL_JIT_VIRTUAL_TOOLCHAIN_ROOT "/sycl-jit-toolchain/") endif() -# TODO: libdevice -set(SYCL_JIT_RESOURCE_DEPS sycl-headers clang ${CMAKE_CURRENT_SOURCE_DIR}/utils/generate.py) +set(SYCL_JIT_RESOURCE_DEPS + sycl-headers # include/sycl + clang # lib/clang/N/include + libsycldevice # lib/*.bc + ${CMAKE_CURRENT_SOURCE_DIR}/utils/generate.py) if ("libclc" IN_LIST LLVM_ENABLE_PROJECTS) # Somehow just "libclc" doesn't build "remangled-*" (and maybe whatever else). - list(APPEND SYCL_JIT_RESOURCE_DEPS libclc libspirv-builtins) + list(APPEND SYCL_JIT_RESOURCE_DEPS libclc libspirv-builtins) # lib/clc/*.bc endif() add_custom_command( diff --git a/sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp b/sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp index 59eda6f548faa..42c6ca45288d3 100644 --- a/sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp +++ b/sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp @@ -64,86 +64,6 @@ using namespace llvm::util; using namespace llvm::vfs; using namespace jit_compiler; -#ifdef _GNU_SOURCE -#include -static char X; // Dummy symbol, used as an anchor for `dlinfo` below. -#endif - -#ifdef _WIN32 -#include // For std::filesystem::path ( C++17 only ) -#include // For PathRemoveFileSpec -#include // For GetModuleFileName, HMODULE, DWORD, MAX_PATH - -// cribbed from sycl/source/detail/os_util.cpp -using OSModuleHandle = intptr_t; -static constexpr OSModuleHandle ExeModuleHandle = -1; -static OSModuleHandle getOSModuleHandle(const void *VirtAddr) { - HMODULE PhModule; - DWORD Flag = GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | - GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT; - auto LpModuleAddr = reinterpret_cast(VirtAddr); - if (!GetModuleHandleExA(Flag, LpModuleAddr, &PhModule)) { - // Expect the caller to check for zero and take - // necessary action - return 0; - } - if (PhModule == GetModuleHandleA(nullptr)) - return ExeModuleHandle; - return reinterpret_cast(PhModule); -} - -// cribbed from sycl/source/detail/os_util.cpp -/// Returns an absolute path where the object was found. -std::wstring getCurrentDSODir() { - wchar_t Path[MAX_PATH]; - auto Handle = getOSModuleHandle(reinterpret_cast(&getCurrentDSODir)); - DWORD Ret = GetModuleFileName( - reinterpret_cast(ExeModuleHandle == Handle ? 0 : Handle), Path, - MAX_PATH); - assert(Ret < MAX_PATH && "Path is longer than MAX_PATH?"); - assert(Ret > 0 && "GetModuleFileName failed"); - (void)Ret; - - BOOL RetCode = PathRemoveFileSpec(Path); - assert(RetCode && "PathRemoveFileSpec failed"); - (void)RetCode; - - return Path; -} -#endif // _WIN32 - -static constexpr auto InvalidDPCPPRoot = ""; - -static const std::string &getDPCPPRoot() { - thread_local std::string DPCPPRoot; - - if (!DPCPPRoot.empty()) { - return DPCPPRoot; - } - DPCPPRoot = InvalidDPCPPRoot; - -#ifdef _GNU_SOURCE - static constexpr auto JITLibraryPathSuffix = "/lib/libsycl-jit.so"; - Dl_info Info; - if (dladdr(&X, &Info)) { - std::string LoadedLibraryPath = Info.dli_fname; - auto Pos = LoadedLibraryPath.rfind(JITLibraryPathSuffix); - if (Pos != std::string::npos) { - DPCPPRoot = LoadedLibraryPath.substr(0, Pos); - } - } -#endif // _GNU_SOURCE - -#ifdef _WIN32 - DPCPPRoot = std::filesystem::path(getCurrentDSODir()).parent_path().string(); -#endif // _WIN32 - - // TODO: Implemenent other means of determining the DPCPP root, e.g. - // evaluating the `CMPLR_ROOT` env. - - return DPCPPRoot; -} - namespace { class HashPreprocessedAction : public PreprocessorFrontendAction { @@ -252,6 +172,30 @@ class SYCLToolchain { return TI.run(); } + Expected loadBitcodeLibrary(StringRef LibPath, + LLVMContext &Context) { + auto FS = llvm::makeIntrusiveRefCnt( + llvm::vfs::getRealFileSystem()); + FS->pushOverlay(ToolchainFS); + + auto MemBuf = FS->getBufferForFile(LibPath, /*FileSize*/ -1, + /*RequiresNullTerminator*/ false); + if (!MemBuf) { + return createStringError("Error opening file %s: %s", LibPath.data(), + MemBuf.getError().message().c_str()); + } + + SMDiagnostic Diag; + ModuleUPtr Lib = parseIR(*MemBuf->get(), Diag, Context); + if (!Lib) { + std::string DiagMsg; + raw_string_ostream SOS(DiagMsg); + Diag.print(/*ProgName=*/nullptr, SOS); + return createStringError(DiagMsg); + } + return std::move(Lib); + } + std::string_view getClangXXExe() const { return ClangXXExe; } private: @@ -516,30 +460,12 @@ static bool getDeviceLibraries(const ArgList &Args, return FoundUnknownLib; } -static Expected loadBitcodeLibrary(StringRef LibPath, - LLVMContext &Context) { - SMDiagnostic Diag; - ModuleUPtr Lib = parseIRFile(LibPath, Diag, Context); - if (!Lib) { - std::string DiagMsg; - raw_string_ostream SOS(DiagMsg); - Diag.print(/*ProgName=*/nullptr, SOS); - return createStringError(DiagMsg); - } - return std::move(Lib); -} - Error jit_compiler::linkDeviceLibraries(llvm::Module &Module, const InputArgList &UserArgList, std::string &BuildLog, BinaryFormat Format) { TimeTraceScope TTS{"linkDeviceLibraries"}; - const std::string &DPCPPRoot = getDPCPPRoot(); - if (DPCPPRoot == InvalidDPCPPRoot) { - return createStringError("Could not locate DPCPP root directory"); - } - IntrusiveRefCntPtr DiagID{new DiagnosticIDs}; DiagnosticOptions DiagOpts; ClangDiagnosticWrapper Wrapper(BuildLog, &DiagOpts); @@ -573,10 +499,13 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module, LLVMContext &Context = Module.getContext(); for (const std::string &LibName : LibNames) { - std::string LibPath = DPCPPRoot + "/lib/" + LibName; + std::string LibPath = + (jit_compiler::ToolchainPrefix + "/lib/" + LibName).str(); ModuleUPtr LibModule; - if (auto Error = loadBitcodeLibrary(LibPath, Context).moveInto(LibModule)) { + if (auto Error = SYCLToolchain::instance() + .loadBitcodeLibrary(LibPath, Context) + .moveInto(LibModule)) { return Error; } @@ -590,14 +519,16 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module, // For GPU targets we need to link against vendor provided libdevice. if (IsCudaHIP) { Triple T{Module.getTargetTriple()}; - Driver D{(Twine(DPCPPRoot) + "/bin/clang++").str(), T.getTriple(), Diags}; + Driver D{(jit_compiler::ToolchainPrefix + "/bin/clang++").str(), + T.getTriple(), Diags}; auto [CPU, Features] = Translator::getTargetCPUAndFeatureAttrs(&Module, "", Format); (void)Features; // Helper lambda to link modules. auto LinkInLib = [&](const StringRef LibDevice) -> Error { ModuleUPtr LibDeviceModule; - if (auto Error = loadBitcodeLibrary(LibDevice, Context) + if (auto Error = SYCLToolchain::instance() + .loadBitcodeLibrary(LibDevice, Context) .moveInto(LibDeviceModule)) { return Error; } @@ -831,16 +762,14 @@ jit_compiler::performPostLink(ModuleUPtr Module, } if (IsBF16DeviceLibUsed) { - const std::string &DPCPPRoot = getDPCPPRoot(); - if (DPCPPRoot == InvalidDPCPPRoot) { - return createStringError("Could not locate DPCPP root directory"); - } - auto &Ctx = Modules.front()->getContext(); auto WrapLibraryInDevImg = [&](const std::string &LibName) -> Error { - std::string LibPath = DPCPPRoot + "/lib/" + LibName; + std::string LibPath = + (jit_compiler::ToolchainPrefix + "/lib/" + LibName).str(); ModuleUPtr LibModule; - if (auto Error = loadBitcodeLibrary(LibPath, Ctx).moveInto(LibModule)) { + if (auto Error = SYCLToolchain::instance() + .loadBitcodeLibrary(LibPath, Ctx) + .moveInto(LibModule)) { return Error; } diff --git a/sycl-jit/jit-compiler/utils/generate.py b/sycl-jit/jit-compiler/utils/generate.py index b4bcf6871c60e..d2cf53b432b35 100644 --- a/sycl-jit/jit-compiler/utils/generate.py +++ b/sycl-jit/jit-compiler/utils/generate.py @@ -1,5 +1,6 @@ import os import argparse +import glob def main(): @@ -31,12 +32,9 @@ def main(): const std::pair ToolchainFiles[] = {""" ) - def process_dir(dir): - for root, _, files in os.walk(dir): - for file in files: - file_path = os.path.join(root, file) - out.write( - f""" + def process_file(file_path): + out.write( + f""" {{ {{"{args.prefix}{os.path.relpath(file_path, toolchain_dir).replace(os.sep, "/")}"}} , []() {{ @@ -46,12 +44,24 @@ def process_dir(dir): return std::string_view(data, std::size(data) - 1); }}() }},""" - ) + ) + + def process_dir(dir): + for root, _, files in os.walk(dir): + for file in files: + file_path = os.path.join(root, file) + process_file(file_path) process_dir(os.path.join(args.toolchain_dir, "include/")) process_dir(os.path.join(args.toolchain_dir, "lib/clang/")) process_dir(os.path.join(args.toolchain_dir, "lib/clc/")) + for file in glob.iglob( + "*.bc", root_dir=os.path.join(args.toolchain_dir, "lib") + ): + file_path = os.path.join(args.toolchain_dir, "lib", file) + process_file(file_path) + out.write( f""" }};