Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions sycl-jit/jit-compiler/include/KernelFusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,13 @@ class RTCResult {
explicit RTCResult(const char *ErrorMessage)
: Failed{true}, BundleInfo{}, ErrorMessage{ErrorMessage} {}

explicit RTCResult(RTCBundleInfo &&BundleInfo)
: Failed{false}, BundleInfo{std::move(BundleInfo)}, ErrorMessage{} {}
RTCResult(RTCBundleInfo &&BundleInfo, const char *BuildLog)
: Failed{false}, BundleInfo{std::move(BundleInfo)},
ErrorMessage{BuildLog} {}

bool failed() const { return Failed; }

const char *getErrorMessage() const {
assert(failed() && "No error message present");
return ErrorMessage.c_str();
}
const char *getErrorMessage() const { return ErrorMessage.c_str(); }

const RTCBundleInfo &getBundleInfo() const {
assert(!failed() && "No bundle info");
Expand Down
9 changes: 6 additions & 3 deletions sycl-jit/jit-compiler/lib/KernelFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,10 @@ compileSYCL(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
}
llvm::opt::InputArgList UserArgList = std::move(*UserArgListOrErr);

auto ModuleOrErr = compileDeviceCode(SourceFile, IncludeFiles, UserArgList);
std::string BuildLog;

auto ModuleOrErr =
compileDeviceCode(SourceFile, IncludeFiles, UserArgList, BuildLog);
if (!ModuleOrErr) {
return errorTo<RTCResult>(ModuleOrErr.takeError(),
"Device compilation failed");
Expand All @@ -254,7 +257,7 @@ compileSYCL(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
std::unique_ptr<llvm::Module> Module = std::move(*ModuleOrErr);
Context.reset(&Module->getContext());

if (auto Error = linkDeviceLibraries(*Module, UserArgList)) {
if (auto Error = linkDeviceLibraries(*Module, UserArgList, BuildLog)) {
return errorTo<RTCResult>(std::move(Error), "Device linking failed");
}

Expand All @@ -274,7 +277,7 @@ compileSYCL(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
}
BundleInfo.BinaryInfo = std::move(*BinaryInfoOrError);

return RTCResult{std::move(BundleInfo)};
return RTCResult{std::move(BundleInfo), BuildLog.c_str()};
}

extern "C" KF_EXPORT_SYMBOL void resetJITConfiguration() {
Expand Down
117 changes: 87 additions & 30 deletions sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,15 @@
#include <clang/CodeGen/CodeGenAction.h>
#include <clang/Driver/Compilation.h>
#include <clang/Driver/Options.h>
#include <clang/Frontend/ChainedDiagnosticConsumer.h>
#include <clang/Frontend/CompilerInstance.h>
#include <clang/Frontend/TextDiagnosticBuffer.h>
#include <clang/Frontend/TextDiagnosticPrinter.h>
#include <clang/Tooling/CompilationDatabase.h>
#include <clang/Tooling/Tooling.h>

#include <llvm/IR/DiagnosticInfo.h>
#include <llvm/IR/DiagnosticPrinter.h>
#include <llvm/IR/PassInstrumentation.h>
#include <llvm/IR/PassManager.h>
#include <llvm/IRReader/IRReader.h>
Expand All @@ -27,6 +31,10 @@
#include <llvm/SYCLLowerIR/SYCLJointMatrixTransform.h>
#include <llvm/Support/PropertySetIO.h>

#include <algorithm>
#include <array>
#include <sstream>

using namespace clang;
using namespace clang::tooling;
using namespace clang::driver;
Expand Down Expand Up @@ -132,6 +140,9 @@ struct GetLLVMModuleAction : public ToolAction {
CompilerInstance Compiler(std::move(PCHContainerOps));
Compiler.setInvocation(std::move(Invocation));
Compiler.setFileManager(Files);
// Suppress summary with number of warnings and errors being printed to
// stdout.
Compiler.setVerboseOutputStream(std::make_unique<llvm::raw_null_ostream>());

// Create the compiler's actual diagnostics engine.
Compiler.createDiagnostics(DiagConsumer, /*ShouldOwnClient=*/false);
Expand Down Expand Up @@ -161,12 +172,55 @@ struct GetLLVMModuleAction : public ToolAction {
std::unique_ptr<llvm::Module> Module;
};

class ClangDiagnosticWrapper {

llvm::raw_string_ostream LogStream;

std::unique_ptr<clang::TextDiagnosticPrinter> LogPrinter;

public:
ClangDiagnosticWrapper(std::string &LogString, DiagnosticOptions *DiagOpts)
: LogStream(LogString),
LogPrinter(
std::make_unique<TextDiagnosticPrinter>(LogStream, DiagOpts)) {}

clang::TextDiagnosticPrinter *consumer() { return LogPrinter.get(); }

llvm::raw_ostream &stream() { return LogStream; }
};

class LLVMDiagnosticWrapper : public llvm::DiagnosticHandler {
llvm::raw_string_ostream LogStream;

DiagnosticPrinterRawOStream LogPrinter;

public:
LLVMDiagnosticWrapper(std::string &BuildLog)
: LogStream(BuildLog), LogPrinter(LogStream) {}

bool handleDiagnostics(const DiagnosticInfo &DI) override {
auto Prefix = [](DiagnosticSeverity Severity) -> llvm::StringLiteral {
switch (Severity) {
case llvm::DiagnosticSeverity::DS_Error:
return "ERROR";
case llvm::DiagnosticSeverity::DS_Warning:
return "WARNING";
default:
return "NOTE:";
}
}(DI.getSeverity());
LogPrinter << Prefix;
DI.print(LogPrinter);
LogPrinter << "\n";
return true;
}
};

} // anonymous namespace

Expected<std::unique_ptr<llvm::Module>>
jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
View<InMemoryFile> IncludeFiles,
const InputArgList &UserArgList) {
Expected<std::unique_ptr<llvm::Module>> jit_compiler::compileDeviceCode(
InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
const InputArgList &UserArgList, std::string &BuildLog) {
const std::string &DPCPPRoot = getDPCPPRoot();
if (DPCPPRoot == InvalidDPCPPRoot) {
return createStringError("Could not locate DPCPP root directory");
Expand Down Expand Up @@ -197,12 +251,19 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
FixedCompilationDatabase DB{".", CommandLine};
ClangTool Tool{DB, {SourceFile.Path}};

IntrusiveRefCntPtr<DiagnosticOptions> DiagOpts{new DiagnosticOptions};
ClangDiagnosticWrapper Wrapper(BuildLog, DiagOpts.get());
Tool.setDiagnosticConsumer(Wrapper.consumer());

// Set up in-memory filesystem.
Tool.mapVirtualFile(SourceFile.Path, SourceFile.Contents);
for (const auto &IF : IncludeFiles) {
Tool.mapVirtualFile(IF.Path, IF.Contents);
}

// Suppress message "Error while processing" being printed to stdout.
Tool.setPrintErrorMessage(false);

// Reset argument adjusters to drop the `-fsyntax-only` flag which is added by
// default by this API.
Tool.clearArgumentsAdjusters();
Expand All @@ -222,15 +283,15 @@ jit_compiler::compileDeviceCode(InMemoryFile SourceFile,
return std::move(Action.Module);
}

// TODO: Capture compiler errors from the ClangTool.
return createStringError("Unable to obtain LLVM module");
return createStringError(BuildLog);
}

// This function is a simplified copy of the device library selection process in
// `clang::driver::tools::SYCL::getDeviceLibraries`, assuming a SPIR-V target
// (no AoT, no third-party GPUs, no native CPU). Keep in sync!
static SmallVector<std::string, 8>
getDeviceLibraries(const ArgList &Args, DiagnosticsEngine &Diags) {
static bool getDeviceLibraries(const ArgList &Args,
SmallVectorImpl<std::string> &LibraryList,
DiagnosticsEngine &Diags) {
struct DeviceLibOptInfo {
StringRef DeviceLibName;
StringRef DeviceLibOption;
Expand All @@ -247,6 +308,8 @@ getDeviceLibraries(const ArgList &Args, DiagnosticsEngine &Diags) {
// libraries cannot be affected via -fno-sycl-device-lib.
bool ExcludeDeviceLibs = false;

bool FoundUnknownLib = false;

if (Arg *A = Args.getLastArg(OPT_fsycl_device_lib_EQ,
OPT_fno_sycl_device_lib_EQ)) {
if (A->getValues().size() == 0) {
Expand All @@ -268,6 +331,7 @@ getDeviceLibraries(const ArgList &Args, DiagnosticsEngine &Diags) {
if (LinkInfoIter == DeviceLibLinkInfo.end() || Val == "internal") {
Diags.Report(diag::err_drv_unsupported_option_argument)
<< A->getSpelling() << Val;
FoundUnknownLib = true;
}
DeviceLibLinkInfo[Val] = !ExcludeDeviceLibs;
}
Expand All @@ -292,7 +356,6 @@ getDeviceLibraries(const ArgList &Args, DiagnosticsEngine &Diags) {
{"libsycl-itt-compiler-wrappers", "internal"},
{"libsycl-itt-stubs", "internal"}};

SmallVector<std::string, 8> LibraryList;
StringRef LibSuffix = ".bc";
auto AddLibraries = [&](const SYCLDeviceLibsList &LibsList) {
for (const DeviceLibOptInfo &Lib : LibsList) {
Expand All @@ -312,37 +375,33 @@ getDeviceLibraries(const ArgList &Args, DiagnosticsEngine &Diags) {
AddLibraries(SYCLDeviceAnnotationLibs);
}

return LibraryList;
return FoundUnknownLib;
}

Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
const InputArgList &UserArgList) {
const InputArgList &UserArgList,
std::string &BuildLog) {
const std::string &DPCPPRoot = getDPCPPRoot();
if (DPCPPRoot == InvalidDPCPPRoot) {
return createStringError("Could not locate DPCPP root directory");
}

// TODO: Seems a bit excessive to set up this machinery for one warning and
// one error. Rethink when implementing the build log/error reporting as
// mandated by the extension.
IntrusiveRefCntPtr<DiagnosticIDs> DiagID{new DiagnosticIDs};
IntrusiveRefCntPtr<DiagnosticOptions> DiagOpts{new DiagnosticOptions};
TextDiagnosticBuffer *DiagBuffer = new TextDiagnosticBuffer;
DiagnosticsEngine Diags(DiagID, DiagOpts, DiagBuffer);

auto LibNames = getDeviceLibraries(UserArgList, Diags);
if (std::distance(DiagBuffer->err_begin(), DiagBuffer->err_end()) > 0) {
std::string DiagMsg;
raw_string_ostream SOS{DiagMsg};
interleave(
DiagBuffer->err_begin(), DiagBuffer->err_end(),
[&](const auto &D) { SOS << D.second; }, [&]() { SOS << '\n'; });
ClangDiagnosticWrapper Wrapper(BuildLog, DiagOpts.get());
DiagnosticsEngine Diags(DiagID, DiagOpts, Wrapper.consumer(),
/* ShouldOwnClient=*/false);

SmallVector<std::string> LibNames;
bool FoundUnknownLib = getDeviceLibraries(UserArgList, LibNames, Diags);
if (FoundUnknownLib) {
return createStringError("Could not determine list of device libraries: %s",
DiagMsg.c_str());
BuildLog.c_str());
}
// TODO: Add warnings to build log.

LLVMContext &Context = Module.getContext();
Context.setDiagnosticHandler(
std::make_unique<LLVMDiagnosticWrapper>(BuildLog));
for (const std::string &LibName : LibNames) {
std::string LibPath = DPCPPRoot + "/lib/" + LibName;

Expand All @@ -356,10 +415,8 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
}

if (Linker::linkModules(Module, std::move(Lib), Linker::LinkOnlyNeeded)) {
// TODO: Obtain detailed error message from the context's diagnostics
// handler.
return createStringError("Unable to link device library: %s",
LibPath.c_str());
return createStringError("Unable to link device library %s: %s",
LibPath.c_str(), BuildLog.c_str());
}
}

Expand Down
7 changes: 5 additions & 2 deletions sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@
#include <llvm/Support/Error.h>

#include <memory>
#include <string>

namespace jit_compiler {

llvm::Expected<std::unique_ptr<llvm::Module>>
compileDeviceCode(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
const llvm::opt::InputArgList &UserArgList);
const llvm::opt::InputArgList &UserArgList,
std::string &BuildLog);

llvm::Error linkDeviceLibraries(llvm::Module &Module,
const llvm::opt::InputArgList &UserArgList);
const llvm::opt::InputArgList &UserArgList,
std::string &BuildLog);

llvm::Expected<RTCBundleInfo>
performPostLink(llvm::Module &Module,
Expand Down
7 changes: 4 additions & 3 deletions sycl/source/detail/jit_compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1242,13 +1242,14 @@ sycl_device_binaries jit_compiler::compileSYCL(

auto Result = CompileSYCLHandle(SourceFile, IncludeFilesView, UserArgsView);

if (LogPtr) {
LogPtr->append(Result.getErrorMessage());
}

if (Result.failed()) {
throw sycl::exception(sycl::errc::build, Result.getErrorMessage());
}

// TODO: We currently don't have a meaningful build log.
(void)LogPtr;

return createDeviceBinaryImage(Result.getBundleInfo());
}

Expand Down
Loading
Loading