Skip to content

Commit 4d463b7

Browse files
committed
Merge branch 'sycl' into aaron/updatePIDocs
2 parents c1ba534 + 3c274a8 commit 4d463b7

File tree

53 files changed

+461
-390
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+461
-390
lines changed

sycl-jit/common/include/Kernel.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <cstdint>
1818
#include <cstring>
1919
#include <functional>
20+
#include <string_view>
2021
#include <type_traits>
2122

2223
namespace jit_compiler {
@@ -350,11 +351,60 @@ struct SYCLKernelInfo {
350351
: Name{KernelName}, Args{NumArgs}, Attributes{}, NDR{}, BinaryInfo{} {}
351352
};
352353

354+
// RTC-related datastructures
355+
// TODO: Consider moving into separate header.
356+
353357
struct InMemoryFile {
354358
const char *Path;
355359
const char *Contents;
356360
};
357361

362+
using RTCBundleBinaryInfo = SYCLKernelBinaryInfo;
363+
using FrozenSymbolTable = DynArray<sycl::detail::string>;
364+
365+
// Note: `FrozenPropertyValue` and `FrozenPropertySet` constructors take
366+
// `std::string_view` arguments instead of `const char *` because they will be
367+
// created from `llvm::SmallString`s, which don't contain the trailing '\0'
368+
// byte. Hence obtaining a C-string would cause an additional copy.
369+
370+
struct FrozenPropertyValue {
371+
sycl::detail::string Name;
372+
bool IsUIntValue;
373+
uint32_t UIntValue;
374+
DynArray<uint8_t> Bytes;
375+
376+
FrozenPropertyValue() = default;
377+
FrozenPropertyValue(FrozenPropertyValue &&) = default;
378+
FrozenPropertyValue &operator=(FrozenPropertyValue &&) = default;
379+
380+
FrozenPropertyValue(std::string_view Name, uint32_t Value)
381+
: Name{Name}, IsUIntValue{true}, UIntValue{Value}, Bytes{0} {}
382+
FrozenPropertyValue(std::string_view Name, const uint8_t *Ptr, size_t Size)
383+
: Name{Name}, IsUIntValue{false}, Bytes{Size} {
384+
std::memcpy(Bytes.begin(), Ptr, Size);
385+
}
386+
};
387+
388+
struct FrozenPropertySet {
389+
sycl::detail::string Name;
390+
DynArray<FrozenPropertyValue> Values;
391+
392+
FrozenPropertySet() = default;
393+
FrozenPropertySet(FrozenPropertySet &&) = default;
394+
FrozenPropertySet &operator=(FrozenPropertySet &&) = default;
395+
396+
FrozenPropertySet(std::string_view Name, size_t Size)
397+
: Name{Name}, Values{Size} {}
398+
};
399+
400+
using FrozenPropertyRegistry = DynArray<FrozenPropertySet>;
401+
402+
struct RTCBundleInfo {
403+
RTCBundleBinaryInfo BinaryInfo;
404+
FrozenSymbolTable SymbolTable;
405+
FrozenPropertyRegistry Properties;
406+
};
407+
358408
} // namespace jit_compiler
359409

360410
#endif // SYCL_FUSION_COMMON_KERNEL_H

