Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions sycl-jit/common/include/Kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <cstdint>
#include <cstring>
#include <functional>
#include <string_view>
#include <type_traits>

namespace jit_compiler {
Expand Down Expand Up @@ -350,11 +351,60 @@ struct SYCLKernelInfo {
: Name{KernelName}, Args{NumArgs}, Attributes{}, NDR{}, BinaryInfo{} {}
};

// RTC-related datastructures
// TODO: Consider moving into separate header.

struct InMemoryFile {
const char *Path;
const char *Contents;
};

using RTCBundleBinaryInfo = SYCLKernelBinaryInfo;
using FrozenSymbolTable = DynArray<sycl::detail::string>;

// Note: `FrozenPropertyValue` and `FrozenPropertySet` constructors take
// `std::string_view` arguments instead of `const char *` because they will be
// created from `llvm::SmallString`s, which don't contain the trailing '\0'
// byte. Hence obtaining a C-string would cause an additional copy.

struct FrozenPropertyValue {
sycl::detail::string Name;
bool IsUIntValue;
uint32_t UIntValue;
DynArray<uint8_t> Bytes;

FrozenPropertyValue() = default;
FrozenPropertyValue(FrozenPropertyValue &&) = default;
FrozenPropertyValue &operator=(FrozenPropertyValue &&) = default;

FrozenPropertyValue(std::string_view Name, uint32_t Value)
: Name{Name}, IsUIntValue{true}, UIntValue{Value}, Bytes{0} {}
FrozenPropertyValue(std::string_view Name, const uint8_t *Ptr, size_t Size)
: Name{Name}, IsUIntValue{false}, Bytes{Size} {
std::memcpy(Bytes.begin(), Ptr, Size);
}
};

struct FrozenPropertySet {
sycl::detail::string Name;
DynArray<FrozenPropertyValue> Values;

FrozenPropertySet() = default;
FrozenPropertySet(FrozenPropertySet &&) = default;
FrozenPropertySet &operator=(FrozenPropertySet &&) = default;

FrozenPropertySet(std::string_view Name, size_t Size)
: Name{Name}, Values{Size} {}
};

using FrozenPropertyRegistry = DynArray<FrozenPropertySet>;

struct RTCBundleInfo {
RTCBundleBinaryInfo BinaryInfo;
FrozenSymbolTable SymbolTable;
FrozenPropertyRegistry Properties;
};

} // namespace jit_compiler

