Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sycl-jit/jit-compiler/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ add_llvm_library(sycl-jit
BitReader
Core
Support
Option
Analysis
IPO
TransformUtils
Passes
IRReader
Linker
ScalarOpts
InstCombine
Expand Down
26 changes: 17 additions & 9 deletions sycl-jit/jit-compiler/lib/KernelFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,22 +237,30 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
extern "C" KF_EXPORT_SYMBOL JITResult
compileSYCL(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
View<const char *> UserArgs) {
auto ModuleOrErr = compileDeviceCode(SourceFile, IncludeFiles, UserArgs);
auto UserArgListOrErr = parseUserArgs(UserArgs);
if (!UserArgListOrErr) {
return errorToFusionResult(UserArgListOrErr.takeError(),
"Parsing of user arguments failed");
}
llvm::opt::InputArgList UserArgList = std::move(*UserArgListOrErr);

auto ModuleOrErr = compileDeviceCode(SourceFile, IncludeFiles, UserArgList);
if (!ModuleOrErr) {
return errorToFusionResult(ModuleOrErr.takeError(),
"Device compilation failed");
}
std::unique_ptr<llvm::Module> Module = std::move(*ModuleOrErr);

SYCLKernelInfo Kernel;
auto Error = translation::KernelTranslator::translateKernel(
Kernel, *Module, JITContext::getInstance(), BinaryFormat::SPIRV);
std::unique_ptr<llvm::LLVMContext> Context;
std::unique_ptr<llvm::Module> Module = std::move(*ModuleOrErr);
Context.reset(&Module->getContext());

auto *LLVMCtx = &Module->getContext();
Module.reset();
delete LLVMCtx;
if (auto Error = linkDeviceLibraries(*Module, UserArgList)) {
return errorToFusionResult(std::move(Error), "Device linking failed");
}

if (Error) {
SYCLKernelInfo Kernel;
if (auto Error = translation::KernelTranslator::translateKernel(
Kernel, *Module, JITContext::getInstance(), BinaryFormat::SPIRV)) {
return errorToFusionResult(std::move(Error), "SPIR-V translation failed");
}

Expand Down
258 changes: 239 additions & 19 deletions sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,28 @@

#include "DeviceCompilation.h"

#include <clang/Basic/DiagnosticDriver.h>
#include <clang/Basic/Version.h>
#include <clang/CodeGen/CodeGenAction.h>
#include <clang/Driver/Compilation.h>
#include <clang/Driver/Options.h>
#include <clang/Frontend/CompilerInstance.h>
#include <clang/Frontend/TextDiagnosticBuffer.h>
#include <clang/Tooling/CompilationDatabase.h>
#include <clang/Tooling/Tooling.h>

#include <llvm/IRReader/IRReader.h>
#include <llvm/Linker/Linker.h>

#include <array>

using namespace clang;
using namespace clang::tooling;
using namespace clang::driver;
using namespace clang::driver::options;
using namespace llvm;
using namespace llvm::opt;

#ifdef _GNU_SOURCE
#include <dlfcn.h>
static char X; // Dummy symbol, used as an anchor for `dlinfo` below.
Expand Down Expand Up @@ -96,9 +111,6 @@ static const std::string &getDPCPPRoot() {
}

namespace {
using namespace clang;
using namespace clang::tooling;
using namespace clang::driver;

struct GetLLVMModuleAction : public ToolAction {
// Code adapted from `FrontendActionFactory::runInvocation`.
Expand Down Expand Up @@ -143,23 +155,37 @@ struct GetLLVMModuleAction : public ToolAction {

} // anonymous namespace

llvm::Expected<std::unique_ptr<llvm::Module>>
Expected<std::unique_ptr<llvm::Module>>
jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
View<InMemoryFile> IncludeFiles,
View<const char *> UserArgs) {
const InputArgList &UserArgList) {
const std::string &DPCPPRoot = getDPCPPRoot();
if (DPCPPRoot == InvalidDPCPPRoot) {
return llvm::createStringError("Could not locate DPCPP root directory");
return createStringError("Could not locate DPCPP root directory");
}

SmallVector<std::string> CommandLine = {"-fsycl-device-only"};
// TODO: Allow instrumentation again when device library linking is
// implemented.
CommandLine.push_back("-fno-sycl-instrument-device-code");
Comment on lines -156 to -158
Copy link
Contributor Author

@jopperm jopperm Oct 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NB: This PR is tested by the existing E2E test, which now works with device instrumentation enabled.

Copy link
Contributor Author

@jopperm jopperm Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Device instrumentation was disabled by default recently (#14910); I added the flag to one of the E2E test's user-supplied arguments.

CommandLine.append(UserArgs.begin(), UserArgs.end());
clang::tooling::FixedCompilationDatabase DB{".", CommandLine};
DerivedArgList DAL{UserArgList};
const auto &OptTable = getDriverOptTable();
DAL.AddFlagArg(nullptr, OptTable.getOption(OPT_fsycl_device_only));
DAL.AddJoinedArg(
nullptr, OptTable.getOption(OPT_resource_dir_EQ),
(DPCPPRoot + "/lib/clang/" + Twine(CLANG_VERSION_MAJOR)).str());
for (auto *Arg : UserArgList) {
DAL.append(Arg);
}
// Remove args that will trigger an unused command line argument warning for
// the FrontendAction invocation, but are handled later (e.g. during device
// linking).
DAL.eraseArg(OPT_fsycl_device_lib_EQ);
DAL.eraseArg(OPT_fno_sycl_device_lib_EQ);

SmallVector<std::string> CommandLine;
for (auto *Arg : DAL) {
CommandLine.emplace_back(Arg->getAsString(DAL));
}

clang::tooling::ClangTool Tool{DB, {SourceFile.Path}};
FixedCompilationDatabase DB{".", CommandLine};
ClangTool Tool{DB, {SourceFile.Path}};

// Set up in-memory filesystem.
Tool.mapVirtualFile(SourceFile.Path, SourceFile.Contents);
Expand All @@ -170,17 +196,14 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
// Reset argument adjusters to drop the `-fsyntax-only` flag which is added by
// default by this API.
Tool.clearArgumentsAdjusters();
// Then, modify argv[0] and set the resource directory so that the driver
// picks up the correct SYCL environment.
// Then, modify argv[0] so that the driver picks up the correct SYCL
// environment. We've already set the resource directory above.
Tool.appendArgumentsAdjuster(
[&DPCPPRoot](const CommandLineArguments &Args,
StringRef Filename) -> CommandLineArguments {
(void)Filename;
CommandLineArguments NewArgs = Args;
NewArgs[0] = (Twine(DPCPPRoot) + "/bin/clang++").str();
NewArgs.push_back((Twine("-resource-dir=") + DPCPPRoot + "/lib/clang/" +
Twine(CLANG_VERSION_MAJOR))
.str());
return NewArgs;
});

Expand All @@ -190,5 +213,202 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
}

// TODO: Capture compiler errors from the ClangTool.
return llvm::createStringError("Unable to obtain LLVM module");
return createStringError("Unable to obtain LLVM module");
}

// This function is a simplified copy of the device library selection process in
// `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V target
// (no AoT, no third-party GPUs, no native CPU). Keep in sync!
static SmallVector<std::string, 8>
getDeviceLibraries(const ArgList &Args, DiagnosticsEngine &Diags) {
struct DeviceLibOptInfo {
StringRef DeviceLibName;
StringRef DeviceLibOption;
};

// Currently, all SYCL device libraries will be linked by default.
llvm::StringMap<bool> DeviceLibLinkInfo = {
{"libc", true}, {"libm-fp32", true}, {"libm-fp64", true},
{"libimf-fp32", true}, {"libimf-fp64", true}, {"libimf-bf16", true},
{"libm-bfloat16", true}, {"internal", true}};

// If -fno-sycl-device-lib is specified, its values will be used to exclude
// linkage of libraries specified by DeviceLibLinkInfo. Linkage of "internal"
// libraries cannot be affected via -fno-sycl-device-lib.
bool ExcludeDeviceLibs = false;

if (Arg *A = Args.getLastArg(OPT_fsycl_device_lib_EQ,
OPT_fno_sycl_device_lib_EQ)) {
if (A->getValues().size() == 0) {
Diags.Report(diag::warn_drv_empty_joined_argument)
<< A->getAsString(Args);
} else {
if (A->getOption().matches(OPT_fno_sycl_device_lib_EQ)) {
ExcludeDeviceLibs = true;
}

for (StringRef Val : A->getValues()) {
if (Val == "all") {
for (const auto &K : DeviceLibLinkInfo.keys()) {
DeviceLibLinkInfo[K] = (K == "internal") || !ExcludeDeviceLibs;
}
break;
}
auto LinkInfoIter = DeviceLibLinkInfo.find(Val);
if (LinkInfoIter == DeviceLibLinkInfo.end() || Val == "internal") {
Diags.Report(diag::err_drv_unsupported_option_argument)
<< A->getSpelling() << Val;
}
DeviceLibLinkInfo[Val] = !ExcludeDeviceLibs;
}
}
}

using SYCLDeviceLibsList = SmallVector<DeviceLibOptInfo, 5>;

const SYCLDeviceLibsList SYCLDeviceWrapperLibs = {
{"libsycl-crt", "libc"},
{"libsycl-complex", "libm-fp32"},
{"libsycl-complex-fp64", "libm-fp64"},
{"libsycl-cmath", "libm-fp32"},
{"libsycl-cmath-fp64", "libm-fp64"},
{"libsycl-imf", "libimf-fp32"},
{"libsycl-imf-fp64", "libimf-fp64"},
{"libsycl-imf-bf16", "libimf-bf16"}};
// ITT annotation libraries are linked in separately whenever the device
// code instrumentation is enabled.
const SYCLDeviceLibsList SYCLDeviceAnnotationLibs = {
{"libsycl-itt-user-wrappers", "internal"},
{"libsycl-itt-compiler-wrappers", "internal"},
{"libsycl-itt-stubs", "internal"}};

SmallVector<std::string, 8> LibraryList;
StringRef LibSuffix = ".bc";
auto AddLibraries = [&](const SYCLDeviceLibsList &LibsList) {
for (const DeviceLibOptInfo &Lib : LibsList) {
if (!DeviceLibLinkInfo[Lib.DeviceLibOption]) {
continue;
}
SmallString<128> LibName(Lib.DeviceLibName);
llvm::sys::path::replace_extension(LibName, LibSuffix);
LibraryList.push_back(Args.MakeArgString(LibName));
}
};

AddLibraries(SYCLDeviceWrapperLibs);

if (Args.hasFlag(OPT_fsycl_instrument_device_code,
OPT_fno_sycl_instrument_device_code, false)) {
AddLibraries(SYCLDeviceAnnotationLibs);
}

return LibraryList;
}

Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
const InputArgList &UserArgList) {
const std::string &DPCPPRoot = getDPCPPRoot();
if (DPCPPRoot == InvalidDPCPPRoot) {
return createStringError("Could not locate DPCPP root directory");
}

// TODO: Seems a bit excessive to set up this machinery for one warning and
// one error. Rethink when implementing the build log/error reporting as
// mandated by the extension.
IntrusiveRefCntPtr<DiagnosticIDs> DiagID{new DiagnosticIDs};
IntrusiveRefCntPtr<DiagnosticOptions> DiagOpts{new DiagnosticOptions};
TextDiagnosticBuffer *DiagBuffer = new TextDiagnosticBuffer;
DiagnosticsEngine Diags(DiagID, DiagOpts, DiagBuffer);

auto LibNames = getDeviceLibraries(UserArgList, Diags);
if (std::distance(DiagBuffer->err_begin(), DiagBuffer->err_end()) > 0) {
std::string DiagMsg;
raw_string_ostream SOS{DiagMsg};
interleave(
DiagBuffer->err_begin(), DiagBuffer->err_end(),
[&](const auto &D) { SOS << D.second; }, [&]() { SOS << '\n'; });
return createStringError("Could not determine list of device libraries: %s",
DiagMsg.c_str());
}
// TODO: Add warnings to build log.

LLVMContext &Context = Module.getContext();
for (const std::string &LibName : LibNames) {
std::string LibPath = DPCPPRoot + "/lib/" + LibName;

SMDiagnostic Diag;
std::unique_ptr<llvm::Module> Lib = parseIRFile(LibPath, Diag, Context);
if (!Lib) {
std::string DiagMsg;
raw_string_ostream SOS(DiagMsg);
Diag.print(/*ProgName=*/nullptr, SOS);
return createStringError(DiagMsg);
}

if (Linker::linkModules(Module, std::move(Lib), Linker::LinkOnlyNeeded)) {
// TODO: Obtain detailed error message from the context's diagnostics
// handler.
return createStringError("Unable to link device library: %s",
LibPath.c_str());
}
}

return Error::success();
}

Expected<InputArgList>
jit_compiler::parseUserArgs(View<const char *> UserArgs) {
unsigned MissingArgIndex, MissingArgCount;
auto UserArgsRef = UserArgs.to<ArrayRef>();
auto AL = getDriverOptTable().ParseArgs(UserArgsRef, MissingArgIndex,
MissingArgCount);
if (MissingArgCount) {
return createStringError(
"User option '%s' at index %d is missing an argument",
UserArgsRef[MissingArgIndex], MissingArgIndex);
}

// Check for unsupported options.
// TODO: There are probably more, e.g. requesting non-SPIR-V targets.
{
// -fsanitize=address
bool IsDeviceAsanEnabled = false;
if (Arg *A = AL.getLastArg(OPT_fsanitize_EQ, OPT_fno_sanitize_EQ)) {
if (A->getOption().matches(OPT_fsanitize_EQ) &&
A->getValues().size() == 1) {
std::string SanitizeVal = A->getValue();
IsDeviceAsanEnabled = SanitizeVal == "address";
}
} else {
// User can pass -fsanitize=address to device compiler via
// -Xsycl-target-frontend.
auto SyclFEArg = AL.getAllArgValues(OPT_Xsycl_frontend);
IsDeviceAsanEnabled = (std::count(SyclFEArg.begin(), SyclFEArg.end(),
"-fsanitize=address") > 0);
if (!IsDeviceAsanEnabled) {
auto SyclFEArgEq = AL.getAllArgValues(OPT_Xsycl_frontend_EQ);
IsDeviceAsanEnabled =
(std::count(SyclFEArgEq.begin(), SyclFEArgEq.end(),
"-fsanitize=address") > 0);
}

// User can also enable asan for SYCL device via -Xarch_device option.
if (!IsDeviceAsanEnabled) {
auto DeviceArchVals = AL.getAllArgValues(OPT_Xarch_device);
for (auto DArchVal : DeviceArchVals) {
if (DArchVal.find("-fsanitize=address") != std::string::npos) {
IsDeviceAsanEnabled = true;
break;
}
}
}
}

if (IsDeviceAsanEnabled) {
return createStringError(
"Device ASAN is not supported for runtime compilation");
}
}

return Expected<InputArgList>{std::move(AL)};
}
9 changes: 8 additions & 1 deletion sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "View.h"

#include <llvm/IR/Module.h>
#include <llvm/Option/ArgList.h>
#include <llvm/Support/Error.h>

#include <memory>
Expand All @@ -21,7 +22,13 @@ namespace jit_compiler {

llvm::Expected<std::unique_ptr<llvm::Module>>
compileDeviceCode(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
View<const char *> UserArgs);
const llvm::opt::InputArgList &UserArgList);

llvm::Error linkDeviceLibraries(llvm::Module &Module,
const llvm::opt::InputArgList &UserArgList);

llvm::Expected<llvm::opt::InputArgList>
parseUserArgs(View<const char *> UserArgs);

} // namespace jit_compiler

Expand Down
Loading
Loading