sycl-jit/jit-compiler/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ add_llvm_library(sycl-jit
3131
Target
3232
TargetParser
3333
MC
34+
SYCLLowerIR
3435
${LLVM_TARGETS_TO_BUILD}
3536

3637
LINK_LIBS

sycl-jit/jit-compiler/include/KernelFusion.h

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,32 @@ class JITResult {
5656
sycl::detail::string ErrorMessage;
5757
};
5858

59+
class RTCResult {
60+
public:
61+
explicit RTCResult(const char *ErrorMessage)
62+
: Failed{true}, BundleInfo{}, ErrorMessage{ErrorMessage} {}
63+
64+
explicit RTCResult(RTCBundleInfo &&BundleInfo)
65+
: Failed{false}, BundleInfo{std::move(BundleInfo)}, ErrorMessage{} {}
66+
67+
bool failed() const { return Failed; }
68+
69+
const char *getErrorMessage() const {
70+
assert(failed() && "No error message present");
71+
return ErrorMessage.c_str();
72+
}
73+
74+
const RTCBundleInfo &getBundleInfo() const {
75+
assert(!failed() && "No bundle info");
76+
return BundleInfo;
77+
}
78+
79+
private:
80+
bool Failed;
81+
RTCBundleInfo BundleInfo;
82+
sycl::detail::string ErrorMessage;
83+
};
84+
5985
extern "C" {
6086

6187
#ifdef __clang__
@@ -77,7 +103,7 @@ KF_EXPORT_SYMBOL JITResult materializeSpecConstants(
77103
const char *KernelName, jit_compiler::SYCLKernelBinaryInfo &BinInfo,
78104
View<unsigned char> SpecConstBlob);
79105

80-
KF_EXPORT_SYMBOL JITResult compileSYCL(InMemoryFile SourceFile,
106+
KF_EXPORT_SYMBOL RTCResult compileSYCL(InMemoryFile SourceFile,
81107
View<InMemoryFile> IncludeFiles,
82108
View<const char *> UserArgs);
83109

sycl-jit/jit-compiler/lib/KernelFusion.cpp

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ using namespace jit_compiler;
2525
using FusedFunction = helper::FusionHelper::FusedFunction;
2626
using FusedFunctionList = std::vector<FusedFunction>;
2727

28-
static JITResult errorToFusionResult(llvm::Error &&Err,
29-
const std::string &Msg) {
28+
template <typename ResultType>
29+
static ResultType errorTo(llvm::Error &&Err, const std::string &Msg) {
3030
std::stringstream ErrMsg;
3131
ErrMsg << Msg << "\nDetailed information:\n";
3232
llvm::handleAllErrors(std::move(Err),
@@ -35,7 +35,7 @@ static JITResult errorToFusionResult(llvm::Error &&Err,
3535
// compiled without exception support.
3636
ErrMsg << "\t" << StrErr.getMessage() << "\n";
3737
});
38-
return JITResult{ErrMsg.str().c_str()};
38+
return ResultType{ErrMsg.str().c_str()};
3939
}
4040

4141
static std::vector<jit_compiler::NDRange>
@@ -95,7 +95,7 @@ extern "C" KF_EXPORT_SYMBOL JITResult materializeSpecConstants(
9595
translation::KernelTranslator::loadKernels(*JITCtx.getLLVMContext(),
9696
ModuleInfo.kernels());
9797
if (auto Error = ModOrError.takeError()) {
98-
return errorToFusionResult(std::move(Error), "Failed to load kernels");
98+
return errorTo<JITResult>(std::move(Error), "Failed to load kernels");
9999
}
100100
std::unique_ptr<llvm::Module> NewMod = std::move(*ModOrError);
101101
if (!fusion::FusionPipeline::runMaterializerPasses(
@@ -107,8 +107,8 @@ extern "C" KF_EXPORT_SYMBOL JITResult materializeSpecConstants(
107107
SYCLKernelInfo &MaterializerKernelInfo = *ModuleInfo.getKernelFor(KernelName);
108108
if (auto Error = translation::KernelTranslator::translateKernel(
109109
MaterializerKernelInfo, *NewMod, JITCtx, TargetFormat)) {
110-
return errorToFusionResult(std::move(Error),
111-
"Translation to output format failed");
110+
return errorTo<JITResult>(std::move(Error),
111+
"Translation to output format failed");
112112
}
113113

114114
return JITResult{MaterializerKernelInfo};
@@ -133,7 +133,7 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
133133
llvm::Expected<jit_compiler::FusedNDRange> FusedNDR =
134134
jit_compiler::FusedNDRange::get(NDRanges);
135135
if (llvm::Error Err = FusedNDR.takeError()) {
136-
return errorToFusionResult(std::move(Err), "Illegal ND-range combination");
136+
return errorTo<JITResult>(std::move(Err), "Illegal ND-range combination");
137137
}
138138

139139
if (!isTargetFormatSupported(TargetFormat)) {
@@ -180,7 +180,7 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
180180
translation::KernelTranslator::loadKernels(*JITCtx.getLLVMContext(),
181181
ModuleInfo.kernels());
182182
if (auto Error = ModOrError.takeError()) {
183-
return errorToFusionResult(std::move(Error), "SPIR-V translation failed");
183+
return errorTo<JITResult>(std::move(Error), "SPIR-V translation failed");
184184
}
185185
std::unique_ptr<llvm::Module> LLVMMod = std::move(*ModOrError);
186186

@@ -197,8 +197,8 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
197197
llvm::Expected<std::unique_ptr<llvm::Module>> NewModOrError =
198198
helper::FusionHelper::addFusedKernel(LLVMMod.get(), FusedKernelList);
199199
if (auto Error = NewModOrError.takeError()) {
200-
return errorToFusionResult(std::move(Error),
201-
"Insertion of fused kernel stub failed");
200+
return errorTo<JITResult>(std::move(Error),
201+
"Insertion of fused kernel stub failed");
202202
}
203203
std::unique_ptr<llvm::Module> NewMod = std::move(*NewModOrError);
204204

@@ -221,8 +221,8 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
221221

222222
if (auto Error = translation::KernelTranslator::translateKernel(
223223
FusedKernelInfo, *NewMod, JITCtx, TargetFormat)) {
224-
return errorToFusionResult(std::move(Error),
225-
"Translation to output format failed");
224+
return errorTo<JITResult>(std::move(Error),
225+
"Translation to output format failed");
226226
}
227227

228228
FusedKernelInfo.NDR = FusedNDR->getNDR();
@@ -234,37 +234,47 @@ fuseKernels(View<SYCLKernelInfo> KernelInformation, const char *FusedKernelName,
234234
return JITResult{FusedKernelInfo};
235235
}
236236

237-
extern "C" KF_EXPORT_SYMBOL JITResult
237+
extern "C" KF_EXPORT_SYMBOL RTCResult
238238
compileSYCL(InMemoryFile SourceFile, View<InMemoryFile> IncludeFiles,
239239
View<const char *> UserArgs) {
240240
auto UserArgListOrErr = parseUserArgs(UserArgs);
241241
if (!UserArgListOrErr) {
242-
return errorToFusionResult(UserArgListOrErr.takeError(),
243-
"Parsing of user arguments failed");
242+
return errorTo<RTCResult>(UserArgListOrErr.takeError(),
243+
"Parsing of user arguments failed");
244244
}
245245
llvm::opt::InputArgList UserArgList = std::move(*UserArgListOrErr);
246246

247247
auto ModuleOrErr = compileDeviceCode(SourceFile, IncludeFiles, UserArgList);
248248
if (!ModuleOrErr) {
249-
return errorToFusionResult(ModuleOrErr.takeError(),
250-
"Device compilation failed");
249+
return errorTo<RTCResult>(ModuleOrErr.takeError(),
250+
"Device compilation failed");
251251
}
252252

253253
std::unique_ptr<llvm::LLVMContext> Context;
254254
std::unique_ptr<llvm::Module> Module = std::move(*ModuleOrErr);
255255
Context.reset(&Module->getContext());
256256

257257
if (auto Error = linkDeviceLibraries(*Module, UserArgList)) {
258-
return errorToFusionResult(std::move(Error), "Device linking failed");
258+
return errorTo<RTCResult>(std::move(Error), "Device linking failed");
259259
}
260260

261-
SYCLKernelInfo Kernel;
262-
if (auto Error = translation::KernelTranslator::translateKernel(
263-
Kernel, *Module, JITContext::getInstance(), BinaryFormat::SPIRV)) {
264-
return errorToFusionResult(std::move(Error), "SPIR-V translation failed");
261+
auto BundleInfoOrError = performPostLink(*Module, UserArgList);
262+
if (!BundleInfoOrError) {
263+
return errorTo<RTCResult>(BundleInfoOrError.takeError(),
264+
"Post-link phase failed");
265+
}
266+
auto BundleInfo = std::move(*BundleInfoOrError);
267+
268+
auto BinaryInfoOrError =
269+
translation::KernelTranslator::translateBundleToSPIRV(
270+
*Module, JITContext::getInstance());
271+
if (!BinaryInfoOrError) {
272+
return errorTo<RTCResult>(BinaryInfoOrError.takeError(),
273+
"SPIR-V translation failed");
265274
}
275+
BundleInfo.BinaryInfo = std::move(*BinaryInfoOrError);
266276

267-
return JITResult{Kernel};
277+
return RTCResult{std::move(BundleInfo)};
268278
}
269279

270280
extern "C" KF_EXPORT_SYMBOL void resetJITConfiguration() {

sycl-jit/jit-compiler/lib/rtc/DeviceCompilation.cpp

Lines changed: 111 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,25 @@
1818
#include <clang/Tooling/CompilationDatabase.h>
1919
#include <clang/Tooling/Tooling.h>
2020

21+
#include <llvm/IR/PassInstrumentation.h>
22+
#include <llvm/IR/PassManager.h>
2123
#include <llvm/IRReader/IRReader.h>
2224
#include <llvm/Linker/Linker.h>
23-
24-
#include <array>
25+
#include <llvm/SYCLLowerIR/ComputeModuleRuntimeInfo.h>
26+
#include <llvm/SYCLLowerIR/ModuleSplitter.h>
27+
#include <llvm/SYCLLowerIR/SYCLJointMatrixTransform.h>
28+
#include <llvm/Support/PropertySetIO.h>
2529

2630
using namespace clang;
2731
using namespace clang::tooling;
2832
using namespace clang::driver;
2933
using namespace clang::driver::options;
3034
using namespace llvm;
3135
using namespace llvm::opt;
36+
using namespace llvm::sycl;
37+
using namespace llvm::module_split;
38+
using namespace llvm::util;
39+
using namespace jit_compiler;
3240

3341
#ifdef _GNU_SOURCE
3442
#include <dlfcn.h>
@@ -356,6 +364,94 @@ Error jit_compiler::linkDeviceLibraries(llvm::Module &Module,
356364
return Error::success();
357365
}
358366

367+
template <class PassClass> static bool runModulePass(llvm::Module &M) {
368+
ModulePassManager MPM;
369+
ModuleAnalysisManager MAM;
370+
// Register required analysis
371+
MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
372+
MPM.addPass(PassClass{});
373+
PreservedAnalyses Res = MPM.run(M, MAM);
374+
return !Res.areAllPreserved();
375+
}
376+
377+
Expected<RTCBundleInfo> jit_compiler::performPostLink(
378+
llvm::Module &Module, [[maybe_unused]] const InputArgList &UserArgList) {
379+
// This is a simplified version of `processInputModule` in
380+
// `llvm/tools/sycl-post-link.cpp`. Assertions/TODOs point to functionality
381+
// left out of the algorithm for now.
382+
383+
assert(!Module.getGlobalVariable("llvm.used") &&
384+
!Module.getGlobalVariable("llvm.compiler.used"));
385+
// Otherwise: Port over the `removeSYCLKernelsConstRefArray` and
386+
// `removeDeviceGlobalFromCompilerUsed` methods.
387+
388+
assert(!isModuleUsingAsan(Module));
389+
// Otherwise: Need to instrument each image scope device globals if the module
390+
// has been instrumented by sanitizer pass.
391+
392+
// Transform Joint Matrix builtin calls to align them with SPIR-V friendly
393+
// LLVM IR specification.
394+
runModulePass<SYCLJointMatrixTransformPass>(Module);
395+
396+
// TODO: Implement actual device code splitting. We're just using the splitter
397+
// to obtain additional information about the module for now.
398+
// TODO: EmitOnlyKernelsAsEntryPoints is controlled by
399+
// `shouldEmitOnlyKernelsAsEntryPoints` in
400+
// `clang/lib/Driver/ToolChains/Clang.cpp`.
401+
std::unique_ptr<ModuleSplitterBase> Splitter = getDeviceCodeSplitter(
402+
ModuleDesc{std::unique_ptr<llvm::Module>{&Module}}, SPLIT_NONE,
403+
/*IROutputOnly=*/false,
404+
/*EmitOnlyKernelsAsEntryPoints=*/true);
405+
assert(Splitter->remainingSplits() == 1);
406+
407+
// TODO: Call `verifyNoCrossModuleDeviceGlobalUsage` if device globals shall
408+
// be processed.
409+
410+
assert(Splitter->hasMoreSplits());
411+
ModuleDesc MDesc = Splitter->nextSplit();
412+
assert(&Module == &MDesc.getModule());
413+
MDesc.saveSplitInformationAsMetadata();
414+
415+
RTCBundleInfo BundleInfo;
416+
BundleInfo.SymbolTable = FrozenSymbolTable{MDesc.entries().size()};
417+
transform(MDesc.entries(), BundleInfo.SymbolTable.begin(),
418+
[](Function *F) { return F->getName(); });
419+
420+
// TODO: Determine what is requested.
421+
GlobalBinImageProps PropReq{
422+
/*EmitKernelParamInfo=*/true, /*EmitProgramMetadata=*/true,
423+
/*EmitExportedSymbols=*/true, /*EmitImportedSymbols=*/true,
424+
/*DeviceGlobals=*/false};
425+
PropertySetRegistry Properties =
426+
computeModuleProperties(MDesc.getModule(), MDesc.entries(), PropReq);
427+
// TODO: Manually add `compile_target` property as in
428+
// `saveModuleProperties`?
429+
const auto &PropertySets = Properties.getPropSets();
430+
431+
BundleInfo.Properties = FrozenPropertyRegistry{PropertySets.size()};
432+
for (auto &&[KV, FrozenPropSet] : zip(PropertySets, BundleInfo.Properties)) {
433+
const auto &PropertySetName = KV.first;
434+
const auto &PropertySet = KV.second;
435+
FrozenPropSet =
436+
FrozenPropertySet{PropertySetName.str(), PropertySet.size()};
437+
for (auto &&[KV2, FrozenProp] : zip(PropertySet, FrozenPropSet.Values)) {
438+
const auto &PropertyName = KV2.first;
439+
const auto &PropertyValue = KV2.second;
440+
FrozenProp = PropertyValue.getType() == PropertyValue::Type::UINT32
441+
? FrozenPropertyValue{PropertyName.str(),
442+
PropertyValue.asUint32()}
443+
: FrozenPropertyValue{
444+
PropertyName.str(), PropertyValue.asRawByteArray(),
445+
PropertyValue.getRawByteArraySize()};
446+
}
447+
};
448+
449+
// Regain ownership of the module.
450+
MDesc.releaseModulePtr().release();
451+
452+
return std::move(BundleInfo);
453+
}
454+
359455
Expected<InputArgList>
360456
jit_compiler::parseUserArgs(View<const char *> UserArgs) {
361457
unsigned MissingArgIndex, MissingArgCount;
@@ -410,5 +506,17 @@ jit_compiler::parseUserArgs(View<const char *> UserArgs) {
410506
}
411507
}
412508

413-
return Expected<InputArgList>{std::move(AL)};
509+
if (auto DCSMode = AL.getLastArgValue(OPT_fsycl_device_code_split_EQ, "none");
510+
DCSMode != "none" && DCSMode != "auto") {
511+
return createStringError("Device code splitting is not yet supported");
512+
}
513+
514+
if (AL.hasArg(OPT_fsycl_device_code_split_esimd,
515+
OPT_fno_sycl_device_code_split_esimd)) {
516+
// TODO: There are more ESIMD-related options.
517+
return createStringError(
518+
"Runtime compilation of ESIMD kernels is not yet supported");
519+
}
520+
521+
return std::move(AL);
414522
}

0 commit comments

Comments
 (0)