Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions sycl-jit/common/include/Kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <cstdint>
#include <cstring>
#include <functional>
#include <string_view>
#include <type_traits>

namespace jit_compiler {
Expand Down Expand Up @@ -350,11 +351,60 @@ struct SYCLKernelInfo {
: Name{KernelName}, Args{NumArgs}, Attributes{}, NDR{}, BinaryInfo{} {}
};

// RTC-related datastructures
// TODO: Consider moving into separate header.

struct InMemoryFile {
const char *Path;
const char *Contents;
};

using RTCBundleBinaryInfo = SYCLKernelBinaryInfo;
using FrozenSymbolTable = DynArray<sycl::detail::string>;

// Note: `FrozenPropertyValue` and `FrozenPropertySet` constructors take
// `std::string_view` arguments instead of `const char *` because they will be
// created from `llvm::SmallString`s, which don't contain the trailing '\0'
// byte. Hence obtaining a C-string would cause an additional copy.

struct FrozenPropertyValue {
sycl::detail::string Name;
bool IsUIntValue;
uint32_t UIntValue;
DynArray<uint8_t> Bytes;

FrozenPropertyValue() = default;
FrozenPropertyValue(FrozenPropertyValue &&) = default;
FrozenPropertyValue &operator=(FrozenPropertyValue &&) = default;

FrozenPropertyValue(std::string_view Name, uint32_t Value)
: Name{Name}, IsUIntValue{true}, UIntValue{Value}, Bytes{0} {}
FrozenPropertyValue(std::string_view Name, const uint8_t *Ptr, size_t Size)
: Name{Name}, IsUIntValue{false}, Bytes{Size} {
std::memcpy(Bytes.begin(), Ptr, Size);
}
};

struct FrozenPropertySet {
sycl::detail::string Name;
DynArray<FrozenPropertyValue> Values;

FrozenPropertySet() = default;
FrozenPropertySet(FrozenPropertySet &&) = default;
FrozenPropertySet &operator=(FrozenPropertySet &&) = default;

FrozenPropertySet(std::string_view Name, size_t Size)
: Name{Name}, Values{Size} {}
};

using FrozenPropertyRegistry = DynArray<FrozenPropertySet>;

struct RTCBundleInfo {
RTCBundleBinaryInfo BinaryInfo;
FrozenSymbolTable SymbolTable;
FrozenPropertyRegistry Properties;
};

} // namespace jit_compiler

