Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions sycl-jit/jit-compiler/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@ else()
set(SYCL_JIT_VIRTUAL_TOOLCHAIN_ROOT "/sycl-jit-toolchain/")
endif()

# TODO: libdevice
set(SYCL_JIT_RESOURCE_DEPS sycl-headers clang ${CMAKE_CURRENT_SOURCE_DIR}/utils/generate.py)
set(SYCL_JIT_RESOURCE_DEPS
sycl-headers # include/sycl
clang # lib/clang/N/include
libsycldevice # lib/*.bc
${CMAKE_CURRENT_SOURCE_DIR}/utils/generate.py)

if ("libclc" IN_LIST LLVM_ENABLE_PROJECTS)
# Somehow just "libclc" doesn't build "remangled-*" (and maybe whatever else).
list(APPEND SYCL_JIT_RESOURCE_DEPS libclc libspirv-builtins)
list(APPEND SYCL_JIT_RESOURCE_DEPS libclc libspirv-builtins) # lib/clc/*.bc
endif()

add_custom_command(
Expand Down
147 changes: 38 additions & 109 deletions sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,86 +64,6 @@ using namespace llvm::util;
using namespace llvm::vfs;
using namespace jit_compiler;

#ifdef _GNU_SOURCE
#include <dlfcn.h>
static char X; // Dummy symbol, used as an anchor for `dlinfo` below.
#endif

#ifdef _WIN32
#include <filesystem> // For std::filesystem::path ( C++17 only )
#include <shlwapi.h> // For PathRemoveFileSpec
#include <windows.h> // For GetModuleFileName, HMODULE, DWORD, MAX_PATH

// cribbed from sycl/source/detail/os_util.cpp
using OSModuleHandle = intptr_t;
static constexpr OSModuleHandle ExeModuleHandle = -1;
static OSModuleHandle getOSModuleHandle(const void *VirtAddr) {
HMODULE PhModule;
DWORD Flag = GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT;
auto LpModuleAddr = reinterpret_cast<LPCSTR>(VirtAddr);
if (!GetModuleHandleExA(Flag, LpModuleAddr, &PhModule)) {
// Expect the caller to check for zero and take
// necessary action
return 0;
}
if (PhModule == GetModuleHandleA(nullptr))
return ExeModuleHandle;
return reinterpret_cast<OSModuleHandle>(PhModule);
}

// cribbed from sycl/source/detail/os_util.cpp
/// Returns an absolute path where the object was found.
std::wstring getCurrentDSODir() {
wchar_t Path[MAX_PATH];
auto Handle = getOSModuleHandle(reinterpret_cast<void *>(&getCurrentDSODir));
DWORD Ret = GetModuleFileName(
reinterpret_cast<HMODULE>(ExeModuleHandle == Handle ? 0 : Handle), Path,
MAX_PATH);
assert(Ret < MAX_PATH && "Path is longer than MAX_PATH?");
assert(Ret > 0 && "GetModuleFileName failed");
(void)Ret;

BOOL RetCode = PathRemoveFileSpec(Path);
assert(RetCode && "PathRemoveFileSpec failed");
(void)RetCode;

return Path;
}
#endif // _WIN32

static constexpr auto InvalidDPCPPRoot = "<invalid>";

static const std::string &getDPCPPRoot() {
thread_local std::string DPCPPRoot;

if (!DPCPPRoot.empty()) {
return DPCPPRoot;
}
DPCPPRoot = InvalidDPCPPRoot;

#ifdef _GNU_SOURCE
static constexpr auto JITLibraryPathSuffix = "/lib/libsycl-jit.so";
Dl_info Info;
if (dladdr(&X, &Info)) {
std::string LoadedLibraryPath = Info.dli_fname;
auto Pos = LoadedLibraryPath.rfind(JITLibraryPathSuffix);
if (Pos != std::string::npos) {
DPCPPRoot = LoadedLibraryPath.substr(0, Pos);
}
}
#endif // _GNU_SOURCE

#ifdef _WIN32
DPCPPRoot = std::filesystem::path(getCurrentDSODir()).parent_path().string();
#endif // _WIN32

// TODO: Implemenent other means of determining the DPCPP root, e.g.
// evaluating the `CMPLR_ROOT` env.

return DPCPPRoot;
}

namespace {

class HashPreprocessedAction : public PreprocessorFrontendAction {
Expand Down Expand Up @@ -252,6 +172,30 @@ class SYCLToolchain {
return TI.run();
}

Expected<ModuleUPtr> loadBitcodeLibrary(StringRef LibPath,
LLVMContext &Context) {
auto FS = llvm::makeIntrusiveRefCnt<llvm::vfs::OverlayFileSystem>(
llvm::vfs::getRealFileSystem());
FS->pushOverlay(ToolchainFS);

auto MemBuf = FS->getBufferForFile(LibPath, /*FileSize*/ -1,
/*RequiresNullTerminator*/ false);
if (!MemBuf) {
return createStringError("Error opening file %s: %s", LibPath.data(),
MemBuf.getError().message().c_str());
}

SMDiagnostic Diag;
ModuleUPtr Lib = parseIR(*MemBuf->get(), Diag, Context);
if (!Lib) {
std::string DiagMsg;
raw_string_ostream SOS(DiagMsg);
Diag.print(/*ProgName=*/nullptr, SOS);
return createStringError(DiagMsg);
}
return std::move(Lib);
}

std::string_view getClangXXExe() const { return ClangXXExe; }

private:
Expand Down Expand Up @@ -516,30 +460,12 @@ static bool getDeviceLibraries(const ArgList &Args,
return FoundUnknownLib;
}

static Expected<ModuleUPtr> loadBitcodeLibrary(StringRef LibPath,
LLVMContext &Context) {
SMDiagnostic Diag;
ModuleUPtr Lib = parseIRFile(LibPath, Diag, Context);
if (!Lib) {
std::string DiagMsg;
raw_string_ostream SOS(DiagMsg);
Diag.print(/*ProgName=*/nullptr, SOS);
return createStringError(DiagMsg);
}
return std::move(Lib);
}

Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
const InputArgList &UserArgList,
std::string &BuildLog,
BinaryFormat Format) {
TimeTraceScope TTS{"linkDeviceLibraries"};

const std::string &DPCPPRoot = getDPCPPRoot();
if (DPCPPRoot == InvalidDPCPPRoot) {
return createStringError("Could not locate DPCPP root directory");
}

IntrusiveRefCntPtr<DiagnosticIDs> DiagID{new DiagnosticIDs};
DiagnosticOptions DiagOpts;
ClangDiagnosticWrapper Wrapper(BuildLog, &DiagOpts);
Expand Down Expand Up @@ -573,10 +499,13 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,

LLVMContext &Context = Module.getContext();
for (const std::string &LibName : LibNames) {
std::string LibPath = DPCPPRoot + "/lib/" + LibName;
std::string LibPath =
(jit_compiler::ToolchainPrefix + "/lib/" + LibName).str();

ModuleUPtr LibModule;
if (auto Error = loadBitcodeLibrary(LibPath, Context).moveInto(LibModule)) {
if (auto Error = SYCLToolchain::instance()
.loadBitcodeLibrary(LibPath, Context)
.moveInto(LibModule)) {
return Error;
}

Expand All @@ -590,14 +519,16 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
// For GPU targets we need to link against vendor provided libdevice.
if (IsCudaHIP) {
Triple T{Module.getTargetTriple()};
Driver D{(Twine(DPCPPRoot) + "/bin/clang++").str(), T.getTriple(), Diags};
Driver D{(jit_compiler::ToolchainPrefix + "/bin/clang++").str(),
T.getTriple(), Diags};
auto [CPU, Features] =
Translator::getTargetCPUAndFeatureAttrs(&Module, "", Format);
(void)Features;
// Helper lambda to link modules.
auto LinkInLib = [&](const StringRef LibDevice) -> Error {
ModuleUPtr LibDeviceModule;
if (auto Error = loadBitcodeLibrary(LibDevice, Context)
if (auto Error = SYCLToolchain::instance()
.loadBitcodeLibrary(LibDevice, Context)
.moveInto(LibDeviceModule)) {
return Error;
}
Expand Down Expand Up @@ -831,16 +762,14 @@ jit_compiler::performPostLink(ModuleUPtr Module,
}

if (IsBF16DeviceLibUsed) {
const std::string &DPCPPRoot = getDPCPPRoot();
if (DPCPPRoot == InvalidDPCPPRoot) {
return createStringError("Could not locate DPCPP root directory");
}

auto &Ctx = Modules.front()->getContext();
auto WrapLibraryInDevImg = [&](const std::string &LibName) -> Error {
std::string LibPath = DPCPPRoot + "/lib/" + LibName;
std::string LibPath =
(jit_compiler::ToolchainPrefix + "/lib/" + LibName).str();
ModuleUPtr LibModule;
if (auto Error = loadBitcodeLibrary(LibPath, Ctx).moveInto(LibModule)) {
if (auto Error = SYCLToolchain::instance()
.loadBitcodeLibrary(LibPath, Ctx)
.moveInto(LibModule)) {
return Error;
}

Expand Down
24 changes: 17 additions & 7 deletions sycl-jit/jit-compiler/utils/generate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import argparse
import glob


def main():
Expand Down Expand Up @@ -31,12 +32,9 @@ def main():
const std::pair<std::string_view, std::string_view> ToolchainFiles[] = {"""
)

def process_dir(dir):
for root, _, files in os.walk(dir):
for file in files:
file_path = os.path.join(root, file)
out.write(
f"""
def process_file(file_path):
out.write(
f"""
{{
{{"{args.prefix}{os.path.relpath(file_path, toolchain_dir).replace(os.sep, "/")}"}} ,
[]() {{
Expand All @@ -46,12 +44,24 @@ def process_dir(dir):
return std::string_view(data, std::size(data) - 1);
}}()
}},"""
)
)

def process_dir(dir):
for root, _, files in os.walk(dir):
for file in files:
file_path = os.path.join(root, file)
process_file(file_path)

process_dir(os.path.join(args.toolchain_dir, "include/"))
process_dir(os.path.join(args.toolchain_dir, "lib/clang/"))
process_dir(os.path.join(args.toolchain_dir, "lib/clc/"))

for file in glob.iglob(
"*.bc", root_dir=os.path.join(args.toolchain_dir, "lib")
):
file_path = os.path.join(args.toolchain_dir, "lib", file)
process_file(file_path)

out.write(
f"""
}};
Expand Down