#endif // SYCL_FUSION_COMMON_KERNEL_H
2 changes: 2 additions & 0 deletions sycl-jit/jit-compiler/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_llvm_library(sycl-jit
lib/fusion/JITContext.cpp
lib/fusion/ModuleHelper.cpp
lib/rtc/DeviceCompilation.cpp
lib/rtc/PostLinkActions.cpp
lib/helper/ConfigHelper.cpp

SHARED
Expand All @@ -31,6 +32,7 @@ add_llvm_library(sycl-jit
Target
TargetParser
MC
SYCLLowerIR
${LLVM_TARGETS_TO_BUILD}

LINK_LIBS
Expand Down
28 changes: 27 additions & 1 deletion sycl-jit/jit-compiler/include/KernelFusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,32 @@ class JITResult {
sycl::detail::string ErrorMessage;
};

class RTCResult {
public:
explicit RTCResult(const char *ErrorMessage)
: Failed{true}, BundleInfo{}, ErrorMessage{ErrorMessage} {}

explicit RTCResult(RTCBundleInfo &&BundleInfo)
: Failed{false}, BundleInfo{std::move(BundleInfo)}, ErrorMessage{} {}

bool failed() const { return Failed; }

const char *getErrorMessage() const {
assert(failed() && "No error message present");
return ErrorMessage.c_str();
}

const RTCBundleInfo &getBundleInfo() const {
assert(!failed() && "No bundle info");
return BundleInfo;
}

private:
bool Failed;
RTCBundleInfo BundleInfo;
sycl::detail::string ErrorMessage;
};

extern "C" {

#ifdef __clang__
Expand All @@ -77,7 +103,7 @@ KF_EXPORT_SYMBOL JITResult materializeSpecConstants(
const char *KernelName, jit_compiler::SYCLKernelBinaryInfo &BinInfo,
View<unsigned char> SpecConstBlob);

KF_EXPORT_SYMBOL JITResult compileSYCL(InMemoryFile SourceFile,
KF_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
View<InMemoryFile> IncludeFiles,
View<const char *> UserArgs);

Expand Down
47 changes: 33 additions & 14 deletions sycl-jit/jit-compiler/lib/KernelFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ using namespace jit_compiler;
using FusedFunction = helper::FusionHelper::FusedFunction;
using FusedFunctionList = std::vector<FusedFunction>;

static JITResult errorToFusionResult(llvm::Error &&Err,
const std::string &Msg) {
template <typename ResultType>
static ResultType wrapError(llvm::Error &&Err, const std::string &Msg) {
std::stringstream ErrMsg;
ErrMsg << Msg << "\nDetailed information:\n";
llvm::handleAllErrors(std::move(Err),
Expand All @@ -35,7 +35,16 @@ static JITResult errorToFusionResult(llvm::Error &&Err,
// compiled without exception support.
ErrMsg << "\t" << StrErr.getMessage() << "\n";
});
return JITResult{ErrMsg.str().c_str()};
return ResultType{ErrMsg.str().c_str()};
}

static JITResult errorToFusionResult(llvm::Error &&Err,
const std::string &Msg) {
return wrapError<JITResult>(std::move(Err), Msg);
}

static RTCResult errorToRTCResult(llvm::Error &&Err, const std::string &Msg) {
return wrapError<RTCResult>(std::move(Err), Msg);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we use wrapError directly (maybe with a different name)?

So, instead of errorToFusionResult or errorToRTCResult, we would call errorToResult<JITResult> and errorToResult<RTCResult>?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

errorToResult<T> makes sense, I initially decided against that to keep the diff a bit cleaner.


static std::vector<jit_compiler::NDRange>
Expand Down Expand Up @@ -234,37 +243,47 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
return JITResult{FusedKernelInfo};
}

extern "C" KF_EXPORT_SYMBOL JITResult
extern "C" KF_EXPORT_SYMBOL RTCResult
compileSYCL(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
View<const char *> UserArgs) {
auto UserArgListOrErr = parseUserArgs(UserArgs);
if (!UserArgListOrErr) {
return errorToFusionResult(UserArgListOrErr.takeError(),
"Parsing of user arguments failed");
return errorToRTCResult(UserArgListOrErr.takeError(),
"Parsing of user arguments failed");
}
llvm::opt::InputArgList UserArgList = std::move(*UserArgListOrErr);

auto ModuleOrErr = compileDeviceCode(SourceFile, IncludeFiles, UserArgList);
if (!ModuleOrErr) {
return errorToFusionResult(ModuleOrErr.takeError(),
"Device compilation failed");
return errorToRTCResult(ModuleOrErr.takeError(),
"Device compilation failed");
}

std::unique_ptr<llvm::LLVMContext> Context;
std::unique_ptr<llvm::Module> Module = std::move(*ModuleOrErr);
Context.reset(&Module->getContext());

if (auto Error = linkDeviceLibraries(*Module, UserArgList)) {
return errorToFusionResult(std::move(Error), "Device linking failed");
return errorToRTCResult(std::move(Error), "Device linking failed");
}

SYCLKernelInfo Kernel;
if (auto Error = translation::KernelTranslator::translateKernel(
Kernel, *Module, JITContext::getInstance(), BinaryFormat::SPIRV)) {
return errorToFusionResult(std::move(Error), "SPIR-V translation failed");
auto BundleInfoOrError = performPostLink(*Module, UserArgList);
if (!BundleInfoOrError) {
return errorToRTCResult(BundleInfoOrError.takeError(),
"Post-link phase failed");
}
auto BundleInfo = std::move(*BundleInfoOrError);

auto BinaryInfoOrError =
translation::KernelTranslator::translateBundleToSPIRV(
*Module, JITContext::getInstance());
if (!BinaryInfoOrError) {
return errorToRTCResult(BinaryInfoOrError.takeError(),
"SPIR-V translation failed");
}
BundleInfo.BinaryInfo = std::move(*BinaryInfoOrError);

return JITResult{Kernel};
return RTCResult{std::move(BundleInfo)};
}

extern "C" KF_EXPORT_SYMBOL void resetJITConfiguration() {
Expand Down
115 changes: 113 additions & 2 deletions sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#include "DeviceCompilation.h"

#include "PostLinkActions.h"

#include <clang/Basic/DiagnosticDriver.h>
#include <clang/Basic/Version.h>
#include <clang/CodeGen/CodeGenAction.h>
Expand All @@ -20,15 +22,22 @@

#include <llvm/IRReader/IRReader.h>
#include <llvm/Linker/Linker.h>

#include <array>
#include <llvm/SYCLLowerIR/ComputeModuleRuntimeInfo.h>
#include <llvm/SYCLLowerIR/ModuleSplitter.h>
#include <llvm/SYCLLowerIR/SYCLJointMatrixTransform.h>
#include <llvm/Support/PropertySetIO.h>

using namespace clang;
using namespace clang::tooling;
using namespace clang::driver;
using namespace clang::driver::options;
using namespace llvm;
using namespace llvm::opt;
using namespace llvm::sycl;
using namespace llvm::module_split;
using namespace llvm::util;
using namespace jit_compiler;
using namespace jit_compiler::post_link;

#ifdef _GNU_SOURCE
#include <dlfcn.h>
Expand Down Expand Up @@ -356,6 +365,96 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
return Error::success();
}

Expected<RTCBundleInfo> jit_compiler::performPostLink(
llvm::Module &Module, [[maybe_unused]] const InputArgList &UserArgList) {
// This is a simplified version of `processInputModule` in
// `llvm/tools/sycl-post-link.cpp`. Assertions/TODOs point to functionality
// left out of the algorithm for now.

// After linking device bitcode "llvm.used" holds references to the kernels
// that are defined in the device image. But after splitting device image into
// separate kernels we may end up with having references to kernel declaration
// originating from "llvm.used" in the IR that is passed to llvm-spirv tool,
// and these declarations cause an assertion in llvm-spirv. To workaround this
// issue remove "llvm.used" from the input module before performing any other
// actions.
removeSYCLKernelsConstRefArray(Module);

// There may be device_global variables kept alive in "llvm.compiler.used"
// to keep the optimizer from wrongfully removing them. llvm.compiler.used
// symbols are usually removed at backend lowering, but this is handled here
// for SPIR-V since SYCL compilation uses llvm-spirv, not the SPIR-V backend.
removeDeviceGlobalFromCompilerUsed(Module);

assert(!isModuleUsingAsan(Module));
// Otherwise: Need to instrument each image scope device globals if the module
// has been instrumented by sanitizer pass.

// Transform Joint Matrix builtin calls to align them with SPIR-V friendly
// LLVM IR specification.
runModulePass<SYCLJointMatrixTransformPass>(Module);

// TODO: Implement actual device code splitting. We're just using the splitter
// to obtain additional information about the module for now.
// TODO: EmitOnlyKernelsAsEntryPoints is controlled by
// `shouldEmitOnlyKernelsAsEntryPoints` in
// `clang/lib/Driver/ToolChains/Clang.cpp`.
std::unique_ptr<ModuleSplitterBase> Splitter = getDeviceCodeSplitter(
ModuleDesc{std::unique_ptr<llvm::Module>{&Module}}, SPLIT_NONE,
/*IROutputOnly=*/false,
/*EmitOnlyKernelsAsEntryPoints=*/true);
bool SplitOccurred = Splitter->remainingSplits() > 1;
assert(!SplitOccurred);

// TODO: Call `verifyNoCrossModuleDeviceGlobalUsage` if device globals shall
// be processed.

assert(Splitter->hasMoreSplits());
ModuleDesc MDesc = Splitter->nextSplit();
assert(&Module == &MDesc.getModule());
MDesc.saveSplitInformationAsMetadata();

RTCBundleInfo BundleInfo;
BundleInfo.SymbolTable =
decltype(BundleInfo.SymbolTable){MDesc.entries().size()};
transform(MDesc.entries(), BundleInfo.SymbolTable.begin(),
[](Function *F) { return F->getName(); });

// TODO: Determine what is requested.
GlobalBinImageProps PropReq{
/*EmitKernelParamInfo=*/true, /*EmitProgramMetadata=*/true,
/*EmitExportedSymbols=*/true, /*EmitImportedSymbols=*/true,
/*DeviceGlobals=*/false};
PropertySetRegistry Properties =
computeModuleProperties(MDesc.getModule(), MDesc.entries(), PropReq);
// TODO: Manually add `compile_target` property as in
// `saveModuleProperties`?
const auto &PropertySets = Properties.getPropSets();

BundleInfo.Properties = decltype(BundleInfo.Properties){PropertySets.size()};
for (auto &&[KV, FrozenPropSet] : zip(PropertySets, BundleInfo.Properties)) {
const auto &PropertySetName = KV.first;
const auto &PropertySet = KV.second;
FrozenPropSet =
FrozenPropertySet{PropertySetName.str(), PropertySet.size()};
for (auto &&[KV2, FrozenProp] : zip(PropertySet, FrozenPropSet.Values)) {
const auto &PropertyName = KV2.first;
const auto &PropertyValue = KV2.second;
FrozenProp = PropertyValue.getType() == PropertyValue::Type::UINT32
? FrozenPropertyValue{PropertyName.str(),
PropertyValue.asUint32()}
: FrozenPropertyValue{
PropertyName.str(), PropertyValue.asRawByteArray(),
PropertyValue.getRawByteArraySize()};
}
};

// Regain ownership of the module.
MDesc.releaseModulePtr().release();

return BundleInfo;
}

Expected<InputArgList>
jit_compiler::parseUserArgs(View<const char *> UserArgs) {
unsigned MissingArgIndex, MissingArgCount;
Expand Down Expand Up @@ -410,5 +509,17 @@ jit_compiler::parseUserArgs(View<const char *> UserArgs) {
}
}

if (auto DCSMode = AL.getLastArgValue(OPT_fsycl_device_code_split_EQ, "none");
DCSMode != "none" && DCSMode != "auto") {
return createStringError("Device code splitting is not yet supported");
}

if (AL.hasArg(OPT_fsycl_device_code_split_esimd,
OPT_fno_sycl_device_code_split_esimd)) {
// TODO: There are more ESIMD-related options.
return createStringError(
"Runtime compilation of ESIMD kernels is not yet supported");
}

return Expected<InputArgList>{std::move(AL)};
}
4 changes: 4 additions & 0 deletions sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ compileDeviceCode(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
llvm::Error linkDeviceLibraries(llvm::Module &Module,
const llvm::opt::InputArgList &UserArgList);

llvm::Expected<RTCBundleInfo>
performPostLink(llvm::Module &Module,
const llvm::opt::InputArgList &UserArgList);

llvm::Expected<llvm::opt::InputArgList>
parseUserArgs(View<const char *> UserArgs);

Expand Down
Loading
Loading