#endif // SYCL_FUSION_COMMON_KERNEL_H
2 changes: 2 additions & 0 deletions sycl-jit/jit-compiler/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_llvm_library(sycl-jit
lib/fusion/JITContext.cpp
lib/fusion/ModuleHelper.cpp
lib/rtc/DeviceCompilation.cpp
lib/rtc/PostLinkActions.cpp
lib/helper/ConfigHelper.cpp

SHARED
Expand All @@ -31,6 +32,7 @@ add_llvm_library(sycl-jit
Target
TargetParser
MC
SYCLLowerIR
${LLVM_TARGETS_TO_BUILD}

LINK_LIBS
Expand Down
28 changes: 27 additions & 1 deletion sycl-jit/jit-compiler/include/KernelFusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,32 @@ class JITResult {
sycl::detail::string ErrorMessage;
};

class RTCResult {
public:
explicit RTCResult(const char *ErrorMessage)
: Failed{true}, BundleInfo{}, ErrorMessage{ErrorMessage} {}

explicit RTCResult(RTCBundleInfo &&BundleInfo)
: Failed{false}, BundleInfo{std::move(BundleInfo)}, ErrorMessage{} {}

bool failed() const { return Failed; }

const char *getErrorMessage() const {
assert(failed() && "No error message present");
return ErrorMessage.c_str();
}

const RTCBundleInfo &getBundleInfo() const {
assert(!failed() && "No bundle info");
return BundleInfo;
}

private:
bool Failed;
RTCBundleInfo BundleInfo;
sycl::detail::string ErrorMessage;
};

extern "C" {

#ifdef __clang__
Expand All @@ -77,7 +103,7 @@ KF_EXPORT_SYMBOL JITResult materializeSpecConstants(
const char *KernelName, jit_compiler::SYCLKernelBinaryInfo &BinInfo,
View<unsigned char> SpecConstBlob);

KF_EXPORT_SYMBOL JITResult compileSYCL(InMemoryFile SourceFile,
KF_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
View<InMemoryFile> IncludeFiles,
View<const char *> UserArgs);

Expand Down
56 changes: 33 additions & 23 deletions sycl-jit/jit-compiler/lib/KernelFusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ using namespace jit_compiler;
using FusedFunction = helper::FusionHelper::FusedFunction;
using FusedFunctionList = std::vector<FusedFunction>;

static JITResult errorToFusionResult(llvm::Error &&Err,
const std::string &Msg) {
template <typename ResultType>
static ResultType errorTo(llvm::Error &&Err, const std::string &Msg) {
std::stringstream ErrMsg;
ErrMsg << Msg << "\nDetailed information:\n";
llvm::handleAllErrors(std::move(Err),
Expand All @@ -35,7 +35,7 @@ static JITResult errorToFusionResult(llvm::Error &&Err,
// compiled without exception support.
ErrMsg << "\t" << StrErr.getMessage() << "\n";
});
return JITResult{ErrMsg.str().c_str()};
return ResultType{ErrMsg.str().c_str()};
}

static std::vector<jit_compiler::NDRange>
Expand Down Expand Up @@ -95,7 +95,7 @@ extern "C" KF_EXPORT_SYMBOL JITResult materializeSpecConstants(
translation::KernelTranslator::loadKernels(*JITCtx.getLLVMContext(),
ModuleInfo.kernels());
if (auto Error = ModOrError.takeError()) {
return errorToFusionResult(std::move(Error), "Failed to load kernels");
return errorTo<JITResult>(std::move(Error), "Failed to load kernels");
}
std::unique_ptr<llvm::Module> NewMod = std::move(*ModOrError);
if (!fusion::FusionPipeline::runMaterializerPasses(
Expand All @@ -107,8 +107,8 @@ extern "C" KF_EXPORT_SYMBOL JITResult materializeSpecConstants(
SYCLKernelInfo &MaterializerKernelInfo = *ModuleInfo.getKernelFor(KernelName);
if (auto Error = translation::KernelTranslator::translateKernel(
MaterializerKernelInfo, *NewMod, JITCtx, TargetFormat)) {
return errorToFusionResult(std::move(Error),
"Translation to output format failed");
return errorTo<JITResult>(std::move(Error),
"Translation to output format failed");
}

return JITResult{MaterializerKernelInfo};
Expand All @@ -133,7 +133,7 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
llvm::Expected<jit_compiler::FusedNDRange> FusedNDR =
jit_compiler::FusedNDRange::get(NDRanges);
if (llvm::Error Err = FusedNDR.takeError()) {
return errorToFusionResult(std::move(Err), "Illegal ND-range combination");
return errorTo<JITResult>(std::move(Err), "Illegal ND-range combination");
}

if (!isTargetFormatSupported(TargetFormat)) {
Expand Down Expand Up @@ -180,7 +180,7 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
translation::KernelTranslator::loadKernels(*JITCtx.getLLVMContext(),
ModuleInfo.kernels());
if (auto Error = ModOrError.takeError()) {
return errorToFusionResult(std::move(Error), "SPIR-V translation failed");
return errorTo<JITResult>(std::move(Error), "SPIR-V translation failed");
}
std::unique_ptr<llvm::Module> LLVMMod = std::move(*ModOrError);

Expand All @@ -197,8 +197,8 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
llvm::Expected<std::unique_ptr<llvm::Module>> NewModOrError =
helper::FusionHelper::addFusedKernel(LLVMMod.get(), FusedKernelList);
if (auto Error = NewModOrError.takeError()) {
return errorToFusionResult(std::move(Error),
"Insertion of fused kernel stub failed");
return errorTo<JITResult>(std::move(Error),
"Insertion of fused kernel stub failed");
}
std::unique_ptr<llvm::Module> NewMod = std::move(*NewModOrError);

Expand All @@ -221,8 +221,8 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,

if (auto Error = translation::KernelTranslator::translateKernel(
FusedKernelInfo, *NewMod, JITCtx, TargetFormat)) {
return errorToFusionResult(std::move(Error),
"Translation to output format failed");
return errorTo<JITResult>(std::move(Error),
"Translation to output format failed");
}

FusedKernelInfo.NDR = FusedNDR->getNDR();
Expand All @@ -234,37 +234,47 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
return JITResult{FusedKernelInfo};
}

extern "C" KF_EXPORT_SYMBOL JITResult
extern "C" KF_EXPORT_SYMBOL RTCResult
compileSYCL(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
View<const char *> UserArgs) {
auto UserArgListOrErr = parseUserArgs(UserArgs);
if (!UserArgListOrErr) {
return errorToFusionResult(UserArgListOrErr.takeError(),
"Parsing of user arguments failed");
return errorTo<RTCResult>(UserArgListOrErr.takeError(),
"Parsing of user arguments failed");
}
llvm::opt::InputArgList UserArgList = std::move(*UserArgListOrErr);

auto ModuleOrErr = compileDeviceCode(SourceFile, IncludeFiles, UserArgList);
if (!ModuleOrErr) {
return errorToFusionResult(ModuleOrErr.takeError(),
"Device compilation failed");
return errorTo<RTCResult>(ModuleOrErr.takeError(),
"Device compilation failed");
}

std::unique_ptr<llvm::LLVMContext> Context;
std::unique_ptr<llvm::Module> Module = std::move(*ModuleOrErr);
Context.reset(&Module->getContext());

if (auto Error = linkDeviceLibraries(*Module, UserArgList)) {
return errorToFusionResult(std::move(Error), "Device linking failed");
return errorTo<RTCResult>(std::move(Error), "Device linking failed");
}

SYCLKernelInfo Kernel;
if (auto Error = translation::KernelTranslator::translateKernel(
Kernel, *Module, JITContext::getInstance(), BinaryFormat::SPIRV)) {
return errorToFusionResult(std::move(Error), "SPIR-V translation failed");
auto BundleInfoOrError = performPostLink(*Module, UserArgList);
if (!BundleInfoOrError) {
return errorTo<RTCResult>(BundleInfoOrError.takeError(),
"Post-link phase failed");
}
auto BundleInfo = std::move(*BundleInfoOrError);

auto BinaryInfoOrError =
translation::KernelTranslator::translateBundleToSPIRV(
*Module, JITContext::getInstance());
if (!BinaryInfoOrError) {
return errorTo<RTCResult>(BinaryInfoOrError.takeError(),
"SPIR-V translation failed");
}
BundleInfo.BinaryInfo = std::move(*BinaryInfoOrError);

return JITResult{Kernel};
return RTCResult{std::move(BundleInfo)};
}

extern "C" KF_EXPORT_SYMBOL void resetJITConfiguration() {
Expand Down
Loading
Loading